From 009a4ed8eae03ff438ffc9f32be9bb7ac380c59a Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Sat, 2 Aug 2025 12:22:06 -0400 Subject: [PATCH 001/109] Initial stub for prefill.cuh --- .../flashinfer/attention/generic/prefill.cuh | 3159 +++++++++++++++++ 1 file changed, 3159 insertions(+) create mode 100644 libflashinfer/include/flashinfer/attention/generic/prefill.cuh diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh new file mode 100644 index 0000000000..578484221c --- /dev/null +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -0,0 +1,3159 @@ +// SPDX - FileCopyrightText : 2023-2025 FlashInfer team. +// SPDX - FileCopyrightText : 2025 Advanced Micro Devices, Inc. +// +// SPDX - License - Identifier : Apache - 2.0 +#ifndef FLASHINFER_PREFILL_CUH_ +#define FLASHINFER_PREFILL_CUH_ + +#include "gpu_iface/cooperative_groups.h" +#include "gpu_iface/fastdiv.cuh" +#include "gpu_iface/math_ops.hpp" +#include "gpu_iface/memory_ops.hpp" +#include "gpu_iface/mma_ops.hpp" +#include "gpu_iface/platform.hpp" + +#ifdef FP16_QK_REDUCTION_SUPPORTED +#include "../../fp16.h" +#endif +#include "frag_layout_swizzle.cuh" + +#include "cascade.cuh" +#include "dispatch.cuh" +#include "page.cuh" +#include "permuted_smem.cuh" +#include "pos_enc.cuh" +#include "utils.cuh" +#include "variants.cuh" + +namespace flashinfer +{ + +DEFINE_HAS_MEMBER(maybe_q_rope_offset) +DEFINE_HAS_MEMBER(maybe_k_rope_offset) + +namespace cg = flashinfer::gpu_iface::cg; +namespace memory = flashinfer::gpu_iface::memory; +namespace mma = gpu_iface::mma; + +using gpu_iface::vec_dtypes::vec_cast; +using mma::MMAMode; + +constexpr uint32_t WARP_SIZE = gpu_iface::kWarpSize; +#if defined(PLATFORM_HIP_DEVICE) +constexpr uint32_t WARP_STEP_SIZE = 16; +#else +constexpr uint32_t WARP_STEP_SIZE = 8; // NVIDIA +#endif + +constexpr uint32_t get_num_warps_q(const uint32_t cta_tile_q) +{ + if (cta_tile_q > 16) { + return 4; + } + else { + return 1; + } +} + +constexpr uint32_t get_num_warps_kv(const uint32_t cta_tile_kv) +{ + return 4 / get_num_warps_q(cta_tile_kv); +} + +constexpr uint32_t get_num_mma_q(const uint32_t cta_tile_q) +{ + if (cta_tile_q > 64) { + return 2; + } + else { + return 1; + } +} + +template +struct SharedStorageQKVO +{ + union + { + struct + { + alignas(16) DTypeQ q_smem[CTA_TILE_Q * HEAD_DIM_QK]; + alignas(16) DTypeKV k_smem[CTA_TILE_KV * HEAD_DIM_QK]; + alignas(16) DTypeKV v_smem[CTA_TILE_KV * HEAD_DIM_VO]; + }; + struct + { // NOTE(Zihao): synchronize attention states across warps + alignas(16) std::conditional_t< + NUM_WARPS_KV == 1, + float[1], + float[NUM_WARPS_KV * CTA_TILE_Q * HEAD_DIM_VO]> cta_sync_o_smem; + alignas(16) std::conditional_t< + NUM_WARPS_KV == 1, + float2[1], + float2[NUM_WARPS_KV * CTA_TILE_Q]> cta_sync_md_smem; + }; + alignas(16) DTypeO smem_o[CTA_TILE_Q * HEAD_DIM_VO]; + }; +}; + +template +struct KernelTraits +{ + static constexpr MaskMode MASK_MODE = MASK_MODE_; + static constexpr uint32_t NUM_MMA_Q = NUM_MMA_Q_; + static constexpr uint32_t NUM_MMA_KV = NUM_MMA_KV_; + static constexpr uint32_t NUM_MMA_D_QK = NUM_MMA_D_QK_; + static constexpr uint32_t NUM_MMA_D_VO = NUM_MMA_D_VO_; + static constexpr uint32_t NUM_WARPS_Q = NUM_WARPS_Q_; + static constexpr uint32_t NUM_WARPS_KV = NUM_WARPS_KV_; + static constexpr uint32_t NUM_THREADS = + NUM_WARPS_Q * NUM_WARPS_KV * WARP_SIZE; + static constexpr uint32_t NUM_WARPS = NUM_WARPS_Q * NUM_WARPS_KV; + static constexpr uint32_t HEAD_DIM_QK = NUM_MMA_D_QK * 16; + static constexpr uint32_t HEAD_DIM_VO = NUM_MMA_D_VO * 16; + static constexpr uint32_t UPCAST_STRIDE_Q = + HEAD_DIM_QK / upcast_size(); + static constexpr uint32_t UPCAST_STRIDE_K = + HEAD_DIM_QK / upcast_size(); + static constexpr uint32_t UPCAST_STRIDE_V = + HEAD_DIM_VO / upcast_size(); + static constexpr uint32_t UPCAST_STRIDE_O = + HEAD_DIM_VO / upcast_size(); + static constexpr uint32_t CTA_TILE_Q = CTA_TILE_Q_; + static constexpr uint32_t CTA_TILE_KV = NUM_MMA_KV * NUM_WARPS_KV * 16; + + static constexpr SwizzleMode SWIZZLE_MODE_Q = SwizzleMode::k128B; + static constexpr SwizzleMode SWIZZLE_MODE_KV = + (sizeof(DTypeKV_) == 1 && HEAD_DIM_VO == 64) ? SwizzleMode::k64B + : SwizzleMode::k128B; + static constexpr uint32_t KV_THR_LAYOUT_ROW = + SWIZZLE_MODE_KV == SwizzleMode::k128B ? 4 : 8; + static constexpr uint32_t KV_THR_LAYOUT_COL = + SWIZZLE_MODE_KV == SwizzleMode::k128B ? 8 : 4; + static constexpr PosEncodingMode POS_ENCODING_MODE = POS_ENCODING_MODE_; + using DTypeQ = DTypeQ_; + using DTypeKV = DTypeKV_; + using DTypeO = DTypeO_; + using DTypeQKAccum = DTypeQKAccum_; + using IdType = IdType_; + using AttentionVariant = AttentionVariant_; + + static constexpr bool IsInvalid() + { + return ((NUM_MMA_D_VO < 4) || + (NUM_MMA_D_VO == 4 && NUM_MMA_KV % 2 == 1) || + (POS_ENCODING_MODE == PosEncodingMode::kRoPELlama && + NUM_MMA_D_VO > 4 && NUM_MMA_D_VO % (2 * NUM_WARPS_Q) != 0) || + (NUM_MMA_Q * (8 * NUM_MMA_D_VO + + 2 * sizeof(DTypeQKAccum) * NUM_MMA_KV) >= + 256) || + (sizeof(DTypeKV) == 1 && NUM_MMA_KV * 2 % NUM_WARPS_Q != 0) || + (sizeof(DTypeKV) == 1 && + POS_ENCODING_MODE == PosEncodingMode::kRoPELlama)); + } + + using SharedStorage = SharedStorageQKVO; +#ifdef FP16_QK_REDUCTION_SUPPORTED + template static constexpr DT getNegInf() + { + if constexpr (std::is_same::value) { + return std::bit_cast( + fp16_ieee_from_fp32_value(-gpu_iface::math::inf)); + } + else { + return static_cast(-gpu_iface::math::inf); + } + } + + static constexpr DTypeQKAccum MaskFillValue = + AttentionVariant::use_softmax ? getNegInf() + : DTypeQKAccum(0.f); +#else + static_assert(!std::is_same::value, + "Set -DFP16_QK_REDUCTION_SUPPORTED and install boost_math " + "then recompile to support fp16 reduction"); + static constexpr DTypeQKAccum MaskFillValue = + AttentionVariant::use_softmax ? DTypeQKAccum(-gpu_iface::math::inf) + : DTypeQKAccum(0.f); +#endif +}; + +namespace +{ + +template +__device__ __forceinline__ uint32_t +get_warp_idx_q(const uint32_t tid_y = threadIdx.y) +{ + if constexpr (KTraits::NUM_WARPS_Q == 1) { + return 0; + } + else { + return tid_y; + } +} + +template +__device__ __forceinline__ uint32_t +get_warp_idx_kv(const uint32_t tid_z = threadIdx.z) +{ + if constexpr (KTraits::NUM_WARPS_KV == 1) { + return 0; + } + else { + return tid_z; + } +} + +template +__device__ __forceinline__ uint32_t +get_warp_idx(const uint32_t tid_y = threadIdx.y, + const uint32_t tid_z = threadIdx.z) +{ + return get_warp_idx_kv(tid_z) * KTraits::NUM_WARPS_Q + + get_warp_idx_q(tid_y); +} + +/*! + * \brief Apply Llama style rotary embedding to two 16x16 fragments. + * \tparam T The data type of the input fragments. + * \param x_first_half First fragment x[offset:offset+16, j*16:(j+1)*16] + * \param x_second_half Second fragment x[offset:offset*16, + * j*16+d/2:(j+1)*16+d/2] + * \param rope_freq Rope frequency + * \param offset The offset of the first row in both fragments. + * \note The sin/cos computation is slow, especially for A100 GPUs which has low + * non tensor-ops flops, will optimize in the future. + */ +template +__device__ __forceinline__ void +k_frag_apply_llama_rope(T *x_first_half, + T *x_second_half, + const float *rope_freq, + const uint32_t kv_offset) +{ + static_assert(sizeof(T) == 2); +#pragma unroll + for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { + float cos, sin, tmp; + // 0 1 | 2 3 + // --------- + // 4 5 | 6 7 + uint32_t i = reg_id / 4, j = (reg_id % 4) / 2; + __sincosf(float(kv_offset + 8 * i) * rope_freq[2 * j + reg_id % 2], + &sin, &cos); + tmp = x_first_half[reg_id]; + x_first_half[reg_id] = (tmp * cos - (float)x_second_half[reg_id] * sin); + x_second_half[reg_id] = + ((float)x_second_half[reg_id] * cos + tmp * sin); + } +} + +template +__device__ __forceinline__ void +q_frag_apply_llama_rope(T *x_first_half, + T *x_second_half, + const float *rope_freq, + const uint32_t qo_packed_offset, + const uint_fastdiv group_size) +{ +#pragma unroll + for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { + float cos, sin, tmp; + // 0 1 | 4 5 + // --------- + // 2 3 | 6 7 + uint32_t i = ((reg_id % 4) / 2), j = (reg_id / 4); + __sincosf(float((qo_packed_offset + 8 * i) / group_size) * + rope_freq[2 * j + reg_id % 2], + &sin, &cos); + tmp = x_first_half[reg_id]; + x_first_half[reg_id] = (tmp * cos - (float)x_second_half[reg_id] * sin); + x_second_half[reg_id] = + ((float)x_second_half[reg_id] * cos + tmp * sin); + } +} + +template +__device__ __forceinline__ void +q_frag_apply_llama_rope_with_pos(T *x_first_half, + T *x_second_half, + const float *rope_freq, + const uint32_t qo_packed_offset, + const uint_fastdiv group_size, + const IdType *q_rope_offset) +{ + float pos[2] = { + static_cast(q_rope_offset[qo_packed_offset / group_size]), + static_cast(q_rope_offset[(qo_packed_offset + 8) / group_size])}; +#pragma unroll + for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { + float cos, sin, tmp; + // 0 1 | 4 5 + // --------- + // 2 3 | 6 7 + uint32_t i = ((reg_id % 4) / 2), j = (reg_id / 4); + __sincosf(pos[i] * rope_freq[2 * j + reg_id % 2], &sin, &cos); + tmp = x_first_half[reg_id]; + x_first_half[reg_id] = (tmp * cos - (float)x_second_half[reg_id] * sin); + x_second_half[reg_id] = + ((float)x_second_half[reg_id] * cos + tmp * sin); + } +} + +/*! + * \brief Produce k/v fragments from global memory to shared memory. + * \tparam fill_mode The fill mode of the shared memory. + * \tparam NUM_MMA_D_VO The number of fragments in y dimension. + * \tparam NUM_MMA_KV The number of fragments in z dimension. + * \tparam num_warps The number of warps in the threadblock. + * \tparam T The data type of the input tensor. + * \param smem The shared memory to store kv fragments. + * \param gptr The global memory pointer. + * \param kv_idx_base The base kv index. + * \param kv_len The length of kv tensor. + */ +template +__device__ __forceinline__ void +produce_kv(smem_t smem, + uint32_t *smem_offset, + typename KTraits::DTypeKV **gptr, + const uint32_t stride_n, + const uint32_t kv_idx_base, + const uint32_t kv_len, + const dim3 tid = threadIdx) +{ + // NOTE: for fp8, this function doesn't work for head_dim = 64 at the moment + using DTypeKV = typename KTraits::DTypeKV; + constexpr uint32_t CTA_TILE_KV = KTraits::CTA_TILE_KV; + constexpr uint32_t NUM_WARPS = KTraits::NUM_WARPS; + constexpr uint32_t NUM_WARPS_Q = KTraits::NUM_WARPS_Q; + constexpr uint32_t NUM_MMA_D = + produce_v ? KTraits::NUM_MMA_D_VO : KTraits::NUM_MMA_D_QK; + constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; + constexpr uint32_t UPCAST_STRIDE = + produce_v ? KTraits::UPCAST_STRIDE_V : KTraits::UPCAST_STRIDE_K; + const uint32_t warp_idx = get_warp_idx(tid.y, tid.z), + lane_idx = tid.x; + + if constexpr (KTraits::SWIZZLE_MODE_KV == SwizzleMode::k128B) { + uint32_t kv_idx = + kv_idx_base + warp_idx * 4 + lane_idx / WARP_STEP_SIZE; + // NOTE: NUM_MMA_KV * 4 / NUM_WARPS_Q = NUM_WARPS_KV * NUM_MMA_KV * 4 / + // num_warps + static_assert(NUM_MMA_KV * 4 % NUM_WARPS_Q == 0); +#pragma unroll + for (uint32_t i = 0; i < NUM_MMA_KV * 4 / NUM_WARPS_Q; ++i) { +#pragma unroll + for (uint32_t j = 0; j < NUM_MMA_D / (8 / sizeof(DTypeKV)); ++j) { + smem.load_128b_async(*smem_offset, *gptr, + kv_idx < kv_len); + *smem_offset = + smem.template advance_offset_by_column( + *smem_offset, j); + *gptr += WARP_STEP_SIZE * upcast_size(); + } + kv_idx += NUM_WARPS * 4; + *smem_offset = smem.template advance_offset_by_row( + *smem_offset) - + sizeof(DTypeKV) * NUM_MMA_D; + *gptr += NUM_WARPS * 4 * stride_n - + sizeof(DTypeKV) * NUM_MMA_D * upcast_size(); + } + *smem_offset -= CTA_TILE_KV * UPCAST_STRIDE; + } + else { +#if defined(PLATFORM_HIP_DEVICE) + static_assert(false, + "SwizzleMode::k64B is not supported on AMD/CDNA3."); +#else + uint32_t kv_idx = kv_idx_base + warp_idx * 8 + lane_idx / 4; + // NOTE: NUM_MMA_KV * 2 / NUM_WARPS_Q = NUM_WARPS_KV * NUM_MMA_KV * 2 / + // num_warps + static_assert(NUM_MMA_KV * 2 % NUM_WARPS_Q == 0); +#pragma unroll + for (uint32_t i = 0; i < NUM_MMA_KV * 2 / NUM_WARPS_Q; ++i) { + smem.load_128b_async(*smem_offset, *gptr, + kv_idx < kv_len); + *smem_offset = smem.template advance_offset_by_row( + *smem_offset); + kv_idx += NUM_WARPS * 8; + *gptr += NUM_WARPS * 8 * stride_n; + } + *smem_offset -= KTraits::CTA_TILE_KV * UPCAST_STRIDE; +#endif + } +} + +template +__device__ __forceinline__ void +page_produce_kv(smem_t smem, + uint32_t *smem_offset, + const paged_kv_t &paged_kv, + const uint32_t kv_idx_base, + const size_t *thr_local_kv_offset, + const uint32_t kv_len, + const dim3 tid = threadIdx) +{ + // NOTE: for fp8, this function doesn't work for head_dim = 64 at the moment + using DType = typename KTraits::DTypeKV; + using IdType = typename KTraits::IdType; + constexpr SharedMemFillMode fill_mode = + produce_v ? SharedMemFillMode::kFillZero : SharedMemFillMode::kNoFill; + constexpr uint32_t NUM_WARPS = KTraits::NUM_WARPS; + constexpr uint32_t NUM_WARPS_Q = KTraits::NUM_WARPS_Q; + constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; + constexpr uint32_t NUM_MMA_D = + produce_v ? KTraits::NUM_MMA_D_VO : KTraits::NUM_MMA_D_QK; + constexpr uint32_t UPCAST_STRIDE = + produce_v ? KTraits::UPCAST_STRIDE_V : KTraits::UPCAST_STRIDE_K; + const uint32_t warp_idx = get_warp_idx(tid.y, tid.z), + lane_idx = tid.x; + if constexpr (KTraits::SWIZZLE_MODE_KV == SwizzleMode::k128B) { + uint32_t kv_idx = kv_idx_base + warp_idx * 4 + lane_idx / 8; + // NOTE: NUM_MMA_KV * 4 / NUM_WARPS_Q = NUM_WARPS_KV * NUM_MMA_KV * 4 / + // num_warps + static_assert(NUM_MMA_KV * 4 % NUM_WARPS_Q == 0); +#pragma unroll + for (uint32_t i = 0; i < NUM_MMA_KV * 4 / NUM_WARPS_Q; ++i) { + DType *gptr = produce_v ? paged_kv.v_data + thr_local_kv_offset[i] + : paged_kv.k_data + thr_local_kv_offset[i]; +#pragma unroll + for (uint32_t j = 0; j < NUM_MMA_D / (8 / sizeof(DType)); ++j) { + smem.load_128b_async(*smem_offset, gptr, + kv_idx < kv_len); + *smem_offset = + smem.template advance_offset_by_column<8>(*smem_offset, j); + gptr += 8 * upcast_size(); + } + kv_idx += NUM_WARPS * 4; + *smem_offset = smem.template advance_offset_by_row( + *smem_offset) - + sizeof(DType) * NUM_MMA_D; + } + *smem_offset -= KTraits::CTA_TILE_KV * UPCAST_STRIDE; + } + else { + uint32_t kv_idx = kv_idx_base + warp_idx * 8 + lane_idx / 4; + // NOTE: NUM_MMA_KV * 2 / NUM_WARPS_Q = NUM_WARPS_KV * NUM_MMA_KV * 2 / + // num_warps + static_assert(NUM_MMA_KV * 2 % NUM_WARPS_Q == 0); +#pragma unroll + for (uint32_t i = 0; i < NUM_MMA_KV * 2 / NUM_WARPS_Q; ++i) { + DType *gptr = produce_v ? paged_kv.v_data + thr_local_kv_offset[i] + : paged_kv.k_data + thr_local_kv_offset[i]; + smem.load_128b_async(*smem_offset, gptr, + kv_idx < kv_len); + kv_idx += NUM_WARPS * 8; + *smem_offset = smem.template advance_offset_by_row( + *smem_offset); + } + *smem_offset -= KTraits::CTA_TILE_KV * UPCAST_STRIDE; + } +} + +template +__device__ __forceinline__ void +init_rope_freq(float (*rope_freq)[4], + const float rope_rcp_scale, + const float rope_rcp_theta, + const uint32_t tid_x = threadIdx.x) +{ + constexpr uint32_t HEAD_DIM = KTraits::NUM_MMA_D_QK * 16; + const uint32_t lane_idx = tid_x; +#pragma unroll + for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO / 2; ++mma_d) { +#pragma unroll + for (uint32_t j = 0; j < 4; ++j) { + rope_freq[mma_d][j] = + rope_rcp_scale * + __powf(rope_rcp_theta, + float(2 * ((mma_d * 16 + (j / 2) * 8 + + (lane_idx % 4) * 2 + (j % 2)) % + (HEAD_DIM / 2))) / + float(HEAD_DIM)); + } + } +} + +template +__device__ __forceinline__ void +init_states(typename KTraits::AttentionVariant variant, + float (*o_frag)[KTraits::NUM_MMA_D_VO][8], + typename KTraits::DTypeQKAccum (*m)[2], + float (*d)[2]) +{ +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { +#pragma unroll + for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { + o_frag[mma_q][mma_d][reg_id] = 0.f; + } + } + } + + if constexpr (variant.use_softmax) { +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + m[mma_q][j] = + typename KTraits::DTypeQKAccum(-gpu_iface::math::inf); + d[mma_q][j] = 1.f; + } + } + } +} + +template +__device__ __forceinline__ void +load_q_global_smem(uint32_t packed_offset, + const uint32_t qo_upper_bound, + typename KTraits::DTypeQ *q_ptr_base, + const uint32_t q_stride_n, + const uint32_t q_stride_h, + const uint_fastdiv group_size, + smem_t *q_smem, + const dim3 tid = threadIdx) +{ + using DTypeQ = typename KTraits::DTypeQ; + constexpr uint32_t UPCAST_STRIDE_Q = KTraits::UPCAST_STRIDE_Q; + const uint32_t lane_idx = tid.x, + warp_idx_x = get_warp_idx_q(tid.y); + + if (get_warp_idx_kv(tid.z) == 0) { + uint32_t row = warp_idx_x * KTraits::NUM_MMA_Q * WARP_STEP_SIZE + + lane_idx / WARP_STEP_SIZE; + uint32_t col = lane_idx % WARP_STEP_SIZE; + uint32_t q_smem_offset_w = + q_smem->get_permuted_offset(row, col); + +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t j = 0; j < 2 * 2; ++j) { + uint32_t q, r; + group_size.divmod(packed_offset + lane_idx / WARP_STEP_SIZE + + mma_q * 16 + j * 4, + q, r); + const uint32_t q_idx = q; + DTypeQ *q_ptr = + q_ptr_base + q * q_stride_n + r * q_stride_h + + (lane_idx % WARP_STEP_SIZE) * upcast_size(); +#pragma unroll + for (uint32_t mma_do = 0; mma_do < KTraits::NUM_MMA_D_QK / 4; + ++mma_do) + { + // load q fragment from gmem to smem + q_smem->load_128b_async( + q_smem_offset_w, q_ptr, q_idx < qo_upper_bound); + q_smem_offset_w = + q_smem + ->template advance_offset_by_column( + q_smem_offset_w, mma_do); + q_ptr += WARP_STEP_SIZE * upcast_size(); + } + q_smem_offset_w = + q_smem->template advance_offset_by_row<4, UPCAST_STRIDE_Q>( + q_smem_offset_w) - + 2 * KTraits::NUM_MMA_D_QK; + } + } + } +} + +template +__device__ __forceinline__ void +q_smem_inplace_apply_rotary(const uint32_t q_packed_idx, + const uint32_t qo_len, + const uint32_t kv_len, + const uint_fastdiv group_size, + smem_t *q_smem, + uint32_t *q_smem_offset_r, + float (*rope_freq)[4], + const dim3 tid = threadIdx) +{ + if (get_warp_idx_kv(tid.z) == 0) { + constexpr uint32_t UPCAST_STRIDE_Q = KTraits::UPCAST_STRIDE_Q; + const uint32_t lane_idx = tid.x; + uint32_t q_frag_local[2][4]; + static_assert(KTraits::NUM_MMA_D_QK % 4 == 0, + "NUM_MMA_D_QK must be a multiple of 4"); +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { + uint32_t q_smem_offset_r_first_half = *q_smem_offset_r; +#pragma unroll + for (uint32_t mma_di = 0; mma_di < KTraits::NUM_MMA_D_QK / 2; + ++mma_di) + { + q_smem->ldmatrix_m8n8x4(q_smem_offset_r_first_half, + q_frag_local[0]); + uint32_t q_smem_offset_r_last_half = + q_smem->template advance_offset_by_column< + KTraits::NUM_MMA_D_QK>(q_smem_offset_r_first_half, 0); + q_smem->ldmatrix_m8n8x4(q_smem_offset_r_last_half, + q_frag_local[1]); + q_frag_apply_llama_rope( + (typename KTraits::DTypeQ *)q_frag_local[0], + (typename KTraits::DTypeQ *)q_frag_local[1], + rope_freq[mma_di], + q_packed_idx + kv_len * group_size - qo_len * group_size + + mma_q * 16 + lane_idx / 4, + group_size); + q_smem->stmatrix_m8n8x4(q_smem_offset_r_last_half, + q_frag_local[1]); + q_smem->stmatrix_m8n8x4(q_smem_offset_r_first_half, + q_frag_local[0]); + q_smem_offset_r_first_half = + q_smem->template advance_offset_by_column<2>( + q_smem_offset_r_first_half, mma_di); + } + *q_smem_offset_r += 16 * UPCAST_STRIDE_Q; + } + *q_smem_offset_r -= KTraits::NUM_MMA_Q * 16 * UPCAST_STRIDE_Q; + } +} + +template +__device__ __forceinline__ void q_smem_inplace_apply_rotary_with_pos( + const uint32_t q_packed_idx_base, + const typename KTraits::IdType *q_rope_offset, + smem_t *q_smem, + const uint_fastdiv group_size, + uint32_t *q_smem_offset_r, + float (*rope_freq)[4], + const dim3 tid = threadIdx) +{ + if (get_warp_idx_kv(tid.z) == 0) { + constexpr uint32_t UPCAST_STRIDE_Q = KTraits::UPCAST_STRIDE_Q; + const uint32_t lane_idx = tid.x; + uint32_t q_frag_local[2][4]; + static_assert(KTraits::NUM_MMA_D_QK % 4 == 0, + "NUM_MMA_D_QK must be a multiple of 4"); +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { + uint32_t q_smem_offset_r_first_half = *q_smem_offset_r; +#pragma unroll + for (uint32_t mma_di = 0; mma_di < KTraits::NUM_MMA_D_QK / 2; + ++mma_di) + { + q_smem->ldmatrix_m8n8x4(q_smem_offset_r_first_half, + q_frag_local[0]); + uint32_t q_smem_offset_r_last_half = + q_smem->template advance_offset_by_column< + KTraits::NUM_MMA_D_QK>(q_smem_offset_r_first_half, 0); + q_smem->ldmatrix_m8n8x4(q_smem_offset_r_last_half, + q_frag_local[1]); + q_frag_apply_llama_rope_with_pos( + (typename KTraits::DTypeQ *)q_frag_local[0], + (typename KTraits::DTypeQ *)q_frag_local[1], + rope_freq[mma_di], + q_packed_idx_base + mma_q * 16 + lane_idx / 4, group_size, + q_rope_offset); + q_smem->stmatrix_m8n8x4(q_smem_offset_r_last_half, + q_frag_local[1]); + q_smem->stmatrix_m8n8x4(q_smem_offset_r_first_half, + q_frag_local[0]); + q_smem_offset_r_first_half = + q_smem->template advance_offset_by_column<2>( + q_smem_offset_r_first_half, mma_di); + } + *q_smem_offset_r += 16 * UPCAST_STRIDE_Q; + } + *q_smem_offset_r -= KTraits::NUM_MMA_Q * 16 * UPCAST_STRIDE_Q; + } +} + +template +__device__ __forceinline__ void +k_smem_inplace_apply_rotary(const uint32_t kv_idx_base, + smem_t *k_smem, + uint32_t *k_smem_offset_r, + float (*rope_freq)[4], + const dim3 tid = threadIdx) +{ + using DTypeKV = typename KTraits::DTypeKV; + static_assert(sizeof(DTypeKV) == 2); + constexpr uint32_t UPCAST_STRIDE_K = KTraits::UPCAST_STRIDE_K; + uint32_t k_frag_local[2][4]; + const uint32_t lane_idx = tid.x; + if constexpr (KTraits::NUM_MMA_D_QK == 4 && KTraits::NUM_WARPS_Q == 4) { + static_assert(KTraits::NUM_WARPS_KV == 1); + const uint32_t warp_idx = get_warp_idx_q(tid.y); + // horizontal-axis: y + // vertical-axis: z + // | 1-16 | 16-32 | 32-48 | 48-64 | + // | 1-16 | warp_idx=0 | warp_idx=1 | warp_idx=0 | warp_idx=1 | + // | 16-32 | warp_idx=2 | warp_idx=3 | warp_idx=2 | warp_idx=3 | + static_assert( + KTraits::NUM_MMA_KV % 2 == 0, + "when NUM_MMA_D_QK == 4, NUM_MMA_KV must be a multiple of 2"); + uint32_t kv_idx = kv_idx_base + (warp_idx / 2) * 16 + lane_idx / 4; + *k_smem_offset_r = (*k_smem_offset_r ^ (0x2 * (warp_idx % 2))) + + (warp_idx / 2) * 16 * UPCAST_STRIDE_K; +#pragma unroll + for (uint32_t i = 0; i < KTraits::NUM_MMA_KV / 2; ++i) { + uint32_t k_smem_offset_r_first_half = *k_smem_offset_r; + uint32_t mma_di = (warp_idx % 2); + k_smem->ldmatrix_m8n8x4(k_smem_offset_r_first_half, + k_frag_local[0]); + uint32_t k_smem_offset_r_last_half = + k_smem->template advance_offset_by_column<4>( + k_smem_offset_r_first_half, 0); + k_smem->ldmatrix_m8n8x4(k_smem_offset_r_last_half, k_frag_local[1]); + k_frag_apply_llama_rope((DTypeKV *)k_frag_local[0], + (DTypeKV *)k_frag_local[1], + rope_freq[mma_di], kv_idx); + k_smem->stmatrix_m8n8x4(k_smem_offset_r_last_half, k_frag_local[1]); + k_smem->stmatrix_m8n8x4(k_smem_offset_r_first_half, + k_frag_local[0]); + *k_smem_offset_r += 32 * UPCAST_STRIDE_K; + kv_idx += 32; + } + *k_smem_offset_r = + (*k_smem_offset_r ^ (0x2 * (warp_idx % 2))) - + ((warp_idx / 2) + KTraits::NUM_MMA_KV) * 16 * UPCAST_STRIDE_K; + } + else { + const uint32_t warp_idx_x = get_warp_idx_q(tid.y), + warp_idx_z = get_warp_idx_kv(tid.z); + static_assert(KTraits::NUM_MMA_D_QK % (2 * KTraits::NUM_WARPS_Q) == 0); + // horizontal axis: y + // vertical axis: z + // | (warp_idx_z, warp_idx_x) | 1-16 | 16-32 | 32-48 | 48-64 + // | ... | 1-16*NUM_MMA_KV | (0, 0) | (0, 1) | (0, 2) | + // (0, 3) | ... | 16*NUM_MMA_KV-32*NUM_MMA_KV | (1, 0) | (1, 1) | (1, + // 2) | (1, 3) | ... + // ... + uint32_t kv_idx = kv_idx_base + + (warp_idx_z * KTraits::NUM_MMA_KV * 16) + + lane_idx / 4; + *k_smem_offset_r = *k_smem_offset_r ^ (0x2 * warp_idx_x); +#pragma unroll + for (uint32_t i = 0; i < KTraits::NUM_MMA_KV; ++i) { + uint32_t k_smem_offset_r_first_half = *k_smem_offset_r; +#pragma unroll + for (uint32_t j = 0; + j < KTraits::NUM_MMA_D_QK / (2 * KTraits::NUM_WARPS_Q); ++j) + { + uint32_t mma_di = warp_idx_x + j * KTraits::NUM_WARPS_Q; + k_smem->ldmatrix_m8n8x4(k_smem_offset_r_first_half, + k_frag_local[0]); + uint32_t k_smem_offset_r_last_half = + k_smem->template advance_offset_by_column< + KTraits::NUM_MMA_D_QK>(k_smem_offset_r_first_half, 0); + k_smem->ldmatrix_m8n8x4(k_smem_offset_r_last_half, + k_frag_local[1]); + k_frag_apply_llama_rope((DTypeKV *)k_frag_local[0], + (DTypeKV *)k_frag_local[1], + rope_freq[mma_di], kv_idx); + k_smem->stmatrix_m8n8x4(k_smem_offset_r_last_half, + k_frag_local[1]); + k_smem->stmatrix_m8n8x4(k_smem_offset_r_first_half, + k_frag_local[0]); + k_smem_offset_r_first_half = + k_smem->template advance_offset_by_column< + 2 * KTraits::NUM_WARPS_Q>(k_smem_offset_r_first_half, + mma_di); + } + *k_smem_offset_r += 16 * UPCAST_STRIDE_K; + kv_idx += 16; + } + *k_smem_offset_r = (*k_smem_offset_r ^ (0x2 * warp_idx_x)) - + KTraits::NUM_MMA_KV * 16 * UPCAST_STRIDE_K; + } +} + +template +__device__ __forceinline__ void +compute_qk(smem_t *q_smem, + uint32_t *q_smem_offset_r, + smem_t *k_smem, + uint32_t *k_smem_offset_r, + typename KTraits::DTypeQKAccum (*s_frag)[KTraits::NUM_MMA_KV][8]) +{ + constexpr uint32_t UPCAST_STRIDE_Q = KTraits::UPCAST_STRIDE_Q; + constexpr uint32_t UPCAST_STRIDE_K = KTraits::UPCAST_STRIDE_K; + uint32_t a_frag[KTraits::NUM_MMA_Q][4], b_frag[4]; + // compute q*k^T +#pragma unroll + for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_QK; ++mma_d) { +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { + q_smem->ldmatrix_m8n8x4(*q_smem_offset_r, a_frag[mma_q]); + *q_smem_offset_r = + q_smem->template advance_offset_by_row<16, UPCAST_STRIDE_Q>( + *q_smem_offset_r); + } + + *q_smem_offset_r = q_smem->template advance_offset_by_column<2>( + *q_smem_offset_r, mma_d) - + KTraits::NUM_MMA_Q * 16 * UPCAST_STRIDE_Q; + +#pragma unroll + for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { + if constexpr (sizeof(typename KTraits::DTypeKV) == 1) { + uint32_t b_frag_f8[2]; + if (mma_d % 2 == 0) { + k_smem->ldmatrix_m8n8x4_left_half(*k_smem_offset_r, + b_frag_f8); + } + else { + k_smem->ldmatrix_m8n8x4_right_half(*k_smem_offset_r, + b_frag_f8); + } + b_frag_f8[0] = frag_layout_swizzle_16b_to_8b(b_frag_f8[0]); + b_frag_f8[1] = frag_layout_swizzle_16b_to_8b(b_frag_f8[1]); + vec_cast:: + cast<8>((typename KTraits::DTypeQ *)b_frag, + (typename KTraits::DTypeKV *)b_frag_f8); + } + else { + k_smem->ldmatrix_m8n8x4(*k_smem_offset_r, b_frag); + } + *k_smem_offset_r = + k_smem->template advance_offset_by_row<16, UPCAST_STRIDE_K>( + *k_smem_offset_r); + +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { + if constexpr (std::is_same_v) + { + if (mma_d == 0) { + mma::mma_sync_m16n16k16_row_col_f16f16f32< + typename KTraits::DTypeQ, MMAMode::kInit>( + s_frag[mma_q][mma_kv], a_frag[mma_q], b_frag); + } + else { + mma::mma_sync_m16n16k16_row_col_f16f16f32< + typename KTraits::DTypeQ>(s_frag[mma_q][mma_kv], + a_frag[mma_q], b_frag); + } + } + else if (std::is_same_v) { + if (mma_d == 0) { + mma::mma_sync_m16n16k16_row_col_f16f16f16< + MMAMode::kInit>((uint32_t *)s_frag[mma_q][mma_kv], + a_frag[mma_q], b_frag); + } + else { + mma::mma_sync_m16n16k16_row_col_f16f16f16( + (uint32_t *)s_frag[mma_q][mma_kv], a_frag[mma_q], + b_frag); + } + } + } + } + if constexpr (sizeof(typename KTraits::DTypeKV) == 1) { + if (mma_d % 2 == 1) { + *k_smem_offset_r = k_smem->template advance_offset_by_column<2>( + *k_smem_offset_r, mma_d / 2); + } + *k_smem_offset_r -= KTraits::NUM_MMA_KV * 16 * UPCAST_STRIDE_K; + } + else { + *k_smem_offset_r = k_smem->template advance_offset_by_column<2>( + *k_smem_offset_r, mma_d) - + KTraits::NUM_MMA_KV * 16 * UPCAST_STRIDE_K; + } + } + *q_smem_offset_r -= KTraits::NUM_MMA_D_QK * 2; + *k_smem_offset_r -= + KTraits::NUM_MMA_D_QK * sizeof(typename KTraits::DTypeKV); +} + +template +__device__ __forceinline__ void +logits_transform(const Params ¶ms, + typename KTraits::AttentionVariant variant, + const uint32_t batch_idx, + const uint32_t qo_packed_idx_base, + const uint32_t kv_idx_base, + const uint32_t qo_len, + const uint32_t kv_len, + const uint_fastdiv group_size, + DTypeQKAccum (*s_frag)[KTraits::NUM_MMA_KV][8], + const dim3 tid = threadIdx, + const uint32_t kv_head_idx = blockIdx.z) +{ + const uint32_t lane_idx = tid.x; + uint32_t q[KTraits::NUM_MMA_Q][2], r[KTraits::NUM_MMA_Q][2]; + float logits = 0., logitsTransformed = 0.; + +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + group_size.divmod(qo_packed_idx_base + mma_q * 16 + lane_idx / 4 + + 8 * j, + q[mma_q][j], r[mma_q][j]); + } + } + +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { +#pragma unroll + for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { + const uint32_t q_idx = q[mma_q][(reg_id % 4) / 2], + kv_idx = kv_idx_base + mma_kv * 16 + + 2 * (lane_idx % 4) + 8 * (reg_id / 4) + + reg_id % 2; + const uint32_t qo_head_idx = + kv_head_idx * group_size + r[mma_q][(reg_id % 4) / 2]; + +#ifdef FP16_QK_REDUCTION_SUPPORTED + if constexpr (std::is_same::value) { + logits = std::bit_cast( + fp16_ieee_to_fp32_value(s_frag[mma_q][mma_kv][reg_id])); + } + else if constexpr (!std::is_same::value) { + logits = s_frag[mma_q][mma_kv][reg_id]; + } +#else + static_assert( + !std::is_same::value, + "Set -DFP16_QK_REDUCTION_SUPPORTED and install boost_math " + "then recompile to support fp16 reduction"); + logits = s_frag[mma_q][mma_kv][reg_id]; +#endif + logitsTransformed = + variant.LogitsTransform(params, logits, batch_idx, q_idx, + kv_idx, qo_head_idx, kv_head_idx); +#ifdef FP16_QK_REDUCTION_SUPPORTED + if constexpr (std::is_same::value) { + s_frag[mma_q][mma_kv][reg_id] = std::bit_cast( + fp16_ieee_from_fp32_value(logitsTransformed)); + } + else if constexpr (!std::is_same::value) { + s_frag[mma_q][mma_kv][reg_id] = logitsTransformed; + } +#else + s_frag[mma_q][mma_kv][reg_id] = logitsTransformed; +#endif + } + } + } +} + +template +__device__ __forceinline__ void +logits_mask(const Params ¶ms, + typename KTraits::AttentionVariant variant, + const uint32_t batch_idx, + const uint32_t qo_packed_idx_base, + const uint32_t kv_idx_base, + const uint32_t qo_len, + const uint32_t kv_len, + const uint32_t chunk_end, + const uint_fastdiv group_size, + typename KTraits::DTypeQKAccum (*s_frag)[KTraits::NUM_MMA_KV][8], + const dim3 tid = threadIdx, + const uint32_t kv_head_idx = blockIdx.z) +{ + const uint32_t lane_idx = tid.x; + constexpr uint32_t NUM_MMA_Q = KTraits::NUM_MMA_Q; + constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; + using DTypeQKAccum = typename KTraits::DTypeQKAccum; + constexpr MaskMode MASK_MODE = KTraits::MASK_MODE; + uint32_t q[NUM_MMA_Q][2], r[NUM_MMA_Q][2]; +#pragma unroll + for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + group_size.divmod(qo_packed_idx_base + mma_q * 16 + lane_idx / 4 + + 8 * j, + q[mma_q][j], r[mma_q][j]); + } + } + +#pragma unroll + for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t mma_kv = 0; mma_kv < NUM_MMA_KV; ++mma_kv) { +#pragma unroll + for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { + const uint32_t q_idx = q[mma_q][(reg_id % 4) / 2], + kv_idx = kv_idx_base + mma_kv * 16 + + 2 * (lane_idx % 4) + 8 * (reg_id / 4) + + reg_id % 2; + const uint32_t qo_head_idx = + kv_head_idx * group_size + r[mma_q][(reg_id % 4) / 2]; + const bool mask = + (!(MASK_MODE == MaskMode::kCausal + ? (kv_idx + qo_len > kv_len + q_idx || + (kv_idx >= chunk_end)) + : kv_idx >= chunk_end)) && + variant.LogitsMask(params, batch_idx, q_idx, kv_idx, + qo_head_idx, kv_head_idx); + s_frag[mma_q][mma_kv][reg_id] = + (mask) ? s_frag[mma_q][mma_kv][reg_id] + : (KTraits::MaskFillValue); + } + } + } +} + +template +__device__ __forceinline__ void update_mdo_states( + typename KTraits::AttentionVariant variant, + typename KTraits::DTypeQKAccum (*s_frag)[KTraits::NUM_MMA_KV][8], + float (*o_frag)[KTraits::NUM_MMA_D_VO][8], + typename KTraits::DTypeQKAccum (*m)[2], + float (*d)[2]) +{ + using DTypeQKAccum = typename KTraits::DTypeQKAccum; + using AttentionVariant = typename KTraits::AttentionVariant; + constexpr bool use_softmax = AttentionVariant::use_softmax; + + if constexpr (use_softmax) { + const float sm_scale = variant.sm_scale_log2; + if constexpr (std::is_same_v) { +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + float m_prev = m[mma_q][j]; +#pragma unroll + for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; + ++mma_kv) + { + float m_local = + max(max(s_frag[mma_q][mma_kv][j * 2 + 0], + s_frag[mma_q][mma_kv][j * 2 + 1]), + max(s_frag[mma_q][mma_kv][j * 2 + 4], + s_frag[mma_q][mma_kv][j * 2 + 5])); + m[mma_q][j] = max(m[mma_q][j], m_local); + } + m[mma_q][j] = + max(m[mma_q][j], + gpu_iface::math::shfl_xor_sync(m[mma_q][j], 0x2)); + m[mma_q][j] = + max(m[mma_q][j], + gpu_iface::math::shfl_xor_sync(m[mma_q][j], 0x1)); + + float o_scale = gpu_iface::math::ptx_exp2( + m_prev * sm_scale - m[mma_q][j] * sm_scale); + d[mma_q][j] *= o_scale; +#pragma unroll + for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; + ++mma_d) + { + o_frag[mma_q][mma_d][j * 2 + 0] *= o_scale; + o_frag[mma_q][mma_d][j * 2 + 1] *= o_scale; + o_frag[mma_q][mma_d][j * 2 + 4] *= o_scale; + o_frag[mma_q][mma_d][j * 2 + 5] *= o_scale; + } +#pragma unroll + for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; + ++mma_kv) + { + s_frag[mma_q][mma_kv][j * 2 + 0] = + gpu_iface::math::ptx_exp2( + s_frag[mma_q][mma_kv][j * 2 + 0] * sm_scale - + m[mma_q][j] * sm_scale); + s_frag[mma_q][mma_kv][j * 2 + 1] = + gpu_iface::math::ptx_exp2( + s_frag[mma_q][mma_kv][j * 2 + 1] * sm_scale - + m[mma_q][j] * sm_scale); + s_frag[mma_q][mma_kv][j * 2 + 4] = + gpu_iface::math::ptx_exp2( + s_frag[mma_q][mma_kv][j * 2 + 4] * sm_scale - + m[mma_q][j] * sm_scale); + s_frag[mma_q][mma_kv][j * 2 + 5] = + gpu_iface::math::ptx_exp2( + s_frag[mma_q][mma_kv][j * 2 + 5] * sm_scale - + m[mma_q][j] * sm_scale); + } + } + } + } + else if constexpr (std::is_same_v) { + const half2 sm_scale = __float2half2_rn(variant.sm_scale_log2); +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { + half m_prev[2]; +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + m_prev[j] = m[mma_q][j]; +#pragma unroll + for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; + ++mma_kv) + { + half2 m_local = gpu_iface::math::hmax2( + *(half2 *)&s_frag[mma_q][mma_kv][j * 2], + *(half2 *)&s_frag[mma_q][mma_kv][j * 2 + 4]); + m[mma_q][j] = + __hmax(m[mma_q][j], __hmax(m_local.x, m_local.y)); + } + } + *(half2 *)&m[mma_q] = gpu_iface::math::hmax2( + *(half2 *)&m[mma_q], + gpu_iface::math::shfl_xor_sync(*(half2 *)&m[mma_q], 0x2)); + *(half2 *)&m[mma_q] = gpu_iface::math::hmax2( + *(half2 *)&m[mma_q], + gpu_iface::math::shfl_xor_sync(*(half2 *)&m[mma_q], 0x1)); +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + float o_scale = gpu_iface::math::ptx_exp2(float( + m_prev[j] * sm_scale.x - m[mma_q][j] * sm_scale.x)); + d[mma_q][j] *= o_scale; +#pragma unroll + for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; + ++mma_d) + { + o_frag[mma_q][mma_d][j * 2 + 0] *= o_scale; + o_frag[mma_q][mma_d][j * 2 + 1] *= o_scale; + o_frag[mma_q][mma_d][j * 2 + 4] *= o_scale; + o_frag[mma_q][mma_d][j * 2 + 5] *= o_scale; + } + half2 m2 = make_half2(m[mma_q][j], m[mma_q][j]); +#pragma unroll + for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; + ++mma_kv) + { + *(half2 *)&s_frag[mma_q][mma_kv][j * 2] = + gpu_iface::math::ptx_exp2( + *(half2 *)&s_frag[mma_q][mma_kv][j * 2] * + sm_scale - + m2 * sm_scale); + *(half2 *)&s_frag[mma_q][mma_kv][j * 2 + 4] = + gpu_iface::math::ptx_exp2( + *(half2 *)&s_frag[mma_q][mma_kv][j * 2 + 4] * + sm_scale - + m2 * sm_scale); + } + } + } + } + } +} + +template +__device__ __forceinline__ void +compute_sfm_v(smem_t *v_smem, + uint32_t *v_smem_offset_r, + typename KTraits::DTypeQKAccum (*s_frag)[KTraits::NUM_MMA_KV][8], + float (*o_frag)[KTraits::NUM_MMA_D_VO][8], + float (*d)[2]) +{ + constexpr uint32_t UPCAST_STRIDE_V = KTraits::UPCAST_STRIDE_V; + + typename KTraits::DTypeQ s_frag_f16[KTraits::NUM_MMA_Q][KTraits::NUM_MMA_KV] + [8]; + if constexpr (std::is_same_v) { +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { + vec_cast::cast<8>( + s_frag_f16[mma_q][mma_kv], s_frag[mma_q][mma_kv]); + } + } + } + + if constexpr (KTraits::AttentionVariant::use_softmax) { +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { + if constexpr (std::is_same_v) + { + mma::m16k16_rowsum_f16f16f32(d[mma_q], + s_frag_f16[mma_q][mma_kv]); + } + else { + mma::m16k16_rowsum_f16f16f32(d[mma_q], + s_frag[mma_q][mma_kv]); + } + } + } + } + +#pragma unroll + for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { +#pragma unroll + for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { + uint32_t b_frag[4]; + if constexpr (sizeof(typename KTraits::DTypeKV) == 1) { + uint32_t b_frag_f8[2]; + if (mma_d % 2 == 0) { + v_smem->ldmatrix_m8n8x4_trans_left_half(*v_smem_offset_r, + b_frag_f8); + } + else { + v_smem->ldmatrix_m8n8x4_trans_right_half(*v_smem_offset_r, + b_frag_f8); + } + b_frag_f8[0] = + frag_layout_swizzle_16b_to_8b_trans(b_frag_f8[0]); + b_frag_f8[1] = + frag_layout_swizzle_16b_to_8b_trans(b_frag_f8[1]); + vec_cast:: + cast<8>((typename KTraits::DTypeQ *)b_frag, + (typename KTraits::DTypeKV *)b_frag_f8); + swap(b_frag[1], b_frag[2]); + } + else { + v_smem->ldmatrix_m8n8x4_trans(*v_smem_offset_r, b_frag); + } +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { + if constexpr (std::is_same_v) + { + mma::mma_sync_m16n16k16_row_col_f16f16f32< + typename KTraits::DTypeQ>( + o_frag[mma_q][mma_d], + (uint32_t *)s_frag_f16[mma_q][mma_kv], b_frag); + } + else { + mma::mma_sync_m16n16k16_row_col_f16f16f32< + typename KTraits::DTypeQ>( + o_frag[mma_q][mma_d], (uint32_t *)s_frag[mma_q][mma_kv], + b_frag); + } + } + if constexpr (sizeof(typename KTraits::DTypeKV) == 1) { + if (mma_d % 2 == 1) { + *v_smem_offset_r = + v_smem->template advance_offset_by_column<2>( + *v_smem_offset_r, mma_d / 2); + } + } + else { + *v_smem_offset_r = v_smem->template advance_offset_by_column<2>( + *v_smem_offset_r, mma_d); + } + } + *v_smem_offset_r = + v_smem->template advance_offset_by_row<16, UPCAST_STRIDE_V>( + *v_smem_offset_r) - + sizeof(typename KTraits::DTypeKV) * KTraits::NUM_MMA_D_VO; + } + *v_smem_offset_r -= 16 * KTraits::NUM_MMA_KV * UPCAST_STRIDE_V; +} + +template +__device__ __forceinline__ void +normalize_d(float (*o_frag)[KTraits::NUM_MMA_D_VO][8], + typename KTraits::DTypeQKAccum (*m)[2], + float (*d)[2]) +{ + using AttentionVariant = typename KTraits::AttentionVariant; + if constexpr (AttentionVariant::use_softmax) { + float d_rcp[KTraits::NUM_MMA_Q][2]; + // compute reciprocal of d +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + d_rcp[mma_q][j] = + (m[mma_q][j] != + typename KTraits::DTypeQKAccum(-gpu_iface::math::inf)) + ? gpu_iface::math::ptx_rcp(d[mma_q][j]) + : 0.f; + } + } + +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { +#pragma unroll + for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { + o_frag[mma_q][mma_d][reg_id] = + o_frag[mma_q][mma_d][reg_id] * + d_rcp[mma_q][(reg_id % 4) / 2]; + } + } + } + } +} + +template +__device__ __forceinline__ void +finalize_m(typename KTraits::AttentionVariant variant, + typename KTraits::DTypeQKAccum (*m)[2]) +{ + if constexpr (variant.use_softmax) { +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + if (m[mma_q][j] != + typename KTraits::DTypeQKAccum(-gpu_iface::math::inf)) + { + m[mma_q][j] *= variant.sm_scale_log2; + } + } + } + } +} + +/*! + * \brief Synchronize the states of the MDO kernel across the threadblock along + * threadIdx.z. + */ +template +__device__ __forceinline__ void +threadblock_sync_mdo_states(float (*o_frag)[KTraits::NUM_MMA_D_VO][8], + typename KTraits::SharedStorage *smem_storage, + typename KTraits::DTypeQKAccum (*m)[2], + float (*d)[2], + const uint32_t warp_idx, + const uint32_t lane_idx, + const dim3 tid = threadIdx) +{ + // only necessary when blockDim.z > 1 + if constexpr (KTraits::NUM_WARPS_KV > 1) { + float *smem_o = smem_storage->cta_sync_o_smem; + float2 *smem_md = smem_storage->cta_sync_md_smem; + // o: [num_warps, NUM_MMA_Q, NUM_MMA_D_VO, WARP_SIZE(32), 8] + // md: [num_warps, NUM_MMA_Q, 16, 2 (m/d)] +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { + vec_t::memcpy( + smem_o + (((warp_idx * KTraits::NUM_MMA_Q + mma_q) * + KTraits::NUM_MMA_D_VO + + mma_d) * + WARP_SIZE + + lane_idx) * + 8, + o_frag[mma_q][mma_d]); + } + } + + if constexpr (KTraits::AttentionVariant::use_softmax) { +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + smem_md[((warp_idx * KTraits::NUM_MMA_Q + mma_q) * 2 + j) * + 8 + + lane_idx / 4] = + make_float2(float(m[mma_q][j]), d[mma_q][j]); + } + } + + // synchronize m,d first + __syncthreads(); +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { + float o_scale[2][KTraits::NUM_WARPS_KV]; +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + float m_new = -gpu_iface::math::inf, d_new = 1.f; +#pragma unroll + for (uint32_t i = 0; i < KTraits::NUM_WARPS_KV; ++i) { + float2 md = smem_md[(((i * KTraits::NUM_WARPS_Q + + get_warp_idx_q(tid.y)) * + KTraits::NUM_MMA_Q + + mma_q) * + 2 + + j) * + 8 + + lane_idx / 4]; + float m_prev = m_new, d_prev = d_new; + m_new = max(m_new, md.x); + d_new = + d_prev * gpu_iface::math::ptx_exp2(m_prev - m_new) + + md.y * gpu_iface::math::ptx_exp2(md.x - m_new); + } + +#pragma unroll + for (uint32_t i = 0; i < KTraits::NUM_WARPS_KV; ++i) { + float2 md = smem_md[(((i * KTraits::NUM_WARPS_Q + + get_warp_idx_q(tid.y)) * + KTraits::NUM_MMA_Q + + mma_q) * + 2 + + j) * + 8 + + lane_idx / 4]; + float mi = md.x; + o_scale[j][i] = + gpu_iface::math::ptx_exp2(float(mi - m_new)); + } + m[mma_q][j] = typename KTraits::DTypeQKAccum(m_new); + d[mma_q][j] = d_new; + } + +#pragma unroll + for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) + { + vec_t o_new; + o_new.fill(0.f); +#pragma unroll + for (uint32_t i = 0; i < KTraits::NUM_WARPS_KV; ++i) { + vec_t oi; + oi.load(smem_o + ((((i * KTraits::NUM_WARPS_Q + + get_warp_idx_q(tid.y)) * + KTraits::NUM_MMA_Q + + mma_q) * + KTraits::NUM_MMA_D_VO + + mma_d) * + WARP_SIZE + + lane_idx) * + 8); + +#pragma unroll + for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { + o_new[reg_id] += + oi[reg_id] * o_scale[(reg_id % 4) / 2][i]; + } + } + o_new.store(o_frag[mma_q][mma_d]); + } + } + } + else { + // synchronize m,d first + __syncthreads(); +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) + { + vec_t o_new; + o_new.fill(0.f); +#pragma unroll + for (uint32_t i = 0; i < KTraits::NUM_WARPS_KV; ++i) { + vec_t oi; + oi.load(smem_o + ((((i * KTraits::NUM_WARPS_Q + + get_warp_idx_q(tid.y)) * + KTraits::NUM_MMA_Q + + mma_q) * + KTraits::NUM_MMA_D_VO + + mma_d) * + WARP_SIZE + + lane_idx) * + 8); +#pragma unroll + for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { + o_new[reg_id] += oi[reg_id]; + } + } + o_new.store(o_frag[mma_q][mma_d]); + } + } + } + } +} + +template +__device__ __forceinline__ void +write_o_reg_gmem(float (*o_frag)[KTraits::NUM_MMA_D_VO][8], + smem_t *o_smem, + typename KTraits::DTypeO *o_ptr_base, + const uint32_t o_packed_idx_base, + const uint32_t qo_upper_bound, + const uint32_t o_stride_n, + const uint32_t o_stride_h, + const uint_fastdiv group_size, + const dim3 tid = threadIdx) +{ + using DTypeO = typename KTraits::DTypeO; + constexpr uint32_t UPCAST_STRIDE_O = KTraits::UPCAST_STRIDE_O; + const uint32_t warp_idx_x = get_warp_idx_q(tid.y); + const uint32_t lane_idx = tid.x; + + if constexpr (sizeof(DTypeO) == 4) { +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + uint32_t q, r; + group_size.divmod(o_packed_idx_base + lane_idx / 4 + + mma_q * 16 + j * 8, + q, r); + const uint32_t o_idx = q; +#pragma unroll + for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) + { + if (o_idx < qo_upper_bound) { + *reinterpret_cast( + o_ptr_base + q * o_stride_n + r * o_stride_h + + mma_d * 16 + (lane_idx % 4) * 2) = + *reinterpret_cast( + &o_frag[mma_q][mma_d][j * 2]); + *reinterpret_cast( + o_ptr_base + q * o_stride_n + r * o_stride_h + + mma_d * 16 + 8 + (lane_idx % 4) * 2) = + *reinterpret_cast( + &o_frag[mma_q][mma_d][4 + j * 2]); + } + } + } + } + } + else { + if (get_warp_idx_kv(tid.z) == 0) { +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) + { + uint32_t o_frag_f16[8 / 2]; + vec_cast::cast<8>((DTypeO *)o_frag_f16, + o_frag[mma_q][mma_d]); + +#ifdef FLASHINFER_STMATRIX_M8N8X4_ENABLED + uint32_t o_smem_offset_w = + o_smem->get_permuted_offset( + (warp_idx_x * KTraits::NUM_MMA_Q + mma_q) * 16 + + lane_idx % 16, + mma_d * 2 + lane_idx / 16); + o_smem->stmatrix_m8n8x4(o_smem_offset_w, o_frag_f16); +#else + uint32_t o_smem_offset_w = + o_smem->get_permuted_offset( + (warp_idx_x * KTraits::NUM_MMA_Q + mma_q) * 16 + + lane_idx / 4, + mma_d * 2); + ((uint32_t *)(o_smem->base + + o_smem_offset_w))[lane_idx % 4] = + o_frag_f16[0]; + ((uint32_t *)(o_smem->base + o_smem_offset_w + + 8 * UPCAST_STRIDE_O))[lane_idx % 4] = + o_frag_f16[1]; + ((uint32_t *)(o_smem->base + + (o_smem_offset_w ^ 0x1)))[lane_idx % 4] = + o_frag_f16[2]; + ((uint32_t *)(o_smem->base + (o_smem_offset_w ^ 0x1) + + 8 * UPCAST_STRIDE_O))[lane_idx % 4] = + o_frag_f16[3]; +#endif + } + } + + uint32_t o_smem_offset_w = + o_smem->get_permuted_offset( + warp_idx_x * KTraits::NUM_MMA_Q * 16 + lane_idx / 8, + lane_idx % 8); + +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t j = 0; j < 2 * 2; ++j) { + uint32_t q, r; + group_size.divmod(o_packed_idx_base + lane_idx / 8 + + mma_q * 16 + j * 4, + q, r); + const uint32_t o_idx = q; + DTypeO *o_ptr = o_ptr_base + q * o_stride_n + + r * o_stride_h + + (lane_idx % 8) * upcast_size(); +#pragma unroll + for (uint32_t mma_do = 0; + mma_do < KTraits::NUM_MMA_D_VO / 4; ++mma_do) + { + if (o_idx < qo_upper_bound) { + o_smem->store_128b(o_smem_offset_w, o_ptr); + } + o_ptr += 8 * upcast_size(); + o_smem_offset_w = + o_smem->template advance_offset_by_column<8>( + o_smem_offset_w, mma_do); + } + o_smem_offset_w = o_smem->template advance_offset_by_row< + 4, UPCAST_STRIDE_O>(o_smem_offset_w) - + 2 * KTraits::NUM_MMA_D_VO; + } + } + } + } +} + +} // namespace + +/*! + * \brief FlashAttention prefill CUDA kernel for a single request. + * \tparam partition_kv Whether to split kv_len into chunks. + * \tparam mask_mode The mask mode used in the attention operation. + * \tparam POS_ENCODING_MODE The positional encoding mode. + * \tparam NUM_MMA_Q The number of fragments in x dimension. + * \tparam NUM_MMA_D_VO The number of fragments in y dimension. + * \tparam NUM_MMA_KV The number of fragments in z dimension. + * \tparam num_warps The number of warps in the threadblock. + * \tparam DTypeQ The data type of the query tensor. + * \tparam DTypeKV The data type of the key/value tensor. + * \tparam DTypeO The data type of the output tensor. + * \param q The query tensor. + * \param k The key tensor. + * \param v The value tensor. + * \param o The output tensor. + * \param tmp The temporary buffer (used when partition_kv is true). + * \param lse The logsumexp value. + * \param rope_rcp_scale 1/(rope_scale), where rope_scale is the scaling + * factor used in RoPE interpolation. + * \param rope_rcp_theta 1/(rope_theta), where rope_theta is the theta + * used in RoPE. + */ +template +__device__ __forceinline__ void +SinglePrefillWithKVCacheDevice(const Params params, + typename KTraits::SharedStorage &smem_storage, + const dim3 tid = threadIdx, + const uint32_t bx = blockIdx.x, + const uint32_t chunk_idx = blockIdx.y, + const uint32_t kv_head_idx = blockIdx.z, + const uint32_t num_chunks = gridDim.y, + const uint32_t num_kv_heads = gridDim.z) +{ + using DTypeQ = typename Params::DTypeQ; +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) + if constexpr (std::is_same_v) { + FLASHINFER_RUNTIME_ASSERT( + "Prefill kernels do not support bf16 on sm75."); + } + else { +#endif + using DTypeKV = typename Params::DTypeKV; + using DTypeO = typename Params::DTypeO; + using DTypeQKAccum = typename KTraits::DTypeQKAccum; + using AttentionVariant = typename KTraits::AttentionVariant; + [[maybe_unused]] constexpr uint32_t NUM_MMA_Q = KTraits::NUM_MMA_Q; + [[maybe_unused]] constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; + [[maybe_unused]] constexpr uint32_t NUM_MMA_D_QK = + KTraits::NUM_MMA_D_QK; + [[maybe_unused]] constexpr uint32_t NUM_MMA_D_VO = + KTraits::NUM_MMA_D_VO; + [[maybe_unused]] constexpr uint32_t HEAD_DIM_QK = KTraits::HEAD_DIM_QK; + [[maybe_unused]] constexpr uint32_t HEAD_DIM_VO = KTraits::HEAD_DIM_VO; + [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_Q = + KTraits::UPCAST_STRIDE_Q; + [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_K = + KTraits::UPCAST_STRIDE_K; + [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_V = + KTraits::UPCAST_STRIDE_V; + [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_O = + KTraits::UPCAST_STRIDE_O; + [[maybe_unused]] constexpr uint32_t CTA_TILE_Q = KTraits::CTA_TILE_Q; + [[maybe_unused]] constexpr uint32_t CTA_TILE_KV = KTraits::CTA_TILE_KV; + [[maybe_unused]] constexpr uint32_t NUM_WARPS_Q = KTraits::NUM_WARPS_Q; + [[maybe_unused]] constexpr uint32_t NUM_WARPS_KV = + KTraits::NUM_WARPS_KV; + [[maybe_unused]] constexpr SwizzleMode SWIZZLE_MODE_Q = + KTraits::SWIZZLE_MODE_Q; + [[maybe_unused]] constexpr SwizzleMode SWIZZLE_MODE_KV = + KTraits::SWIZZLE_MODE_KV; + [[maybe_unused]] constexpr uint32_t KV_THR_LAYOUT_ROW = + KTraits::KV_THR_LAYOUT_ROW; + [[maybe_unused]] constexpr uint32_t KV_THR_LAYOUT_COL = + KTraits::KV_THR_LAYOUT_COL; + [[maybe_unused]] constexpr MaskMode MASK_MODE = KTraits::MASK_MODE; + + DTypeQ *q = params.q; + DTypeKV *k = params.k; + DTypeKV *v = params.v; + DTypeO *o = params.o; + float *lse = params.lse; + const uint32_t qo_len = params.qo_len; + const uint32_t kv_len = params.kv_len; + const bool partition_kv = params.partition_kv; + const uint32_t q_stride_n = params.q_stride_n; + const uint32_t q_stride_h = params.q_stride_h; + const uint32_t k_stride_n = params.k_stride_n; + const uint32_t k_stride_h = params.k_stride_h; + const uint32_t v_stride_n = params.v_stride_n; + const uint32_t v_stride_h = params.v_stride_h; + const uint_fastdiv &group_size = params.group_size; + + static_assert(sizeof(DTypeQ) == 2); + const uint32_t lane_idx = tid.x, + warp_idx = get_warp_idx(tid.y, tid.z); + const uint32_t num_qo_heads = num_kv_heads * group_size; + + const uint32_t max_chunk_size = + partition_kv ? ceil_div(kv_len, num_chunks) : kv_len; + const uint32_t chunk_start = + partition_kv ? chunk_idx * max_chunk_size : 0; + const uint32_t chunk_end = + partition_kv ? min((chunk_idx + 1) * max_chunk_size, kv_len) + : kv_len; + const uint32_t chunk_size = chunk_end - chunk_start; + + auto block = cg::this_thread_block(); + auto smem = reinterpret_cast(&smem_storage); + AttentionVariant variant(params, /*batch_idx=*/0, smem); + const uint32_t window_left = variant.window_left; + + DTypeQKAccum s_frag[NUM_MMA_Q][NUM_MMA_KV][8]; + alignas(16) float o_frag[NUM_MMA_Q][NUM_MMA_D_VO][8]; + DTypeQKAccum m[NUM_MMA_Q][2]; + float d[NUM_MMA_Q][2]; + float rope_freq[NUM_MMA_D_QK / 2][4]; + if constexpr (KTraits::POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) + { + const float rope_rcp_scale = params.rope_rcp_scale; + const float rope_rcp_theta = params.rope_rcp_theta; + init_rope_freq(rope_freq, rope_rcp_scale, rope_rcp_theta, + tid.x); + } + init_states(variant, o_frag, m, d); + + // cooperative fetch q fragment from gmem to reg + const uint32_t qo_packed_idx_base = + (bx * NUM_WARPS_Q + get_warp_idx_q(tid.y)) * NUM_MMA_Q * + 16; + smem_t qo_smem(smem_storage.q_smem); + const uint32_t o_stride_n = num_qo_heads * HEAD_DIM_VO, + o_stride_h = HEAD_DIM_VO; + DTypeQ *q_ptr_base = q + (kv_head_idx * group_size) * q_stride_h; + DTypeO *o_ptr_base = partition_kv + ? o + chunk_idx * o_stride_n + + (kv_head_idx * group_size) * o_stride_h + : o + (kv_head_idx * group_size) * o_stride_h; + + uint32_t q_smem_offset_r = qo_smem.get_permuted_offset( + get_warp_idx_q(tid.y) * NUM_MMA_Q * 16 + lane_idx % 16, + lane_idx / 16); + + load_q_global_smem(qo_packed_idx_base, qo_len, q_ptr_base, + q_stride_n, q_stride_h, group_size, + &qo_smem, tid); + + memory::commit_group(); + if constexpr (KTraits::POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) + { + memory::wait_group<0>(); + block.sync(); + q_smem_inplace_apply_rotary( + qo_packed_idx_base, qo_len, kv_len, group_size, &qo_smem, + &q_smem_offset_r, rope_freq, tid); + block.sync(); + } + + smem_t k_smem(smem_storage.k_smem), + v_smem(smem_storage.v_smem); + + const uint32_t num_iterations = ceil_div( + MASK_MODE == MaskMode::kCausal + ? min(chunk_size, sub_if_greater_or_zero( + kv_len - qo_len + + ((bx + 1) * CTA_TILE_Q) / group_size, + chunk_start)) + : chunk_size, + CTA_TILE_KV); + + const uint32_t window_iteration = ceil_div( + sub_if_greater_or_zero(kv_len + (bx + 1) * CTA_TILE_Q / group_size, + qo_len + window_left + chunk_start), + CTA_TILE_KV); + + const uint32_t mask_iteration = + (MASK_MODE == MaskMode::kCausal + ? min(chunk_size, + sub_if_greater_or_zero( + kv_len + (bx * CTA_TILE_Q) / group_size - qo_len, + chunk_start)) + : chunk_size) / + CTA_TILE_KV; + + DTypeKV *k_ptr = + k + + (chunk_start + warp_idx * KV_THR_LAYOUT_ROW + + lane_idx / KV_THR_LAYOUT_COL) * + k_stride_n + + kv_head_idx * k_stride_h + + (lane_idx % KV_THR_LAYOUT_COL) * upcast_size(); + DTypeKV *v_ptr = + v + + (chunk_start + warp_idx * KV_THR_LAYOUT_ROW + + lane_idx / KV_THR_LAYOUT_COL) * + v_stride_n + + kv_head_idx * v_stride_h + + (lane_idx % KV_THR_LAYOUT_COL) * upcast_size(); + + uint32_t k_smem_offset_r = + k_smem.template get_permuted_offset( + get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + + 8 * (lane_idx / 16) + lane_idx % 8, + (lane_idx % 16) / 8), + v_smem_offset_r = + v_smem.template get_permuted_offset( + get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + + lane_idx % 16, + lane_idx / 16), + k_smem_offset_w = + k_smem.template get_permuted_offset( + warp_idx * KV_THR_LAYOUT_ROW + + lane_idx / KV_THR_LAYOUT_COL, + lane_idx % KV_THR_LAYOUT_COL), + v_smem_offset_w = + v_smem.template get_permuted_offset( + warp_idx * KV_THR_LAYOUT_ROW + + lane_idx / KV_THR_LAYOUT_COL, + lane_idx % KV_THR_LAYOUT_COL); + produce_kv( + k_smem, &k_smem_offset_w, &k_ptr, k_stride_n, 0, chunk_size, tid); + memory::commit_group(); + produce_kv( + v_smem, &v_smem_offset_w, &v_ptr, v_stride_n, 0, chunk_size, tid); + memory::commit_group(); + +#pragma unroll 1 + for (uint32_t iter = 0; iter < num_iterations; ++iter) { + memory::wait_group<1>(); + block.sync(); + + if constexpr (KTraits::POS_ENCODING_MODE == + PosEncodingMode::kRoPELlama) + { + k_smem_inplace_apply_rotary( + chunk_start + iter * CTA_TILE_KV, &k_smem, &k_smem_offset_r, + rope_freq, tid); + block.sync(); + } + + // compute attention score + compute_qk(&qo_smem, &q_smem_offset_r, &k_smem, + &k_smem_offset_r, s_frag); + + logits_transform( + params, variant, /*batch_idx=*/0, qo_packed_idx_base, + chunk_start + + (iter * NUM_WARPS_KV + get_warp_idx_kv(tid.z)) * + NUM_MMA_KV * 16, + qo_len, kv_len, group_size, s_frag, tid, kv_head_idx); + + // apply mask + if (MASK_MODE == MaskMode::kCustom || + (iter >= mask_iteration || iter < window_iteration)) + { + logits_mask( + params, variant, /*batch_idx=*/0, qo_packed_idx_base, + chunk_start + (iter * NUM_WARPS_KV + + get_warp_idx_kv(tid.z)) * + NUM_MMA_KV * 16, + qo_len, kv_len, chunk_end, group_size, s_frag, tid, + kv_head_idx); + } + + // compute m,d states in online softmax + update_mdo_states(variant, s_frag, o_frag, m, d); + + block.sync(); + produce_kv( + k_smem, &k_smem_offset_w, &k_ptr, k_stride_n, + (iter + 1) * CTA_TILE_KV, chunk_size, tid); + memory::commit_group(); + memory::wait_group<1>(); + block.sync(); + + // compute sfm*v + compute_sfm_v(&v_smem, &v_smem_offset_r, s_frag, o_frag, + d); + + block.sync(); + produce_kv( + v_smem, &v_smem_offset_w, &v_ptr, v_stride_n, + (iter + 1) * CTA_TILE_KV, chunk_size, tid); + memory::commit_group(); + } + memory::wait_group<0>(); + block.sync(); + + finalize_m(variant, m); + + // threadblock synchronization + threadblock_sync_mdo_states(o_frag, &smem_storage, m, d, + warp_idx, lane_idx, tid); + + // normalize d + normalize_d(o_frag, m, d); + + // write back + write_o_reg_gmem( + o_frag, &qo_smem, o_ptr_base, qo_packed_idx_base, qo_len, + /*o_stride_n=*/ + partition_kv ? num_chunks * o_stride_n : o_stride_n, + /*o_stride_h=*/o_stride_h, group_size, tid); + + // write lse + if constexpr (variant.use_softmax) { + if (lse != nullptr || partition_kv) { + if (get_warp_idx_kv(tid.z) == 0) { +#pragma unroll + for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + uint32_t q, r; + group_size.divmod(qo_packed_idx_base + + lane_idx / 4 + j * 8 + + mma_q * 16, + q, r); + const uint32_t qo_head_idx = + kv_head_idx * group_size + r; + const uint32_t qo_idx = q; + if (qo_idx < qo_len) { + if (partition_kv) { + lse[(qo_idx * num_chunks + chunk_idx) * + num_qo_heads + + qo_head_idx] = + gpu_iface::math::ptx_log2(d[mma_q][j]) + + float(m[mma_q][j]); + } + else { + lse[qo_idx * num_qo_heads + qo_head_idx] = + gpu_iface::math::ptx_log2(d[mma_q][j]) + + float(m[mma_q][j]); + } + } + } + } + } + } + } +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) + } +#endif +} + +template +__global__ +__launch_bounds__(KTraits::NUM_THREADS) void SinglePrefillWithKVCacheKernel( + const __grid_constant__ Params params) +{ + extern __shared__ uint8_t smem[]; + auto &smem_storage = + reinterpret_cast(smem); + SinglePrefillWithKVCacheDevice(params, smem_storage); +} + +template +gpuError_t SinglePrefillWithKVCacheDispatched(Params params, + typename Params::DTypeO *tmp, + gpuStream_t stream) +{ + using DTypeQ = typename Params::DTypeQ; + using DTypeKV = typename Params::DTypeKV; + using DTypeO = typename Params::DTypeO; + const uint32_t num_qo_heads = params.num_qo_heads; + const uint32_t num_kv_heads = params.num_kv_heads; + const uint32_t qo_len = params.qo_len; + const uint32_t kv_len = params.kv_len; + if (kv_len < qo_len && MASK_MODE == MaskMode::kCausal) { + std::ostringstream err_msg; + err_msg << "When mask_mode is set to MaskMode::kCausal, kv_len must be " + "greater than or equal to qo_len, got kv_len" + << kv_len << " and qo_len " << qo_len; + FLASHINFER_ERROR(err_msg.str()); + } + + const uint32_t group_size = num_qo_heads / num_kv_heads; + constexpr uint32_t NUM_MMA_D_QK = HEAD_DIM_QK / 16; + constexpr uint32_t NUM_MMA_D_VO = HEAD_DIM_VO / 16; + int64_t packed_qo_len = qo_len * group_size; + uint32_t cta_tile_q = FA2DetermineCtaTileQ(packed_qo_len, HEAD_DIM_VO); + + DISPATCH_CTA_TILE_Q(cta_tile_q, CTA_TILE_Q, { + constexpr uint32_t NUM_WARPS_Q = get_num_warps_q(CTA_TILE_Q); + constexpr uint32_t NUM_WARPS_KV = get_num_warps_kv(CTA_TILE_Q); + constexpr uint32_t NUM_MMA_Q = get_num_mma_q(CTA_TILE_Q); + + using DTypeQKAccum = + typename std::conditional, + half, float>::type; + + int dev_id = 0; + FI_GPU_CALL(gpuGetDevice(&dev_id)); + int max_smem_per_sm = 0; + FI_GPU_CALL(gpuDeviceGetAttribute( + &max_smem_per_sm, gpuDevAttrMaxSharedMemoryPerMultiProcessor, + dev_id)); + // we expect each sm execute two threadblocks + const int num_ctas_per_sm = + max_smem_per_sm >= 2 * (CTA_TILE_Q * HEAD_DIM_QK * sizeof(DTypeQ) + + (HEAD_DIM_QK + HEAD_DIM_VO) * 16 * + NUM_WARPS_KV * sizeof(DTypeKV)) + ? 2 + : 1; + const int max_smem_per_threadblock = max_smem_per_sm / num_ctas_per_sm; + + const uint32_t max_num_mma_kv_reg = + (HEAD_DIM_VO >= 128 && NUM_MMA_Q == 2 && + POS_ENCODING_MODE == PosEncodingMode::kRoPELlama && + !USE_FP16_QK_REDUCTION) + ? 2 + : (8 / NUM_MMA_Q); + const uint32_t max_num_mma_kv_smem = + (max_smem_per_threadblock - + CTA_TILE_Q * HEAD_DIM_QK * sizeof(DTypeQ)) / + ((HEAD_DIM_QK + HEAD_DIM_VO) * 16 * NUM_WARPS_KV * sizeof(DTypeKV)); + + // control NUM_MMA_KV for maximum warp occupancy + DISPATCH_NUM_MMA_KV( + min(max_num_mma_kv_smem, max_num_mma_kv_reg), NUM_MMA_KV, { + using KTraits = + KernelTraits; + if constexpr (KTraits::IsInvalid()) { + // Invalid configuration, skip + std::ostringstream err_msg; + err_msg << "FlashInfer Internal Error: Invalid " + "configuration : NUM_MMA_Q=" + << NUM_MMA_Q << " NUM_MMA_D_QK=" << NUM_MMA_D_QK + << " NUM_MMA_D_VO=" << NUM_MMA_D_VO + << " NUM_MMA_KV=" << NUM_MMA_KV + << " NUM_WARPS_Q=" << NUM_WARPS_Q + << " NUM_WARPS_KV=" << NUM_WARPS_KV + << " please create an issue " + "(https://github.com/flashinfer-ai/flashinfer/" + "issues)" + " and report the issue to the developers."; + FLASHINFER_ERROR(err_msg.str()); + } + else { + constexpr uint32_t num_threads = + (NUM_WARPS_Q * NUM_WARPS_KV) * WARP_SIZE; + auto kernel = + SinglePrefillWithKVCacheKernel; + size_t smem_size = sizeof(typename KTraits::SharedStorage); + FI_GPU_CALL(gpuFuncSetAttribute( + kernel, gpuFuncAttributeMaxDynamicSharedMemorySize, + smem_size)); + int num_blocks_per_sm = 0; + int num_sm = 0; + FI_GPU_CALL(gpuDeviceGetAttribute( + &num_sm, gpuDevAttrMultiProcessorCount, dev_id)); + FI_GPU_CALL(gpuOccupancyMaxActiveBlocksPerMultiprocessor( + &num_blocks_per_sm, kernel, num_threads, smem_size)); + uint32_t max_num_kv_chunks = + (num_blocks_per_sm * num_sm) / + (num_kv_heads * + ceil_div(qo_len * group_size, CTA_TILE_Q)); + uint32_t num_chunks; + if (max_num_kv_chunks > 0) { + uint32_t chunk_size = + max(ceil_div(kv_len, max_num_kv_chunks), 256); + num_chunks = ceil_div(kv_len, chunk_size); + } + else { + num_chunks = 0; + } + + if (num_chunks <= 1 || tmp == nullptr) { + // Enough parallelism, do not split-kv + params.partition_kv = false; + void *args[] = {(void *)¶ms}; + dim3 nblks(ceil_div(qo_len * group_size, CTA_TILE_Q), 1, + num_kv_heads); + dim3 nthrs(32, NUM_WARPS_Q, NUM_WARPS_KV); // FIXME + FI_GPU_CALL(gpuLaunchKernel((void *)kernel, nblks, + nthrs, args, smem_size, + stream)); + } + else { + // Use cooperative groups to increase occupancy + params.partition_kv = true; + float *tmp_lse = + (float *)(tmp + num_chunks * qo_len * num_qo_heads * + HEAD_DIM_VO); + auto o = params.o; + auto lse = params.lse; + params.o = tmp; + params.lse = tmp_lse; + void *args[] = {(void *)¶ms}; + dim3 nblks(ceil_div(qo_len * group_size, CTA_TILE_Q), + num_chunks, num_kv_heads); + dim3 nthrs(32, NUM_WARPS_Q, NUM_WARPS_KV); + FI_GPU_CALL(gpuLaunchKernel((void *)kernel, nblks, + nthrs, args, smem_size, + stream)); + if constexpr (AttentionVariant::use_softmax) { + FI_GPU_CALL(MergeStates( + tmp, tmp_lse, o, lse, num_chunks, qo_len, + num_qo_heads, HEAD_DIM_VO, stream)); + } + else { + FI_GPU_CALL(AttentionSum(tmp, o, num_chunks, qo_len, + num_qo_heads, HEAD_DIM_VO, + stream)); + } + } + } + }) + }); + return gpuSuccess; +} + +template +__global__ +__launch_bounds__(KTraits::NUM_THREADS) void BatchPrefillWithRaggedKVCacheKernel( + const __grid_constant__ Params params) +{ + using DTypeQ = typename Params::DTypeQ; +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) + if constexpr (std::is_same_v) { + FLASHINFER_RUNTIME_ASSERT( + "Prefill kernels do not support bf16 on sm75."); + } + else { +#endif + using DTypeKV = typename Params::DTypeKV; + using DTypeO = typename Params::DTypeO; + using IdType = typename Params::IdType; + using DTypeQKAccum = typename KTraits::DTypeQKAccum; + using AttentionVariant = typename KTraits::AttentionVariant; + [[maybe_unused]] constexpr uint32_t NUM_MMA_Q = KTraits::NUM_MMA_Q; + [[maybe_unused]] constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; + [[maybe_unused]] constexpr uint32_t NUM_MMA_D_QK = + KTraits::NUM_MMA_D_QK; + [[maybe_unused]] constexpr uint32_t NUM_MMA_D_VO = + KTraits::NUM_MMA_D_VO; + [[maybe_unused]] constexpr uint32_t HEAD_DIM_QK = KTraits::HEAD_DIM_QK; + [[maybe_unused]] constexpr uint32_t HEAD_DIM_VO = KTraits::HEAD_DIM_VO; + [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_Q = + KTraits::UPCAST_STRIDE_Q; + [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_K = + KTraits::UPCAST_STRIDE_K; + [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_V = + KTraits::UPCAST_STRIDE_V; + [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_O = + KTraits::UPCAST_STRIDE_O; + [[maybe_unused]] constexpr uint32_t CTA_TILE_Q = KTraits::CTA_TILE_Q; + [[maybe_unused]] constexpr uint32_t CTA_TILE_KV = KTraits::CTA_TILE_KV; + [[maybe_unused]] constexpr uint32_t NUM_WARPS_Q = KTraits::NUM_WARPS_Q; + [[maybe_unused]] constexpr uint32_t NUM_WARPS_KV = + KTraits::NUM_WARPS_KV; + [[maybe_unused]] constexpr SwizzleMode SWIZZLE_MODE_Q = + KTraits::SWIZZLE_MODE_Q; + [[maybe_unused]] constexpr SwizzleMode SWIZZLE_MODE_KV = + KTraits::SWIZZLE_MODE_KV; + [[maybe_unused]] constexpr uint32_t KV_THR_LAYOUT_ROW = + KTraits::KV_THR_LAYOUT_ROW; + [[maybe_unused]] constexpr uint32_t KV_THR_LAYOUT_COL = + KTraits::KV_THR_LAYOUT_COL; + [[maybe_unused]] constexpr MaskMode MASK_MODE = KTraits::MASK_MODE; + + DTypeQ *q = params.q; + IdType *request_indices = params.request_indices; + IdType *qo_tile_indices = params.qo_tile_indices; + IdType *kv_tile_indices = params.kv_tile_indices; + IdType *q_indptr = params.q_indptr; + IdType *kv_indptr = params.kv_indptr; + DTypeKV *k = params.k; + DTypeKV *v = params.v; + IdType *o_indptr = params.o_indptr; + DTypeO *o = params.o; + float *lse = params.lse; + bool *block_valid_mask = params.block_valid_mask; + const bool partition_kv = params.partition_kv; + const uint32_t q_stride_n = params.q_stride_n; + const uint32_t q_stride_h = params.q_stride_h; + const uint32_t k_stride_n = params.k_stride_n; + const uint32_t k_stride_h = params.k_stride_h; + const uint32_t v_stride_n = params.v_stride_n; + const uint32_t v_stride_h = params.v_stride_h; + const uint_fastdiv &group_size = params.group_size; + + static_assert(sizeof(DTypeQ) == 2); + const uint32_t kv_chunk_size = *(params.kv_chunk_size_ptr); + const dim3 &tid = threadIdx; + + auto block = cg::this_thread_block(); + const uint32_t bx = blockIdx.x, lane_idx = tid.x, + warp_idx = get_warp_idx(tid.y, tid.z), + kv_head_idx = blockIdx.z; + if (block_valid_mask && !block_valid_mask[bx]) { + return; + } + const uint32_t num_kv_heads = gridDim.z, + num_qo_heads = group_size * num_kv_heads; + const uint32_t request_idx = request_indices[bx], + qo_tile_idx = qo_tile_indices[bx], + kv_tile_idx = kv_tile_indices[bx]; + extern __shared__ uint8_t smem[]; + auto &smem_storage = + reinterpret_cast(smem); + AttentionVariant variant(params, /*batch_idx=*/request_idx, smem); + const uint32_t qo_len = variant.qo_len, kv_len = variant.kv_len, + window_left = variant.window_left; + const uint32_t kv_len_safe = kv_len > 0 ? kv_len : 1; + const uint32_t max_chunk_size = partition_kv ? kv_chunk_size : kv_len; + const uint32_t chunk_start = + partition_kv ? kv_tile_idx * max_chunk_size : 0; + const uint32_t chunk_end = + partition_kv ? min((kv_tile_idx + 1) * max_chunk_size, kv_len) + : kv_len; + const uint32_t chunk_size = chunk_end - chunk_start; + const uint32_t qo_upper_bound = + min(qo_len, ceil_div((qo_tile_idx + 1) * CTA_TILE_Q, group_size)); + + DTypeQKAccum s_frag[NUM_MMA_Q][NUM_MMA_KV][8]; + alignas(16) float o_frag[NUM_MMA_Q][NUM_MMA_D_VO][8]; + DTypeQKAccum m[NUM_MMA_Q][2]; + float d[NUM_MMA_Q][2]; + float rope_freq[NUM_MMA_D_QK / 2][4]; + + if constexpr (KTraits::POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) + { + const float rope_rcp_scale = params.rope_rcp_scale; + const float rope_rcp_theta = params.rope_rcp_theta; + init_rope_freq(rope_freq, rope_rcp_scale, rope_rcp_theta, + tid.x); + } + init_states(variant, o_frag, m, d); + + const uint32_t qo_packed_idx_base = + (qo_tile_idx * NUM_WARPS_Q + get_warp_idx_q(tid.y)) * + NUM_MMA_Q * 16; + smem_t qo_smem(smem_storage.q_smem); + const uint32_t o_stride_n = num_qo_heads * HEAD_DIM_VO, + o_stride_h = HEAD_DIM_VO; + + DTypeQ *q_ptr_base = q + q_indptr[request_idx] * q_stride_n + + kv_head_idx * group_size * q_stride_h; + + DTypeO *o_ptr_base = + partition_kv + ? o + (o_indptr[request_idx] + kv_tile_idx) * o_stride_n + + (kv_head_idx * group_size) * o_stride_h + : o + o_indptr[request_idx] * o_stride_n + + (kv_head_idx * group_size) * o_stride_h; + + uint32_t q_smem_offset_r = qo_smem.get_permuted_offset( + get_warp_idx_q(tid.y) * NUM_MMA_Q * 16 + lane_idx % 16, + lane_idx / 16); + + load_q_global_smem(qo_packed_idx_base, qo_upper_bound, + q_ptr_base, q_stride_n, q_stride_h, + group_size, &qo_smem, tid); + + memory::commit_group(); + + if constexpr (KTraits::POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) + { + memory::wait_group<0>(); + block.sync(); + IdType *q_rope_offset = nullptr; + + if constexpr (has_maybe_q_rope_offset_v) { + q_rope_offset = params.maybe_q_rope_offset; + } + if (!q_rope_offset) { + q_smem_inplace_apply_rotary( + qo_packed_idx_base, qo_len, kv_len, group_size, &qo_smem, + &q_smem_offset_r, rope_freq, tid); + } + else { + q_smem_inplace_apply_rotary_with_pos( + qo_packed_idx_base, q_rope_offset + q_indptr[request_idx], + &qo_smem, group_size, &q_smem_offset_r, rope_freq, tid); + } + block.sync(); + } + + const uint32_t num_iterations = ceil_div( + (MASK_MODE == MaskMode::kCausal + ? min(chunk_size, + sub_if_greater_or_zero( + kv_len - qo_len + + ((qo_tile_idx + 1) * CTA_TILE_Q) / group_size, + chunk_start)) + : chunk_size), + CTA_TILE_KV); + + const uint32_t window_iteration = + ceil_div(sub_if_greater_or_zero( + kv_len + (qo_tile_idx + 1) * CTA_TILE_Q / group_size, + qo_len + window_left + chunk_start), + CTA_TILE_KV); + + const uint32_t mask_iteration = + (MASK_MODE == MaskMode::kCausal + ? min(chunk_size, + sub_if_greater_or_zero( + kv_len + (qo_tile_idx * CTA_TILE_Q) / group_size - + qo_len, + chunk_start)) + : chunk_size) / + CTA_TILE_KV; + + smem_t k_smem(smem_storage.k_smem), + v_smem(smem_storage.v_smem); + + uint32_t k_smem_offset_r = + k_smem.template get_permuted_offset( + get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + + 8 * (lane_idx / 16) + lane_idx % 8, + (lane_idx % 16) / 8), + v_smem_offset_r = + v_smem.template get_permuted_offset( + get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + + lane_idx % 16, + lane_idx / 16), + k_smem_offset_w = + k_smem.template get_permuted_offset( + warp_idx * KV_THR_LAYOUT_ROW + + lane_idx / KV_THR_LAYOUT_COL, + lane_idx % KV_THR_LAYOUT_COL), + v_smem_offset_w = + v_smem.template get_permuted_offset( + warp_idx * KV_THR_LAYOUT_ROW + + lane_idx / KV_THR_LAYOUT_COL, + lane_idx % KV_THR_LAYOUT_COL); + + DTypeKV *k_ptr = + k + + (kv_indptr[request_idx] + chunk_start + + warp_idx * KV_THR_LAYOUT_ROW + lane_idx / KV_THR_LAYOUT_COL) * + k_stride_n + + kv_head_idx * k_stride_h + + (lane_idx % KV_THR_LAYOUT_COL) * upcast_size(); + DTypeKV *v_ptr = + v + + (kv_indptr[request_idx] + chunk_start + + warp_idx * KV_THR_LAYOUT_ROW + lane_idx / KV_THR_LAYOUT_COL) * + v_stride_n + + kv_head_idx * v_stride_h + + (lane_idx % KV_THR_LAYOUT_COL) * upcast_size(); + + produce_kv( + k_smem, &k_smem_offset_w, &k_ptr, k_stride_n, 0, chunk_size, tid); + memory::commit_group(); + produce_kv( + v_smem, &v_smem_offset_w, &v_ptr, v_stride_n, 0, chunk_size, tid); + memory::commit_group(); + +#pragma unroll 1 + for (uint32_t iter = 0; iter < num_iterations; ++iter) { + memory::wait_group<1>(); + block.sync(); + + if constexpr (KTraits::POS_ENCODING_MODE == + PosEncodingMode::kRoPELlama) + { + IdType *k_rope_offset = nullptr; + if constexpr (has_maybe_k_rope_offset_v) { + k_rope_offset = params.maybe_k_rope_offset; + } + k_smem_inplace_apply_rotary( + (k_rope_offset == nullptr ? 0 + : k_rope_offset[request_idx]) + + chunk_start + iter * CTA_TILE_KV, + &k_smem, &k_smem_offset_r, rope_freq, tid); + block.sync(); + } + + // compute attention score + compute_qk(&qo_smem, &q_smem_offset_r, &k_smem, + &k_smem_offset_r, s_frag); + + logits_transform( + params, variant, /*batch_idx=*/request_idx, qo_packed_idx_base, + chunk_start + + (iter * NUM_WARPS_KV + get_warp_idx_kv(tid.z)) * + NUM_MMA_KV * 16, + qo_len, kv_len, group_size, s_frag, tid, kv_head_idx); + + // apply mask + if (MASK_MODE == MaskMode::kCustom || + (iter >= mask_iteration || iter < window_iteration)) + { + logits_mask(params, variant, /*batch_idx=*/request_idx, + qo_packed_idx_base, + chunk_start + + (iter * NUM_WARPS_KV + + get_warp_idx_kv(tid.z)) * + NUM_MMA_KV * 16, + qo_len, kv_len, chunk_end, group_size, + s_frag, tid, kv_head_idx); + } + + // compute m,d states in online softmax + update_mdo_states(variant, s_frag, o_frag, m, d); + + block.sync(); + produce_kv( + k_smem, &k_smem_offset_w, &k_ptr, k_stride_n, + (iter + 1) * CTA_TILE_KV, chunk_size, tid); + memory::commit_group(); + memory::wait_group<1>(); + block.sync(); + + // compute sfm*v + compute_sfm_v(&v_smem, &v_smem_offset_r, s_frag, o_frag, + d); + + block.sync(); + produce_kv( + v_smem, &v_smem_offset_w, &v_ptr, v_stride_n, + (iter + 1) * CTA_TILE_KV, chunk_size, tid); + memory::commit_group(); + } + memory::wait_group<0>(); + block.sync(); + + finalize_m(variant, m); + + // threadblock synchronization + threadblock_sync_mdo_states(o_frag, &smem_storage, m, d, + warp_idx, lane_idx, tid); + + // normalize d + normalize_d(o_frag, m, d); + + const uint32_t num_kv_chunks = + (kv_len_safe + kv_chunk_size - 1) / kv_chunk_size; + + // write back + write_o_reg_gmem( + o_frag, &qo_smem, o_ptr_base, qo_packed_idx_base, qo_len, + /*o_stride_n=*/ + partition_kv ? num_kv_chunks * o_stride_n : o_stride_n, + /*o_stride_h=*/o_stride_h, group_size, tid); + + // write lse + if constexpr (AttentionVariant::use_softmax) { + if (lse != nullptr) { + if (get_warp_idx_kv(tid.z) == 0) { +#pragma unroll + for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + uint32_t q, r; + group_size.divmod(qo_packed_idx_base + + lane_idx / 4 + j * 8 + + mma_q * 16, + q, r); + const uint32_t qo_head_idx = + kv_head_idx * group_size + r; + const uint32_t qo_idx = q; + if (qo_idx < qo_len) { + if (partition_kv) { + lse[(o_indptr[request_idx] + + qo_idx * num_kv_chunks + kv_tile_idx) * + num_qo_heads + + qo_head_idx] = + gpu_iface::math::ptx_log2(d[mma_q][j]) + + float(m[mma_q][j]); + } + else { + lse[(o_indptr[request_idx] + qo_idx) * + num_qo_heads + + qo_head_idx] = + gpu_iface::math::ptx_log2(d[mma_q][j]) + + float(m[mma_q][j]); + } + } + } + } + } + } + } +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) + } +#endif +} + +template +__device__ __forceinline__ void BatchPrefillWithPagedKVCacheDevice( + const Params params, + typename KTraits::SharedStorage &smem_storage, + const dim3 tid = threadIdx, + const uint32_t bx = blockIdx.x, + const uint32_t kv_head_idx = blockIdx.z, + const uint32_t num_kv_heads = gridDim.z) +{ + using DTypeQ = typename Params::DTypeQ; +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) + if constexpr (std::is_same_v) { + FLASHINFER_RUNTIME_ASSERT( + "Prefill kernels do not support bf16 on sm75."); + } + else { +#endif + using DTypeKV = typename Params::DTypeKV; + using DTypeO = typename Params::DTypeO; + using IdType = typename Params::IdType; + using DTypeQKAccum = typename KTraits::DTypeQKAccum; + using AttentionVariant = typename KTraits::AttentionVariant; + [[maybe_unused]] constexpr uint32_t NUM_MMA_Q = KTraits::NUM_MMA_Q; + [[maybe_unused]] constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; + [[maybe_unused]] constexpr uint32_t NUM_MMA_D_QK = + KTraits::NUM_MMA_D_QK; + [[maybe_unused]] constexpr uint32_t NUM_MMA_D_VO = + KTraits::NUM_MMA_D_VO; + [[maybe_unused]] constexpr uint32_t HEAD_DIM_QK = KTraits::HEAD_DIM_QK; + [[maybe_unused]] constexpr uint32_t HEAD_DIM_VO = KTraits::HEAD_DIM_VO; + [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_Q = + KTraits::UPCAST_STRIDE_Q; + [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_K = + KTraits::UPCAST_STRIDE_K; + [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_V = + KTraits::UPCAST_STRIDE_V; + [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_O = + KTraits::UPCAST_STRIDE_O; + [[maybe_unused]] constexpr uint32_t NUM_WARPS_Q = KTraits::NUM_WARPS_Q; + [[maybe_unused]] constexpr uint32_t NUM_WARPS_KV = + KTraits::NUM_WARPS_KV; + [[maybe_unused]] constexpr SwizzleMode SWIZZLE_MODE_Q = + KTraits::SWIZZLE_MODE_Q; + [[maybe_unused]] constexpr SwizzleMode SWIZZLE_MODE_KV = + KTraits::SWIZZLE_MODE_KV; + [[maybe_unused]] constexpr uint32_t CTA_TILE_Q = KTraits::CTA_TILE_Q; + [[maybe_unused]] constexpr uint32_t CTA_TILE_KV = KTraits::CTA_TILE_KV; + [[maybe_unused]] constexpr uint32_t KV_THR_LAYOUT_ROW = + KTraits::KV_THR_LAYOUT_ROW; + [[maybe_unused]] constexpr uint32_t KV_THR_LAYOUT_COL = + KTraits::KV_THR_LAYOUT_COL; + [[maybe_unused]] constexpr MaskMode MASK_MODE = KTraits::MASK_MODE; + + IdType *request_indices = params.request_indices; + IdType *qo_tile_indices = params.qo_tile_indices; + IdType *kv_tile_indices = params.kv_tile_indices; + DTypeQ *q = params.q; + IdType *q_indptr = params.q_indptr; + IdType *o_indptr = params.o_indptr; + DTypeO *o = params.o; + float *lse = params.lse; + bool *block_valid_mask = params.block_valid_mask; + const paged_kv_t &paged_kv = params.paged_kv; + const bool partition_kv = params.partition_kv; + const uint_fastdiv &group_size = params.group_size; + + static_assert(sizeof(DTypeQ) == 2); + auto block = cg::this_thread_block(); + const uint32_t kv_chunk_size = *(params.kv_chunk_size_ptr); + + const uint32_t lane_idx = tid.x, + warp_idx = get_warp_idx(tid.y, tid.z); + if (block_valid_mask && !block_valid_mask[bx]) { + return; + } + const uint32_t num_qo_heads = num_kv_heads * group_size; + + const uint32_t request_idx = request_indices[bx], + qo_tile_idx = qo_tile_indices[bx], + kv_tile_idx = kv_tile_indices[bx]; + auto smem = reinterpret_cast(&smem_storage); + AttentionVariant variant(params, /*batch_idx=*/request_idx, smem); + const uint32_t qo_len = variant.qo_len, kv_len = variant.kv_len, + window_left = variant.window_left; + const uint32_t kv_len_safe = kv_len > 0 ? kv_len : 1; + const uint32_t max_chunk_size = partition_kv ? kv_chunk_size : kv_len; + const uint32_t chunk_start = + partition_kv ? kv_tile_idx * max_chunk_size : 0; + const uint32_t chunk_end = + partition_kv ? min((kv_tile_idx + 1) * max_chunk_size, kv_len) + : kv_len; + const uint32_t chunk_size = chunk_end - chunk_start; + const uint32_t qo_upper_bound = + min(qo_len, ceil_div((qo_tile_idx + 1) * CTA_TILE_Q, group_size)); + + DTypeQKAccum s_frag[NUM_MMA_Q][NUM_MMA_KV][8]; + alignas(16) float o_frag[NUM_MMA_Q][NUM_MMA_D_VO][8]; + DTypeQKAccum m[NUM_MMA_Q][2]; + float d[NUM_MMA_Q][2]; + float rope_freq[NUM_MMA_D_QK / 2][4]; + + if constexpr (KTraits::POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) + { + const float rope_rcp_scale = params.rope_rcp_scale; + const float rope_rcp_theta = params.rope_rcp_theta; + init_rope_freq(rope_freq, rope_rcp_scale, rope_rcp_theta, + tid.x); + } + init_states(variant, o_frag, m, d); + + const uint32_t qo_packed_idx_base = + (qo_tile_idx * NUM_WARPS_Q + get_warp_idx_q(tid.y)) * + NUM_MMA_Q * 16; + const uint32_t q_stride_n = params.q_stride_n, + q_stride_h = params.q_stride_h; + smem_t qo_smem(smem_storage.q_smem); + const uint32_t o_stride_n = num_qo_heads * HEAD_DIM_VO, + o_stride_h = HEAD_DIM_VO; + DTypeQ *q_ptr_base = q + q_indptr[request_idx] * q_stride_n + + (kv_head_idx * group_size) * q_stride_h; + DTypeO *o_ptr_base = + partition_kv + ? o + (o_indptr[request_idx] + kv_tile_idx) * o_stride_n + + (kv_head_idx * group_size) * o_stride_h + : o + o_indptr[request_idx] * o_stride_n + + (kv_head_idx * group_size) * o_stride_h; + uint32_t q_smem_offset_r = qo_smem.get_permuted_offset( + get_warp_idx_q(tid.y) * NUM_MMA_Q * 16 + lane_idx % 16, + lane_idx / 16); + + load_q_global_smem(qo_packed_idx_base, qo_upper_bound, + q_ptr_base, q_stride_n, q_stride_h, + group_size, &qo_smem, tid); + + memory::commit_group(); + + if constexpr (KTraits::POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) + { + memory::wait_group<0>(); + block.sync(); + IdType *q_rope_offset = nullptr; + if constexpr (has_maybe_q_rope_offset_v) { + q_rope_offset = params.maybe_q_rope_offset; + } + if (q_rope_offset == nullptr) { + q_smem_inplace_apply_rotary( + qo_packed_idx_base, qo_len, kv_len, group_size, &qo_smem, + &q_smem_offset_r, rope_freq, tid); + } + else { + q_smem_inplace_apply_rotary_with_pos( + qo_packed_idx_base, q_rope_offset + q_indptr[request_idx], + &qo_smem, group_size, &q_smem_offset_r, rope_freq, tid); + } + block.sync(); + } + + smem_t k_smem(smem_storage.k_smem), + v_smem(smem_storage.v_smem); + size_t thr_local_kv_offset[NUM_MMA_KV * KV_THR_LAYOUT_COL / 2 / + NUM_WARPS_Q]; + + uint32_t k_smem_offset_r = + k_smem.template get_permuted_offset( + get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + + 8 * (lane_idx / 16) + lane_idx % 8, + (lane_idx % 16) / 8), + v_smem_offset_r = + v_smem.template get_permuted_offset( + get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + + lane_idx % 16, + lane_idx / 16), + k_smem_offset_w = + k_smem.template get_permuted_offset( + warp_idx * KV_THR_LAYOUT_ROW + + lane_idx / KV_THR_LAYOUT_COL, + lane_idx % KV_THR_LAYOUT_COL), + v_smem_offset_w = + v_smem.template get_permuted_offset( + warp_idx * KV_THR_LAYOUT_ROW + + lane_idx / KV_THR_LAYOUT_COL, + lane_idx % KV_THR_LAYOUT_COL); + const IdType last_indptr = paged_kv.indptr[paged_kv.batch_size]; + + uint32_t packed_page_iter_base = + paged_kv.indptr[request_idx] * paged_kv.page_size + chunk_start; +#pragma unroll + for (uint32_t i = 0; + i < NUM_MMA_KV * (SWIZZLE_MODE_KV == SwizzleMode::k128B ? 4 : 2) / + NUM_WARPS_Q; + ++i) + { + uint32_t page_iter, entry_idx; + paged_kv.page_size.divmod( + packed_page_iter_base + warp_idx * KV_THR_LAYOUT_ROW + + lane_idx / KV_THR_LAYOUT_COL + + KV_THR_LAYOUT_ROW * NUM_WARPS_Q * NUM_WARPS_KV * i, + page_iter, entry_idx); + thr_local_kv_offset[i] = paged_kv.protective_get_kv_offset( + page_iter, kv_head_idx, entry_idx, + (lane_idx % KV_THR_LAYOUT_COL) * upcast_size(), + last_indptr); + } + page_produce_kv(k_smem, &k_smem_offset_w, paged_kv, 0, + thr_local_kv_offset, chunk_size, tid); + memory::commit_group(); + page_produce_kv(v_smem, &v_smem_offset_w, paged_kv, 0, + thr_local_kv_offset, chunk_size, tid); + memory::commit_group(); + + const uint32_t num_iterations = ceil_div( + (MASK_MODE == MaskMode::kCausal + ? min(chunk_size, + sub_if_greater_or_zero( + kv_len - qo_len + + ((qo_tile_idx + 1) * CTA_TILE_Q) / group_size, + chunk_start)) + : chunk_size), + CTA_TILE_KV); + + const uint32_t window_iteration = + ceil_div(sub_if_greater_or_zero( + kv_len + (qo_tile_idx + 1) * CTA_TILE_Q / group_size, + qo_len + window_left + chunk_start), + CTA_TILE_KV); + + const uint32_t mask_iteration = + (MASK_MODE == MaskMode::kCausal + ? min(chunk_size, + sub_if_greater_or_zero( + kv_len + (qo_tile_idx * CTA_TILE_Q) / group_size - + qo_len, + chunk_start)) + : chunk_size) / + CTA_TILE_KV; + +#pragma unroll 1 + for (uint32_t iter = 0; iter < num_iterations; ++iter) { + packed_page_iter_base += CTA_TILE_KV; +#pragma unroll + for (uint32_t i = 0; + i < NUM_MMA_KV * + (SWIZZLE_MODE_KV == SwizzleMode::k128B ? 4 : 2) / + NUM_WARPS_Q; + ++i) + { + uint32_t page_iter, entry_idx; + paged_kv.page_size.divmod( + packed_page_iter_base + warp_idx * KV_THR_LAYOUT_ROW + + lane_idx / KV_THR_LAYOUT_COL + + KV_THR_LAYOUT_ROW * NUM_WARPS_Q * NUM_WARPS_KV * i, + page_iter, entry_idx); + thr_local_kv_offset[i] = paged_kv.protective_get_kv_offset( + page_iter, kv_head_idx, entry_idx, + (lane_idx % KV_THR_LAYOUT_COL) * upcast_size(), + last_indptr); + } + memory::wait_group<1>(); + block.sync(); + + if constexpr (KTraits::POS_ENCODING_MODE == + PosEncodingMode::kRoPELlama) + { + k_smem_inplace_apply_rotary( + (paged_kv.rope_pos_offset == nullptr + ? 0 + : paged_kv.rope_pos_offset[request_idx]) + + chunk_start + iter * CTA_TILE_KV, + &k_smem, &k_smem_offset_r, rope_freq, tid); + block.sync(); + } + + // compute attention score + compute_qk(&qo_smem, &q_smem_offset_r, &k_smem, + &k_smem_offset_r, s_frag); + + logits_transform( + params, variant, /*batch_idx=*/request_idx, qo_packed_idx_base, + chunk_start + + (iter * NUM_WARPS_KV + get_warp_idx_kv(tid.z)) * + NUM_MMA_KV * 16, + qo_len, kv_len, group_size, s_frag, tid, kv_head_idx); + + // apply mask + if (MASK_MODE == MaskMode::kCustom || + (iter >= mask_iteration || iter < window_iteration)) + { + logits_mask(params, variant, /*batch_idx=*/request_idx, + qo_packed_idx_base, + chunk_start + + (iter * NUM_WARPS_KV + + get_warp_idx_kv(tid.z)) * + NUM_MMA_KV * 16, + qo_len, kv_len, chunk_end, group_size, + s_frag, tid, kv_head_idx); + } + + // compute m,d states in online softmax + update_mdo_states(variant, s_frag, o_frag, m, d); + + block.sync(); + page_produce_kv( + k_smem, &k_smem_offset_w, paged_kv, (iter + 1) * CTA_TILE_KV, + thr_local_kv_offset, chunk_size, tid); + memory::commit_group(); + memory::wait_group<1>(); + block.sync(); + + // compute sfm*v + compute_sfm_v(&v_smem, &v_smem_offset_r, s_frag, o_frag, + d); + + block.sync(); + page_produce_kv( + v_smem, &v_smem_offset_w, paged_kv, (iter + 1) * CTA_TILE_KV, + thr_local_kv_offset, chunk_size, tid); + memory::commit_group(); + } + memory::wait_group<0>(); + block.sync(); + + finalize_m(variant, m); + + // threadblock synchronization + threadblock_sync_mdo_states(o_frag, &smem_storage, m, d, + warp_idx, lane_idx, tid); + + // normalize d + normalize_d(o_frag, m, d); + + const uint32_t num_kv_chunks = + (kv_len_safe + kv_chunk_size - 1) / kv_chunk_size; + + // write_back + write_o_reg_gmem( + o_frag, &qo_smem, o_ptr_base, qo_packed_idx_base, qo_len, + /*o_stride_n=*/ + partition_kv ? num_kv_chunks * o_stride_n : o_stride_n, + /*o_stride_h=*/o_stride_h, group_size, tid); + + // write lse + if constexpr (variant.use_softmax) { + if (lse != nullptr) { + if (get_warp_idx_kv(tid.z) == 0) { +#pragma unroll + for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + uint32_t q, r; + group_size.divmod(qo_packed_idx_base + + lane_idx / 4 + j * 8 + + mma_q * 16, + q, r); + const uint32_t qo_head_idx = + kv_head_idx * group_size + r; + const uint32_t qo_idx = q; + if (qo_idx < qo_upper_bound) { + if (partition_kv) { + lse[(o_indptr[request_idx] + + qo_idx * num_kv_chunks + kv_tile_idx) * + num_qo_heads + + qo_head_idx] = + gpu_iface::math::ptx_log2(d[mma_q][j]) + + float(m[mma_q][j]); + } + else { + lse[(o_indptr[request_idx] + qo_idx) * + num_qo_heads + + qo_head_idx] = + gpu_iface::math::ptx_log2(d[mma_q][j]) + + float(m[mma_q][j]); + } + } + } + } + } + } + } +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) + } +#endif +} + +template +__global__ +__launch_bounds__(KTraits::NUM_THREADS) void BatchPrefillWithPagedKVCacheKernel( + const __grid_constant__ Params params) +{ + extern __shared__ uint8_t smem[]; + auto &smem_storage = + reinterpret_cast(smem); + BatchPrefillWithPagedKVCacheDevice(params, smem_storage); +} + +template +gpuError_t +BatchPrefillWithRaggedKVCacheDispatched(Params params, + typename Params::DTypeO *tmp_v, + float *tmp_s, + gpuStream_t stream) +{ + using DTypeQ = typename Params::DTypeQ; + using DTypeKV = typename Params::DTypeKV; + using DTypeO = typename Params::DTypeO; + const uint32_t padded_batch_size = params.padded_batch_size; + const uint32_t num_qo_heads = params.num_qo_heads; + const uint32_t num_kv_heads = params.num_kv_heads; + constexpr uint32_t NUM_MMA_Q = get_num_mma_q(CTA_TILE_Q); + constexpr uint32_t NUM_WARPS_Q = get_num_warps_q(CTA_TILE_Q); + constexpr uint32_t NUM_WARPS_KV = get_num_warps_kv(CTA_TILE_Q); + + if (padded_batch_size == 0) { + // No request, skip + // this won't happen in CUDAGraph mode because we fixed the + // padded_batch_size + return gpuSuccess; + } + + dim3 nblks(padded_batch_size, 1, num_kv_heads); + dim3 nthrs(32, NUM_WARPS_Q, NUM_WARPS_KV); + constexpr uint32_t NUM_MMA_D_QK = HEAD_DIM_QK / 16; + constexpr uint32_t NUM_MMA_D_VO = HEAD_DIM_VO / 16; + using DTypeQKAccum = + typename std::conditional, + half, float>::type; + + int dev_id = 0; + FI_GPU_CALL(gpuGetDevice(&dev_id)); + int max_smem_per_sm = 0; + FI_GPU_CALL(gpuDeviceGetAttribute( + &max_smem_per_sm, gpuDevAttrMaxSharedMemoryPerMultiProcessor, dev_id)); + // we expect each sm execute two threadblocks + const int num_ctas_per_sm = + max_smem_per_sm >= 2 * (CTA_TILE_Q * HEAD_DIM_QK * sizeof(DTypeQ) + + (HEAD_DIM_QK + HEAD_DIM_VO) * 16 * + NUM_WARPS_KV * sizeof(DTypeKV)) + ? 2 + : 1; + const int max_smem_per_threadblock = max_smem_per_sm / num_ctas_per_sm; + + const uint32_t max_num_mma_kv_reg = + (HEAD_DIM_VO >= 128 && NUM_MMA_Q == 2 && + POS_ENCODING_MODE == PosEncodingMode::kRoPELlama && + !USE_FP16_QK_REDUCTION) + ? 2 + : (8 / NUM_MMA_Q); + const uint32_t max_num_mma_kv_smem = + (max_smem_per_threadblock - CTA_TILE_Q * HEAD_DIM_QK * sizeof(DTypeQ)) / + ((HEAD_DIM_QK + HEAD_DIM_VO) * 16 * NUM_WARPS_KV * sizeof(DTypeKV)); + + DISPATCH_NUM_MMA_KV( + min(max_num_mma_kv_smem, max_num_mma_kv_reg), NUM_MMA_KV, { + using KTraits = + KernelTraits; + if constexpr (KTraits::IsInvalid()) { + // Invalid configuration, skip + std::ostringstream err_msg; + err_msg + << "FlashInfer Internal Error: Invalid configuration : " + "NUM_MMA_Q=" + << NUM_MMA_Q << " NUM_MMA_D_QK=" << NUM_MMA_D_QK + << " NUM_MMA_D_VO=" << NUM_MMA_D_VO + << " NUM_MMA_KV=" << NUM_MMA_KV + << " NUM_WARPS_Q=" << NUM_WARPS_Q + << " NUM_WARPS_KV=" << NUM_WARPS_KV + << " please create an issue " + "(https://github.com/flashinfer-ai/flashinfer/issues)" + " and report the issue to the developers."; + FLASHINFER_ERROR(err_msg.str()); + } + else { + size_t smem_size = sizeof(typename KTraits::SharedStorage); + auto kernel = + BatchPrefillWithRaggedKVCacheKernel; + FI_GPU_CALL(gpuFuncSetAttribute( + kernel, gpuFuncAttributeMaxDynamicSharedMemorySize, + smem_size)); + if (tmp_v == nullptr) { + // do not partition kv + params.partition_kv = false; + void *args[] = {(void *)¶ms}; + FI_GPU_CALL(gpuLaunchKernel((void *)kernel, nblks, nthrs, + args, smem_size, stream)); + } + else { + // partition kv + params.partition_kv = true; + auto o = params.o; + auto lse = params.lse; + params.o = tmp_v; + params.lse = tmp_s; + void *args[] = {(void *)¶ms}; + FI_GPU_CALL(gpuLaunchKernel((void *)kernel, nblks, nthrs, + args, smem_size, stream)); + if constexpr (AttentionVariant::use_softmax) { + FI_GPU_CALL(VariableLengthMergeStates( + tmp_v, tmp_s, params.merge_indptr, o, lse, + params.max_total_num_rows, params.total_num_rows, + num_qo_heads, HEAD_DIM_VO, stream)); + } + else { + FI_GPU_CALL(VariableLengthAttentionSum( + tmp_v, params.merge_indptr, o, + params.max_total_num_rows, params.total_num_rows, + num_qo_heads, HEAD_DIM_VO, stream)); + } + } + } + }); + return gpuSuccess; +} + +template +gpuError_t +BatchPrefillWithPagedKVCacheDispatched(Params params, + typename Params::DTypeO *tmp_v, + float *tmp_s, + gpuStream_t stream) +{ + using DTypeQ = typename Params::DTypeQ; + using DTypeKV = typename Params::DTypeKV; + using DTypeO = typename Params::DTypeO; + const uint32_t padded_batch_size = params.padded_batch_size; + const uint32_t num_qo_heads = params.num_qo_heads; + const uint32_t num_kv_heads = params.paged_kv.num_heads; + constexpr uint32_t NUM_MMA_Q = get_num_mma_q(CTA_TILE_Q); + constexpr uint32_t NUM_WARPS_Q = get_num_warps_q(CTA_TILE_Q); + constexpr uint32_t NUM_WARPS_KV = get_num_warps_kv(CTA_TILE_Q); + + if (padded_batch_size == 0) { + // No request, skip + // this won't happen in CUDAGraph mode because we fixed the + // padded_batch_size + return gpuSuccess; + } + + dim3 nblks(padded_batch_size, 1, num_kv_heads); + dim3 nthrs(32, NUM_WARPS_Q, NUM_WARPS_KV); + + constexpr uint32_t NUM_MMA_D_QK = HEAD_DIM_QK / 16; + constexpr uint32_t NUM_MMA_D_VO = HEAD_DIM_VO / 16; + using DTypeQKAccum = + typename std::conditional, + half, float>::type; + + int dev_id = 0; + FI_GPU_CALL(gpuGetDevice(&dev_id)); + int max_smem_per_sm = 0; + FI_GPU_CALL(gpuDeviceGetAttribute( + &max_smem_per_sm, gpuDevAttrMaxSharedMemoryPerMultiProcessor, dev_id)); + // we expect each sm execute two threadblocks + const int num_ctas_per_sm = + max_smem_per_sm >= 2 * (CTA_TILE_Q * HEAD_DIM_QK * sizeof(DTypeQ) + + (HEAD_DIM_QK + HEAD_DIM_VO) * 16 * + NUM_WARPS_KV * sizeof(DTypeKV)) + ? 2 + : 1; + const int max_smem_per_threadblock = max_smem_per_sm / num_ctas_per_sm; + + const uint32_t max_num_mma_kv_reg = + (HEAD_DIM_VO >= 128 && NUM_MMA_Q == 2 && + POS_ENCODING_MODE == PosEncodingMode::kRoPELlama && + !USE_FP16_QK_REDUCTION) + ? 2 + : (8 / NUM_MMA_Q); + const uint32_t max_num_mma_kv_smem = + (max_smem_per_threadblock - CTA_TILE_Q * HEAD_DIM_QK * sizeof(DTypeQ)) / + ((HEAD_DIM_QK + HEAD_DIM_VO) * 16 * NUM_WARPS_KV * sizeof(DTypeKV)); + + DISPATCH_NUM_MMA_KV( + min(max_num_mma_kv_smem, max_num_mma_kv_reg), NUM_MMA_KV, { + using KTraits = + KernelTraits; + if constexpr (KTraits::IsInvalid()) { + // Invalid configuration, skip + std::ostringstream err_msg; + err_msg + << "FlashInfer Internal Error: Invalid configuration : " + "NUM_MMA_Q=" + << NUM_MMA_Q << " NUM_MMA_D_QK=" << NUM_MMA_D_QK + << " NUM_MMA_D_VO=" << NUM_MMA_D_VO + << " NUM_MMA_KV=" << NUM_MMA_KV + << " NUM_WARPS_Q=" << NUM_WARPS_Q + << " NUM_WARPS_KV=" << NUM_WARPS_KV + << " please create an issue " + "(https://github.com/flashinfer-ai/flashinfer/issues)" + " and report the issue to the developers."; + FLASHINFER_ERROR(err_msg.str()); + } + else { + size_t smem_size = sizeof(typename KTraits::SharedStorage); + auto kernel = + BatchPrefillWithPagedKVCacheKernel; + FI_GPU_CALL(gpuFuncSetAttribute( + kernel, gpuFuncAttributeMaxDynamicSharedMemorySize, + smem_size)); + if (tmp_v == nullptr) { + // do not partition kv + params.partition_kv = false; + void *args[] = {(void *)¶ms}; + FI_GPU_CALL(gpuLaunchKernel((void *)kernel, nblks, nthrs, + args, smem_size, stream)); + } + else { + params.partition_kv = true; + auto o = params.o; + auto lse = params.lse; + params.o = tmp_v; + params.lse = tmp_s; + void *args[] = {(void *)¶ms}; + FI_GPU_CALL(gpuLaunchKernel((void *)kernel, nblks, nthrs, + args, smem_size, stream)); + if constexpr (AttentionVariant::use_softmax) { + FI_GPU_CALL(VariableLengthMergeStates( + tmp_v, tmp_s, params.merge_indptr, o, lse, + params.max_total_num_rows, params.total_num_rows, + num_qo_heads, HEAD_DIM_VO, stream)); + } + else { + FI_GPU_CALL(VariableLengthAttentionSum( + tmp_v, params.merge_indptr, o, + params.max_total_num_rows, params.total_num_rows, + num_qo_heads, HEAD_DIM_VO, stream)); + } + } + } + }); + return gpuSuccess; +} + +} // namespace flashinfer + +#endif // FLASHINFER_PREFILL_CUH_ From 47120392bc021cc7e66e49b3f7ff7f0249983976 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Sat, 2 Aug 2025 12:38:43 -0400 Subject: [PATCH 002/109] Add new utils.cuh --- .../flashinfer/attention/generic/utils.cuh | 152 ++++++++++++++++++ 1 file changed, 152 insertions(+) create mode 100644 libflashinfer/include/flashinfer/attention/generic/utils.cuh diff --git a/libflashinfer/include/flashinfer/attention/generic/utils.cuh b/libflashinfer/include/flashinfer/attention/generic/utils.cuh new file mode 100644 index 0000000000..acbd374141 --- /dev/null +++ b/libflashinfer/include/flashinfer/attention/generic/utils.cuh @@ -0,0 +1,152 @@ +// SPDX - FileCopyrightText : 2023-2035 FlashInfer team. +// SPDX - FileCopyrightText : 2025 Advanced Micro Devices, Inc. +// +// SPDX - License - Identifier : Apache 2.0 + +#pragma once + +#include "gpu_iface/gpu_runtime_compat.hpp" +#include +#include +#include +#include + +#define STR_HELPER(x) #x +#define STR(x) STR_HELPER(x) + +// macro to turn off fp16 qk reduction to reduce binary +#ifndef FLASHINFER_ALWAYS_DISUSE_FP16_QK_REDUCTION +#define FLASHINFER_ALWAYS_DISUSE_FP16_QK_REDUCTION 0 +#endif + +namespace flashinfer +{ + +template +__forceinline__ __device__ __host__ T1 ceil_div(const T1 x, const T2 y) +{ + return (x + y - 1) / y; +} + +#if defined(PLATFORM_CUDA_DEVICE) +inline std::pair GetCudaComputeCapability() +{ + int device_id = 0; + cudaGetDevice(&device_id); + int major = 0, minor = 0; + cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, + device_id); + cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, + device_id); + return std::make_pair(major, minor); +} +#endif + +template +inline void +DebugPrintCUDAArray(T *device_ptr, size_t size, std::string prefix = "") +{ + std::vector host_array(size); + std::cout << prefix; + gpuMemcpy(host_array.data(), device_ptr, size * sizeof(T), + gpuMemcpyDeviceToHost); + for (size_t i = 0; i < size; ++i) { + std::cout << host_array[i] << " "; + } + std::cout << std::endl; +} + +inline uint32_t FA2DetermineCtaTileQ(int64_t avg_packed_qo_len, + uint32_t head_dim) +{ +#if defined(PLATFORM_CUDA_DEVICE) + if (avg_packed_qo_len > 64 && head_dim < 256) { + return 128; + } + else { + auto compute_capacity = GetCudaComputeCapability(); + if (compute_capacity.first >= 8) { + // Ampere or newer + if (avg_packed_qo_len > 16) { + // avg_packed_qo_len <= 64 + return 64; + } + else { + // avg_packed_qo_len <= 16 + return 16; + } + } + else { + // NOTE(Zihao): not enough shared memory on Turing for 1x4 warp + // layout + return 64; + } + } +#elif defined(PLATFORM_HIP_DEVICE) + // Simplified version for HIP + if (avg_packed_qo_len > 64 && head_dim < 256) { + return 128; + } + else { + return avg_packed_qo_len <= 16 ? 16 : 64; + } +#endif +} + +/*! + * \brief Return x - y if x > y, otherwise return 0. + */ +__device__ __forceinline__ uint32_t sub_if_greater_or_zero(uint32_t x, + uint32_t y) +{ + return (x > y) ? x - y : 0U; +} + +__device__ __forceinline__ void swap(uint32_t &a, uint32_t &b) +{ + uint32_t tmp = a; + a = b; + b = tmp; +} + +__device__ __forceinline__ uint32_t dim2_offset(const uint32_t &dim_a, + const uint32_t &idx_b, + const uint32_t &idx_a) +{ + return idx_b * dim_a + idx_a; +} + +__device__ __forceinline__ uint32_t dim3_offset(const uint32_t &dim_b, + const uint32_t &dim_a, + const uint32_t &idx_c, + const uint32_t &idx_b, + const uint32_t &idx_a) +{ + return (idx_c * dim_b + idx_b) * dim_a + idx_a; +} + +__device__ __forceinline__ uint32_t dim4_offset(const uint32_t &dim_c, + const uint32_t &dim_b, + const uint32_t &dim_a, + const uint32_t &idx_d, + const uint32_t &idx_c, + const uint32_t &idx_b, + const uint32_t &idx_a) +{ + return ((idx_d * dim_c + idx_c) * dim_b + idx_b) * dim_a + idx_a; +} + +#define DEFINE_HAS_MEMBER(member) \ + template \ + struct has_##member : std::false_type \ + { \ + }; \ + template \ + struct has_##member().member)>> \ + : std::true_type \ + { \ + }; \ + template \ + inline constexpr bool has_##member##_v = has_##member::value; + +} // namespace flashinfer From 15dcca9b7310871edac97c95533764abd0be9d14 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Sat, 2 Aug 2025 13:26:45 -0400 Subject: [PATCH 003/109] Add dispatch.cuh --- .../flashinfer/attention/generic/dispatch.cuh | 256 ++++++++++++++++++ 1 file changed, 256 insertions(+) create mode 100644 libflashinfer/include/flashinfer/attention/generic/dispatch.cuh diff --git a/libflashinfer/include/flashinfer/attention/generic/dispatch.cuh b/libflashinfer/include/flashinfer/attention/generic/dispatch.cuh new file mode 100644 index 0000000000..cabadd58f8 --- /dev/null +++ b/libflashinfer/include/flashinfer/attention/generic/dispatch.cuh @@ -0,0 +1,256 @@ +// SPDX - FileCopyrightText : 2023-2035 FlashInfer team. +// SPDX - FileCopyrightText : 2025 Advanced Micro Devices, Inc. +// +// SPDX - License - Identifier : Apache 2.0 + +#pragma once + +#include "enums.hpp" +#include "gpu_iface/exception.h" + +#define DISPATCH_USE_FP16_QK_REDUCTION(use_fp16_qk_reduction, \ + USE_FP16_QK_REDUCTION, ...) \ + if (use_fp16_qk_reduction) { \ + FLASHINFER_ERROR("FP16_QK_REDUCTION disabled at compile time"); \ + } \ + else { \ + constexpr bool USE_FP16_QK_REDUCTION = false; \ + __VA_ARGS__ \ + } + +#define DISPATCH_NUM_MMA_Q(num_mma_q, NUM_MMA_Q, ...) \ + if (num_mma_q == 1) { \ + constexpr size_t NUM_MMA_Q = 1; \ + __VA_ARGS__ \ + } \ + else if (num_mma_q == 2) { \ + constexpr size_t NUM_MMA_Q = 2; \ + __VA_ARGS__ \ + } \ + else { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported num_mma_q: " << num_mma_q; \ + FLASHINFER_ERROR(err_msg.str()); \ + } + +#define DISPATCH_NUM_MMA_KV(max_mma_kv, NUM_MMA_KV, ...) \ + if (max_mma_kv >= 8) { \ + constexpr size_t NUM_MMA_KV = 8; \ + __VA_ARGS__ \ + } \ + else if (max_mma_kv >= 4) { \ + constexpr size_t NUM_MMA_KV = 4; \ + __VA_ARGS__ \ + } \ + else if (max_mma_kv >= 2) { \ + constexpr size_t NUM_MMA_KV = 2; \ + __VA_ARGS__ \ + } \ + else if (max_mma_kv >= 1) { \ + constexpr size_t NUM_MMA_KV = 1; \ + __VA_ARGS__ \ + } \ + else { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported max_mma_kv: " << max_mma_kv; \ + FLASHINFER_ERROR(err_msg.str()); \ + } + +#define DISPATCH_CTA_TILE_Q(cta_tile_q, CTA_TILE_Q, ...) \ + switch (cta_tile_q) { \ + case 128: \ + { \ + constexpr uint32_t CTA_TILE_Q = 128; \ + __VA_ARGS__ \ + break; \ + } \ + case 64: \ + { \ + constexpr uint32_t CTA_TILE_Q = 64; \ + __VA_ARGS__ \ + break; \ + } \ + case 16: \ + { \ + constexpr uint32_t CTA_TILE_Q = 16; \ + __VA_ARGS__ \ + break; \ + } \ + default: \ + { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported cta_tile_q: " << cta_tile_q; \ + FLASHINFER_ERROR(err_msg.str()); \ + } \ + } + +#define DISPATCH_GQA_GROUP_SIZE(group_size, GROUP_SIZE, ...) \ + if (group_size == 1) { \ + constexpr size_t GROUP_SIZE = 1; \ + __VA_ARGS__ \ + } \ + else if (group_size == 2) { \ + constexpr size_t GROUP_SIZE = 2; \ + __VA_ARGS__ \ + } \ + else if (group_size == 3) { \ + constexpr size_t GROUP_SIZE = 3; \ + __VA_ARGS__ \ + } \ + else if (group_size == 4) { \ + constexpr size_t GROUP_SIZE = 4; \ + __VA_ARGS__ \ + } \ + else if (group_size == 8) { \ + constexpr size_t GROUP_SIZE = 8; \ + __VA_ARGS__ \ + } \ + else { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported group_size: " << group_size; \ + FLASHINFER_ERROR(err_msg.str()); \ + } + +#define DISPATCH_MASK_MODE(mask_mode, MASK_MODE, ...) \ + switch (mask_mode) { \ + case MaskMode::kNone: \ + { \ + constexpr MaskMode MASK_MODE = MaskMode::kNone; \ + __VA_ARGS__ \ + break; \ + } \ + case MaskMode::kCausal: \ + { \ + constexpr MaskMode MASK_MODE = MaskMode::kCausal; \ + __VA_ARGS__ \ + break; \ + } \ + case MaskMode::kCustom: \ + { \ + constexpr MaskMode MASK_MODE = MaskMode::kCustom; \ + __VA_ARGS__ \ + break; \ + } \ + default: \ + { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported mask_mode: " << int(mask_mode); \ + FLASHINFER_ERROR(err_msg.str()); \ + } \ + } + +// convert head_dim to compile-time constant +#define DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, ...) \ + switch (head_dim) { \ + case 64: \ + { \ + constexpr size_t HEAD_DIM = 64; \ + __VA_ARGS__ \ + break; \ + } \ + case 128: \ + { \ + constexpr size_t HEAD_DIM = 128; \ + __VA_ARGS__ \ + break; \ + } \ + case 256: \ + { \ + constexpr size_t HEAD_DIM = 256; \ + __VA_ARGS__ \ + break; \ + } \ + case 512: \ + { \ + constexpr size_t HEAD_DIM = 512; \ + __VA_ARGS__ \ + break; \ + } \ + default: \ + { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported head_dim: " << head_dim; \ + FLASHINFER_ERROR(err_msg.str()); \ + } \ + } + +#define DISPATCH_POS_ENCODING_MODE(pos_encoding_mode, POS_ENCODING_MODE, ...) \ + switch (pos_encoding_mode) { \ + case PosEncodingMode::kNone: \ + { \ + constexpr PosEncodingMode POS_ENCODING_MODE = PosEncodingMode::kNone; \ + __VA_ARGS__ \ + break; \ + } \ + case PosEncodingMode::kRoPELlama: \ + { \ + constexpr PosEncodingMode POS_ENCODING_MODE = \ + PosEncodingMode::kRoPELlama; \ + __VA_ARGS__ \ + break; \ + } \ + case PosEncodingMode::kALiBi: \ + { \ + constexpr PosEncodingMode POS_ENCODING_MODE = PosEncodingMode::kALiBi; \ + __VA_ARGS__ \ + break; \ + } \ + default: \ + { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported pos_encoding_mode: " \ + << int(pos_encoding_mode); \ + FLASHINFER_ERROR(err_msg.str()); \ + } \ + } + +#define DISPATCH_ALIGNED_VEC_SIZE(aligned_vec_size, ALIGNED_VEC_SIZE, ...) \ + switch (aligned_vec_size) { \ + case 16: \ + { \ + constexpr size_t ALIGNED_VEC_SIZE = 16; \ + __VA_ARGS__ \ + break; \ + } \ + case 8: \ + { \ + constexpr size_t ALIGNED_VEC_SIZE = 8; \ + __VA_ARGS__ \ + break; \ + } \ + case 4: \ + { \ + constexpr size_t ALIGNED_VEC_SIZE = 4; \ + __VA_ARGS__ \ + break; \ + } \ + case 2: \ + { \ + constexpr size_t ALIGNED_VEC_SIZE = 2; \ + __VA_ARGS__ \ + break; \ + } \ + case 1: \ + { \ + constexpr size_t ALIGNED_VEC_SIZE = 1; \ + __VA_ARGS__ \ + break; \ + } \ + default: \ + { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported aligned_vec_size: " << aligned_vec_size; \ + FLASHINFER_ERROR(err_msg.str()); \ + } \ + } + +#define DISPATCH_COMPUTE_CAP_DECODE_NUM_STAGES_SMEM(compute_capacity, \ + NUM_STAGES_SMEM, ...) \ + if (compute_capacity.first >= 8) { \ + constexpr uint32_t NUM_STAGES_SMEM = 2; \ + __VA_ARGS__ \ + } \ + else { \ + constexpr uint32_t NUM_STAGES_SMEM = 1; \ + __VA_ARGS__ \ + } From 632c1b665a29c5e7138c4a06f36f93e125b7b799 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Sat, 2 Aug 2025 13:48:46 -0400 Subject: [PATCH 004/109] Unit test cases for data access patterns --- .../tests/hip/test_load_q_global_smem.cpp | 193 ++++++++++++ libflashinfer/tests/hip/test_produce_kv.cpp | 282 ++++++++++++++++++ .../tests/hip/test_q_smem_read_pattern.cpp | 180 +++++++++++ 3 files changed, 655 insertions(+) create mode 100644 libflashinfer/tests/hip/test_load_q_global_smem.cpp create mode 100644 libflashinfer/tests/hip/test_produce_kv.cpp create mode 100644 libflashinfer/tests/hip/test_q_smem_read_pattern.cpp diff --git a/libflashinfer/tests/hip/test_load_q_global_smem.cpp b/libflashinfer/tests/hip/test_load_q_global_smem.cpp new file mode 100644 index 0000000000..c9d8c3840a --- /dev/null +++ b/libflashinfer/tests/hip/test_load_q_global_smem.cpp @@ -0,0 +1,193 @@ +#include +#include +#include +#include +#include + +// Constants for MI300 +constexpr uint32_t WARP_STEP_SIZE = 16; // 16 threads per warp row +constexpr uint32_t QUERY_ELEMS_PER_THREAD = + 4; // Each thread loads 4 fp16 elements +constexpr uint32_t WARP_THREAD_ROWS = 4; // 4 rows of threads in a warp + +// Simplified linear shared memory operations (CPU implementation) +template +uint32_t get_permuted_offset_linear(uint32_t row, uint32_t col) +{ + return row * stride + col; +} + +template +uint32_t advance_offset_by_column_linear(uint32_t offset, uint32_t step_idx) +{ + return offset + step_size; +} + +template +uint32_t advance_offset_by_row_linear(uint32_t offset) +{ + return offset + step_size * row_stride; +} + +// CPU-based offset pattern verification with configurable NUM_MMA_Q +template +void SimulateOffsetPattern(std::vector &thread_ids_at_offsets) +{ + // Constants derived from HEAD_DIM + constexpr uint32_t UPCAST_STRIDE_Q = HEAD_DIM / QUERY_ELEMS_PER_THREAD; + constexpr uint32_t NUM_MMA_D_QK = HEAD_DIM / 16; + constexpr uint32_t COLUMN_RESET_OFFSET = + (NUM_MMA_D_QK / 4) * WARP_STEP_SIZE; + constexpr uint32_t grid_width = + (HEAD_DIM / QUERY_ELEMS_PER_THREAD); // 16 for 64, 32 for 128 + constexpr uint32_t grid_height = + 16 * NUM_MMA_Q; // 16 for NUM_MMA_Q=1, 32 for NUM_MMA_Q=2 + + // Initialize with -1 (unwritten) + thread_ids_at_offsets.assign(grid_height * grid_width, -1); + + // Simulate each thread + for (uint32_t tid = 0; tid < 64; tid++) { + uint32_t row = tid / WARP_STEP_SIZE; // 0-3 for 64 threads + uint32_t col = tid % WARP_STEP_SIZE; // 0-15 + + // Calculate initial offset using linear addressing + uint32_t q_smem_offset_w = + get_permuted_offset_linear(row, col); + + // Main loop structure from load_q_global_smem + for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { + for (uint32_t j = 0; j < 4; ++j) { + // Calculate sequence index + const uint32_t seq_idx = row + mma_q * 16 + j; + + for (uint32_t mma_do = 0; mma_do < NUM_MMA_D_QK / 4; ++mma_do) { + // Record which thread wrote to this offset + if (q_smem_offset_w < grid_height * grid_width) + { // Safety check + thread_ids_at_offsets[q_smem_offset_w] = tid; + } + else { + printf("ERROR by tid: %d, offset: %d\n", tid, + q_smem_offset_w); + } + + // Advance to next column within same row + q_smem_offset_w = + advance_offset_by_column_linear( + q_smem_offset_w, mma_do); + } + + // Advance to next sequence (row) with adjustment back to first + // column + q_smem_offset_w = advance_offset_by_row_linear( + q_smem_offset_w) - + COLUMN_RESET_OFFSET; + } + } + } +} + +// Helper function to run the test with configurable NUM_MMA_Q +template void RunOffsetTest() +{ + constexpr uint32_t grid_width = + (HEAD_DIM / QUERY_ELEMS_PER_THREAD); // 16 for 64, 32 for 128 + constexpr uint32_t grid_height = + 16 * NUM_MMA_Q; // 16 for NUM_MMA_Q=1, 32 for NUM_MMA_Q=2 + + printf("\n=== Testing offset calculations with HEAD_DIM = %u, NUM_MMA_Q = " + "%u ===\n", + HEAD_DIM, NUM_MMA_Q); + + // Host array to store thread IDs at each offset + std::vector thread_ids(grid_height * grid_width, -1); + + // Run CPU simulation of offset pattern + SimulateOffsetPattern(thread_ids); + + // Print the grid of thread IDs (potentially truncated for readability) + printf("Thread IDs writing to each offset (%dx%d grid):\n", grid_height, + grid_width); + + // Column headers + printf(" "); + for (int c = 0; c < grid_width; c++) { + printf("%3d ", c); + if (c == 15 && grid_width > 16) + printf("| "); // Divider between first and second half + } + printf("\n +"); + for (int c = 0; c < grid_width; c++) { + printf("----"); + if (c == 15 && grid_width > 16) + printf("+"); // Divider between first and second half + } + printf("\n"); + + // Print quadrants with clear separation + for (int r = 0; r < grid_height; r++) { + printf("%2d | ", r); + for (int c = 0; c < grid_width; c++) { + int thread_id = thread_ids[r * grid_width + c]; + if (thread_id >= 0) { + printf("%3d ", thread_id); + } + else { + printf(" . "); // Dot for unwritten positions + } + if (c == 15 && grid_width > 16) + printf("| "); // Divider between first and second half + } + printf("\n"); + + // Add horizontal divider between first and second block of sequences + if (r == 15 && NUM_MMA_Q > 1) { + printf(" +"); + for (int c = 0; c < grid_width; c++) { + printf("----"); + if (c == 15 && grid_width > 16) + printf("+"); // Intersection divider + } + printf("\n"); + } + } + + // Check for unwritten positions + int unwritten = 0; + for (int i = 0; i < grid_height * grid_width; i++) { + if (thread_ids[i] == -1) { + unwritten++; + } + } + + // Print statistics + printf("\nStatistics:\n"); + printf("- Positions written: %d/%d (%.1f%%)\n", + grid_height * grid_width - unwritten, grid_height * grid_width, + 100.0f * (grid_height * grid_width - unwritten) / + (grid_height * grid_width)); + printf("- Unwritten positions: %d/%d (%.1f%%)\n", unwritten, + grid_height * grid_width, + 100.0f * unwritten / (grid_height * grid_width)); + + // Validate full coverage + EXPECT_EQ(unwritten, 0) << "Not all positions were written"; +} + +// Original tests with NUM_MMA_Q = 1 +TEST(MI300OffsetTest, HeadDim64_NumMmaQ1) { RunOffsetTest<64, 1>(); } + +TEST(MI300OffsetTest, HeadDim128_NumMmaQ1) { RunOffsetTest<128, 1>(); } + +// New tests with NUM_MMA_Q = 2 +TEST(MI300OffsetTest, HeadDim64_NumMmaQ2) { RunOffsetTest<64, 2>(); } + +TEST(MI300OffsetTest, HeadDim128_NumMmaQ2) { RunOffsetTest<128, 2>(); } + +int main(int argc, char **argv) +{ + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/libflashinfer/tests/hip/test_produce_kv.cpp b/libflashinfer/tests/hip/test_produce_kv.cpp new file mode 100644 index 0000000000..0494c6675f --- /dev/null +++ b/libflashinfer/tests/hip/test_produce_kv.cpp @@ -0,0 +1,282 @@ +#include +#include +#include +#include +#include + +// Constants +constexpr uint32_t WARP_SIZE_NV = 32; +constexpr uint32_t WARP_SIZE_AMD = 64; + +// SwizzleMode enum to match the original code +enum class SwizzleMode +{ + k64B = 0U, // Original NVIDIA mode (32 threads, 8 rows x 4 columns) + k128B = 1U, // Original pseudo-128B mode (32 threads, 4 rows x 8 columns) + kLinear = 2U // New AMD-specific mode (64 threads, 4 rows x 16 columns) +}; + +// Simplified linear shared memory operations (CPU implementation) +template +uint32_t get_permuted_offset_linear(uint32_t row, uint32_t col) +{ + return row * stride + col; +} + +template +uint32_t advance_offset_by_column_linear(uint32_t offset, uint32_t step_idx) +{ + return offset + step_size; +} + +template +uint32_t advance_offset_by_row_linear(uint32_t offset) +{ + return offset + step_size * row_stride; +} + +// CPU-based simulation of produce_kv for different SwizzleMode values +template +void SimulateProduceKV(std::vector &thread_ids_at_offsets, + uint32_t warp_size = WARP_SIZE_AMD) +{ + // Constants derived from HEAD_DIM and SwizzleMode + constexpr uint32_t ELEMS_PER_THREAD = 4; + constexpr uint32_t UPCAST_STRIDE = HEAD_DIM / ELEMS_PER_THREAD; + constexpr uint32_t NUM_MMA_D = HEAD_DIM / 16; + constexpr uint32_t grid_width = + HEAD_DIM / ELEMS_PER_THREAD; // 16 for 64, 32 for 128 + constexpr uint32_t grid_height = + 16 * NUM_MMA_KV; // 16 for NUM_MMA_KV=1, 32 for NUM_MMA_KV=2 + + // Thread layout constants based on SwizzleMode + constexpr uint32_t KV_THR_LAYOUT_ROW = + SWIZZLE_MODE == SwizzleMode::k128B ? 4 + : SWIZZLE_MODE == SwizzleMode::k64B ? 8 + : 4; // 4 for kLinear (AMD) + + constexpr uint32_t KV_THR_LAYOUT_COL = + SWIZZLE_MODE == SwizzleMode::k128B ? 8 + : SWIZZLE_MODE == SwizzleMode::k64B ? 4 + : 16; // 16 for kLinear (AMD) + + constexpr uint32_t NUM_WARPS = 1; + constexpr uint32_t NUM_WARPS_Q = 1; + + // Initialize with -1 (unwritten) + thread_ids_at_offsets.assign(grid_height * grid_width, -1); + + // Simulate each thread's write pattern + for (uint32_t tid = 0; tid < warp_size; tid++) { + uint32_t warp_idx = tid / warp_size; // Always 0 for single warp + uint32_t lane_idx = tid; + + // Calculate the initial shared memory offset and global memory index + uint32_t kv_smem_offset_w = get_permuted_offset_linear( + warp_idx * KV_THR_LAYOUT_ROW + lane_idx / KV_THR_LAYOUT_COL, + lane_idx % KV_THR_LAYOUT_COL); + + if constexpr (SWIZZLE_MODE == SwizzleMode::k128B) { + // k128B mode (original pseudo-128B mode) + uint32_t kv_idx = warp_idx * 4 + lane_idx / 8; + + static_assert(NUM_MMA_KV * 4 % NUM_WARPS_Q == 0); + for (uint32_t i = 0; i < NUM_MMA_KV * 4 / NUM_WARPS_Q; ++i) { + for (uint32_t j = 0; j < NUM_MMA_D / (8 / sizeof(uint16_t)); + ++j) + { + // Record which thread writes to this offset + if (kv_smem_offset_w < grid_height * grid_width && + kv_idx < grid_height) + { + thread_ids_at_offsets[kv_smem_offset_w] = tid; + } + + // Advance to next column within same row + kv_smem_offset_w = + advance_offset_by_column_linear<8>(kv_smem_offset_w, j); + } + + kv_idx += NUM_WARPS * 4; + kv_smem_offset_w = + advance_offset_by_row_linear( + kv_smem_offset_w) - + sizeof(uint16_t) * NUM_MMA_D; + } + } + else if constexpr (SWIZZLE_MODE == SwizzleMode::k64B) { + // k64B mode (original NVIDIA mode) + uint32_t kv_idx = warp_idx * 8 + lane_idx / 4; + + static_assert(NUM_MMA_KV * 2 % NUM_WARPS_Q == 0); + for (uint32_t i = 0; i < NUM_MMA_KV * 2 / NUM_WARPS_Q; ++i) { + // Record which thread writes to this offset + if (kv_smem_offset_w < grid_height * grid_width && + kv_idx < grid_height) + { + thread_ids_at_offsets[kv_smem_offset_w] = tid; + } + + kv_smem_offset_w = + advance_offset_by_row_linear( + kv_smem_offset_w); + kv_idx += NUM_WARPS * 8; + } + } + else if constexpr (SWIZZLE_MODE == SwizzleMode::kLinear) { + // kLinear mode (AMD-specific, using all 64 threads) + uint32_t kv_idx = warp_idx * 4 + lane_idx / 16; + + // For AMD's 64-thread warp, we need to process 4 rows with 16 + // threads per row + for (uint32_t i = 0; i < NUM_MMA_KV; ++i) { + for (uint32_t j = 0; j < NUM_MMA_D; ++j) { + // Record which thread writes to this offset + if (kv_smem_offset_w < grid_height * grid_width && + kv_idx < grid_height) + { + thread_ids_at_offsets[kv_smem_offset_w] = tid; + } + + // Advance to next column within same row + kv_smem_offset_w = + advance_offset_by_column_linear( + kv_smem_offset_w, j); + } + + kv_idx += 4; // Advance by 4 rows + kv_smem_offset_w = + advance_offset_by_row_linear<4, UPCAST_STRIDE>( + kv_smem_offset_w) - + NUM_MMA_D * ELEMS_PER_THREAD; + } + } + } +} + +// Helper function to run the test for different SwizzleModes +template +void RunProduceKVTest(uint32_t warp_size = WARP_SIZE_AMD) +{ + constexpr uint32_t grid_width = HEAD_DIM / 4; // 16 for 64, 32 for 128 + constexpr uint32_t grid_height = + 16 * NUM_MMA_KV; // 16 for NUM_MMA_KV=1, 32 for NUM_MMA_KV=2 + + std::string swizzle_mode_str; + switch (SWIZZLE_MODE) { + case SwizzleMode::k64B: + swizzle_mode_str = "k64B (NVIDIA)"; + break; + case SwizzleMode::k128B: + swizzle_mode_str = "k128B (NVIDIA pseudo-128B)"; + break; + case SwizzleMode::kLinear: + swizzle_mode_str = "kLinear (AMD)"; + break; + } + + printf("\n=== Testing produce_kv with HEAD_DIM = %u, NUM_MMA_KV = %u, " + "SwizzleMode = %s ===\n", + HEAD_DIM, NUM_MMA_KV, swizzle_mode_str.c_str()); + + // Host array to store thread IDs at each offset + std::vector thread_ids(grid_height * grid_width, -1); + + // Run CPU simulation of produce_kv + SimulateProduceKV(thread_ids, + warp_size); + + // Print the grid of thread IDs + printf("Thread IDs writing to each offset (%dx%d grid):\n", grid_height, + grid_width); + + // Column headers + printf(" "); + for (int c = 0; c < std::min(32, (int)grid_width); c++) { + printf("%3d ", c); + if (c == 15 && grid_width > 16) + printf("| "); + } + printf("\n +"); + for (int c = 0; c < std::min(32, (int)grid_width); c++) { + printf("----"); + if (c == 15 && grid_width > 16) + printf("+"); + } + printf("\n"); + + // Print grid with clear separation + for (int r = 0; r < grid_height; r++) { + printf("%2d | ", r); + for (int c = 0; c < std::min(32, (int)grid_width); c++) { + int thread_id = thread_ids[r * grid_width + c]; + if (thread_id >= 0) { + printf("%3d ", thread_id); + } + else { + printf(" . "); + } + if (c == 15 && grid_width > 16) + printf("| "); + } + printf("\n"); + + // Add horizontal divider between blocks + if (r == 15 && NUM_MMA_KV > 1) { + printf(" +"); + for (int c = 0; c < std::min(32, (int)grid_width); c++) { + printf("----"); + if (c == 15 && grid_width > 16) + printf("+"); + } + printf("\n"); + } + } + + // Check for unwritten positions + int unwritten = 0; + for (int i = 0; i < grid_height * grid_width; i++) { + if (thread_ids[i] == -1) { + unwritten++; + } + } + + // Print statistics + printf("\nStatistics:\n"); + printf("- Positions written: %d/%d (%.1f%%)\n", + grid_height * grid_width - unwritten, grid_height * grid_width, + 100.0f * (grid_height * grid_width - unwritten) / + (grid_height * grid_width)); + printf("- Unwritten positions: %d/%d (%.1f%%)\n", unwritten, + grid_height * grid_width, + 100.0f * unwritten / (grid_height * grid_width)); +} + +// Tests for different SwizzleModes +// TEST(KVCacheWritePatternTest, HeadDim64_NVIDIA_k64B) { +// RunProduceKVTest<64, 1, SwizzleMode::k64B>(WARP_SIZE_NV); +// } + +// TEST(KVCacheWritePatternTest, HeadDim64_NVIDIA_k128B) { +// RunProduceKVTest<64, 1, SwizzleMode::k128B>(WARP_SIZE_NV); +// } + +TEST(KVCacheWritePatternTest, HeadDim64_AMD_kLinear) +{ + RunProduceKVTest<64, 1, SwizzleMode::kLinear>(WARP_SIZE_AMD); +} + +TEST(KVCacheWritePatternTest, HeadDim128_AMD_kLinear) +{ + RunProduceKVTest<128, 1, SwizzleMode::kLinear>(WARP_SIZE_AMD); +} + +int main(int argc, char **argv) +{ + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/libflashinfer/tests/hip/test_q_smem_read_pattern.cpp b/libflashinfer/tests/hip/test_q_smem_read_pattern.cpp new file mode 100644 index 0000000000..12cddca0c0 --- /dev/null +++ b/libflashinfer/tests/hip/test_q_smem_read_pattern.cpp @@ -0,0 +1,180 @@ +#include +#include +#include +#include +#include + +// Constants for MI300 +constexpr uint32_t WARP_STEP_SIZE = 16; // 16 threads per warp row +constexpr uint32_t QUERY_ELEMS_PER_THREAD = 4; +constexpr uint32_t WARP_THREAD_ROWS = 4; // 4 rows of threads in a warp + +// Simplified linear shared memory operations (CPU implementation) +template +uint32_t get_permuted_offset_linear(uint32_t row, uint32_t col) +{ + return row * stride + col; +} + +template +uint32_t advance_offset_by_column_linear(uint32_t offset, uint32_t step_idx) +{ + return offset + step_size; +} + +template +uint32_t advance_offset_by_row_linear(uint32_t offset) +{ + return offset + step_size * row_stride; +} + +// CPU-based simulation of the read pattern in compute_qk +template +void SimulateReadPattern(std::vector &thread_ids_reading_offsets) +{ + // Constants derived from HEAD_DIM + constexpr uint32_t UPCAST_STRIDE_Q = HEAD_DIM / QUERY_ELEMS_PER_THREAD; + constexpr uint32_t NUM_MMA_D_QK = HEAD_DIM / 16; + constexpr uint32_t grid_width = + (HEAD_DIM / QUERY_ELEMS_PER_THREAD); // 16 for 64, 32 for 128 + constexpr uint32_t grid_height = + 16 * NUM_MMA_Q; // 16 for NUM_MMA_Q=1, 32 for NUM_MMA_Q=2 + + // Initialize with -1 (unread) + thread_ids_reading_offsets.assign(grid_height * grid_width, -1); + + // Simulate each thread's read pattern + for (uint32_t tid = 0; tid < 64; tid++) { + // Map tid to kernel's lane_idx (same for a single warp) + uint32_t lane_idx = tid; + + // Get warp_idx_q (this is 0 for our single warp simulation) + uint32_t warp_idx_q = 0; + + // Exactly match the kernel's initial offset calculation + uint32_t q_smem_offset_r = get_permuted_offset_linear( + warp_idx_q * NUM_MMA_Q * 16 + lane_idx % 16, lane_idx / 16); + + // Follow exactly the same loop structure as in compute_qk + for (uint32_t mma_d = 0; mma_d < NUM_MMA_D_QK; ++mma_d) { + for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { + // This would be a ldmatrix_m8n8x4 call in the actual code + uint32_t read_row = q_smem_offset_r / UPCAST_STRIDE_Q; + uint32_t read_col = q_smem_offset_r % UPCAST_STRIDE_Q; + + if (read_row < grid_height && read_col < grid_width) { + thread_ids_reading_offsets[read_row * grid_width + + read_col] = tid; + } + + // Advance to next row, exactly as in compute_qk + q_smem_offset_r = + advance_offset_by_row_linear<16, UPCAST_STRIDE_Q>( + q_smem_offset_r); + } + + // Reset row position and advance to next column, exactly as in + // compute_qk + q_smem_offset_r = + advance_offset_by_column_linear<4>(q_smem_offset_r, mma_d) - + NUM_MMA_Q * 16 * UPCAST_STRIDE_Q; + } + } +} + +// Helper function to run the test with configurable NUM_MMA_Q +template void RunReadPatternTest() +{ + constexpr uint32_t grid_width = + (HEAD_DIM / QUERY_ELEMS_PER_THREAD); // 16 for 64, 32 for 128 + constexpr uint32_t grid_height = + 16 * NUM_MMA_Q; // 16 for NUM_MMA_Q=1, 32 for NUM_MMA_Q=2 + + printf("\n=== Testing query read pattern with HEAD_DIM = %u, NUM_MMA_Q = " + "%u ===\n", + HEAD_DIM, NUM_MMA_Q); + + // Host array to store thread IDs at each offset + std::vector thread_ids(grid_height * grid_width, -1); + + // Run CPU simulation of read pattern + SimulateReadPattern(thread_ids); + + // Print the grid of thread IDs + printf("Thread IDs reading from each offset (%dx%d grid):\n", grid_height, + grid_width); + + // Column headers + printf(" "); + for (int c = 0; c < grid_width; c++) { + printf("%3d ", c); + if (c == 15 && grid_width > 16) + printf("| "); // Divider for HEAD_DIM=128 + } + printf("\n +"); + for (int c = 0; c < grid_width; c++) { + printf("----"); + if (c == 15 && grid_width > 16) + printf("+"); + } + printf("\n"); + + // Print the grid + for (int r = 0; r < grid_height; r++) { + printf("%2d | ", r); + for (int c = 0; c < grid_width; c++) { + int thread_id = thread_ids[r * grid_width + c]; + if (thread_id >= 0) { + printf("%3d ", thread_id); + } + else { + printf(" . "); // Dot for unread positions + } + if (c == 15 && grid_width > 16) + printf("| "); // Divider for HEAD_DIM=128 + } + printf("\n"); + } + + // Check for unread positions + int unread = 0; + for (int i = 0; i < grid_height * grid_width; i++) { + if (thread_ids[i] == -1) { + unread++; + } + } + + // Print statistics + printf("\nStatistics:\n"); + printf("- Positions read: %d/%d (%.1f%%)\n", + grid_height * grid_width - unread, grid_height * grid_width, + 100.0f * (grid_height * grid_width - unread) / + (grid_height * grid_width)); + printf("- Unread positions: %d/%d (%.1f%%)\n", unread, + grid_height * grid_width, + 100.0f * unread / (grid_height * grid_width)); + + // Validate full coverage + EXPECT_EQ(unread, 0) << "Not all positions were read"; +} + +// Tests for different configurations +TEST(MI300ReadPatternTest, HeadDim64_NumMmaQ1) { RunReadPatternTest<64, 1>(); } + +TEST(MI300ReadPatternTest, HeadDim128_NumMmaQ1) +{ + RunReadPatternTest<128, 1>(); +} + +TEST(MI300ReadPatternTest, HeadDim64_NumMmaQ2) { RunReadPatternTest<64, 2>(); } + +TEST(MI300ReadPatternTest, HeadDim128_NumMmaQ2) +{ + RunReadPatternTest<128, 2>(); +} + +int main(int argc, char **argv) +{ + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} From 7bb43d11253a4add7bf5450d7af35845d9f190e4 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Sat, 2 Aug 2025 23:24:01 -0400 Subject: [PATCH 005/109] Updated test_produce_kv.cpp --- libflashinfer/tests/hip/test_produce_kv.cpp | 206 ++++++++------------ 1 file changed, 78 insertions(+), 128 deletions(-) diff --git a/libflashinfer/tests/hip/test_produce_kv.cpp b/libflashinfer/tests/hip/test_produce_kv.cpp index 0494c6675f..90fcee6656 100644 --- a/libflashinfer/tests/hip/test_produce_kv.cpp +++ b/libflashinfer/tests/hip/test_produce_kv.cpp @@ -7,6 +7,8 @@ // Constants constexpr uint32_t WARP_SIZE_NV = 32; constexpr uint32_t WARP_SIZE_AMD = 64; +constexpr uint32_t WARP_STEP_SIZE = 16; // 16 threads per warp row for AMD +constexpr uint32_t WARP_THREAD_ROWS = 4; // 4 rows of threads in a warp for AMD // SwizzleMode enum to match the original code enum class SwizzleMode @@ -35,160 +37,117 @@ uint32_t advance_offset_by_row_linear(uint32_t offset) return offset + step_size * row_stride; } -// CPU-based simulation of produce_kv for different SwizzleMode values -template -void SimulateProduceKV(std::vector &thread_ids_at_offsets, - uint32_t warp_size = WARP_SIZE_AMD) +// CPU-based simulation of produce_kv for AMD MI300 with linear offset +// addressing +template +void SimulateProduceKV(std::vector &thread_ids_at_offsets) { - // Constants derived from HEAD_DIM and SwizzleMode - constexpr uint32_t ELEMS_PER_THREAD = 4; + // Constants for MI300 (64-thread warp, 4×16 thread layout) + constexpr uint32_t WARP_SIZE = 64; + constexpr uint32_t WARP_THREAD_ROWS = 4; // 4 rows of threads + constexpr uint32_t WARP_STEP_SIZE = 16; // 16 threads per row + constexpr uint32_t ELEMS_PER_THREAD = + 4; // Each thread loads 4 fp16 elements + + // Derived constants constexpr uint32_t UPCAST_STRIDE = HEAD_DIM / ELEMS_PER_THREAD; constexpr uint32_t NUM_MMA_D = HEAD_DIM / 16; - constexpr uint32_t grid_width = - HEAD_DIM / ELEMS_PER_THREAD; // 16 for 64, 32 for 128 - constexpr uint32_t grid_height = - 16 * NUM_MMA_KV; // 16 for NUM_MMA_KV=1, 32 for NUM_MMA_KV=2 - - // Thread layout constants based on SwizzleMode - constexpr uint32_t KV_THR_LAYOUT_ROW = - SWIZZLE_MODE == SwizzleMode::k128B ? 4 - : SWIZZLE_MODE == SwizzleMode::k64B ? 8 - : 4; // 4 for kLinear (AMD) - - constexpr uint32_t KV_THR_LAYOUT_COL = - SWIZZLE_MODE == SwizzleMode::k128B ? 8 - : SWIZZLE_MODE == SwizzleMode::k64B ? 4 - : 16; // 16 for kLinear (AMD) - + constexpr uint32_t grid_width = HEAD_DIM / ELEMS_PER_THREAD; + constexpr uint32_t grid_height = 16 * NUM_MMA_KV; constexpr uint32_t NUM_WARPS = 1; constexpr uint32_t NUM_WARPS_Q = 1; + constexpr uint32_t COLUMN_RESET_OFFSET = (NUM_MMA_D / 4) * WARP_STEP_SIZE; + //(NUM_MMA_D / (4 / sizeof(uint16_t))) * WARP_STEP_SIZE; // Initialize with -1 (unwritten) thread_ids_at_offsets.assign(grid_height * grid_width, -1); // Simulate each thread's write pattern - for (uint32_t tid = 0; tid < warp_size; tid++) { - uint32_t warp_idx = tid / warp_size; // Always 0 for single warp + for (uint32_t tid = 0; tid < WARP_SIZE; tid++) { + uint32_t warp_idx = 0; // Always 0 for single warp uint32_t lane_idx = tid; - // Calculate the initial shared memory offset and global memory index - uint32_t kv_smem_offset_w = get_permuted_offset_linear( - warp_idx * KV_THR_LAYOUT_ROW + lane_idx / KV_THR_LAYOUT_COL, - lane_idx % KV_THR_LAYOUT_COL); - - if constexpr (SWIZZLE_MODE == SwizzleMode::k128B) { - // k128B mode (original pseudo-128B mode) - uint32_t kv_idx = warp_idx * 4 + lane_idx / 8; + // Calculate thread's row and column + uint32_t row = lane_idx / WARP_STEP_SIZE; + uint32_t col = lane_idx % WARP_STEP_SIZE; - static_assert(NUM_MMA_KV * 4 % NUM_WARPS_Q == 0); - for (uint32_t i = 0; i < NUM_MMA_KV * 4 / NUM_WARPS_Q; ++i) { - for (uint32_t j = 0; j < NUM_MMA_D / (8 / sizeof(uint16_t)); - ++j) - { - // Record which thread writes to this offset - if (kv_smem_offset_w < grid_height * grid_width && - kv_idx < grid_height) - { - thread_ids_at_offsets[kv_smem_offset_w] = tid; - } - - // Advance to next column within same row - kv_smem_offset_w = - advance_offset_by_column_linear<8>(kv_smem_offset_w, j); - } - - kv_idx += NUM_WARPS * 4; - kv_smem_offset_w = - advance_offset_by_row_linear( - kv_smem_offset_w) - - sizeof(uint16_t) * NUM_MMA_D; - } - } - else if constexpr (SWIZZLE_MODE == SwizzleMode::k64B) { - // k64B mode (original NVIDIA mode) - uint32_t kv_idx = warp_idx * 8 + lane_idx / 4; - - static_assert(NUM_MMA_KV * 2 % NUM_WARPS_Q == 0); - for (uint32_t i = 0; i < NUM_MMA_KV * 2 / NUM_WARPS_Q; ++i) { + // Calculate initial offset + uint32_t kv_smem_offset_w = get_permuted_offset_linear( + warp_idx * WARP_THREAD_ROWS + row, col); + + // Initial kv_idx points to the first row this thread handles + uint32_t kv_idx = warp_idx * WARP_THREAD_ROWS + row; + + // Handle all blocks of rows + for (uint32_t i = 0; i < NUM_MMA_KV * 4 / NUM_WARPS_Q; ++i) { + // Process columns within a row (each thread loads 4 elements per + // iteration) + // for (uint32_t j = 0; j < NUM_MMA_D / (4 / sizeof(uint16_t)); ++j) + // { + for (uint32_t j = 0; j < NUM_MMA_D / 4; ++j) { // Record which thread writes to this offset + // if(tid == 0) { + // std::cout << "tid : " << tid << " kv_smem_offset_w at + // start " << kv_smem_offset_w << '\n'; + // } if (kv_smem_offset_w < grid_height * grid_width && kv_idx < grid_height) { thread_ids_at_offsets[kv_smem_offset_w] = tid; } - - kv_smem_offset_w = - advance_offset_by_row_linear( - kv_smem_offset_w); - kv_idx += NUM_WARPS * 8; - } - } - else if constexpr (SWIZZLE_MODE == SwizzleMode::kLinear) { - // kLinear mode (AMD-specific, using all 64 threads) - uint32_t kv_idx = warp_idx * 4 + lane_idx / 16; - - // For AMD's 64-thread warp, we need to process 4 rows with 16 - // threads per row - for (uint32_t i = 0; i < NUM_MMA_KV; ++i) { - for (uint32_t j = 0; j < NUM_MMA_D; ++j) { - // Record which thread writes to this offset - if (kv_smem_offset_w < grid_height * grid_width && - kv_idx < grid_height) - { - thread_ids_at_offsets[kv_smem_offset_w] = tid; - } - - // Advance to next column within same row - kv_smem_offset_w = - advance_offset_by_column_linear( - kv_smem_offset_w, j); + else { + std::cerr << "ERROR: Out of bound offset (" + << kv_smem_offset_w << ") at " << tid << '\n'; } - kv_idx += 4; // Advance by 4 rows + // Advance to next column by 16 (number of threads per row) kv_smem_offset_w = - advance_offset_by_row_linear<4, UPCAST_STRIDE>( - kv_smem_offset_w) - - NUM_MMA_D * ELEMS_PER_THREAD; + advance_offset_by_column_linear( + kv_smem_offset_w, j); + // if(tid == 0) { + // std::cout << "tid : " << tid << " kv_smem_offset_w after + // column inc: " << kv_smem_offset_w << '\n'; + // } } + + // Move to next set of rows + kv_idx += WARP_THREAD_ROWS; + + // if(tid == 0) { + // std::cout << "tid : " << tid << " kv_smem_offset_w before row + // inc " << kv_smem_offset_w << '\n'; + // } + // Reset column position and advance rows + kv_smem_offset_w = + advance_offset_by_row_linear(kv_smem_offset_w) - + COLUMN_RESET_OFFSET; + + // if(tid == 0) { + // std::cout << "tid : " << tid << " kv_smem_offset_w after row + // inc " << kv_smem_offset_w << '\n'; + // } } + // FIXME: Verify with original in prefill.cuh + kv_smem_offset_w -= 16 * NUM_MMA_KV * UPCAST_STRIDE; } } -// Helper function to run the test for different SwizzleModes -template -void RunProduceKVTest(uint32_t warp_size = WARP_SIZE_AMD) +// Helper function to run the test +template void RunProduceKVTest() { constexpr uint32_t grid_width = HEAD_DIM / 4; // 16 for 64, 32 for 128 constexpr uint32_t grid_height = 16 * NUM_MMA_KV; // 16 for NUM_MMA_KV=1, 32 for NUM_MMA_KV=2 - std::string swizzle_mode_str; - switch (SWIZZLE_MODE) { - case SwizzleMode::k64B: - swizzle_mode_str = "k64B (NVIDIA)"; - break; - case SwizzleMode::k128B: - swizzle_mode_str = "k128B (NVIDIA pseudo-128B)"; - break; - case SwizzleMode::kLinear: - swizzle_mode_str = "kLinear (AMD)"; - break; - } - - printf("\n=== Testing produce_kv with HEAD_DIM = %u, NUM_MMA_KV = %u, " - "SwizzleMode = %s ===\n", - HEAD_DIM, NUM_MMA_KV, swizzle_mode_str.c_str()); + printf("\n=== Testing produce_kv with HEAD_DIM = %u, NUM_MMA_KV = %u ===\n", + HEAD_DIM, NUM_MMA_KV); // Host array to store thread IDs at each offset std::vector thread_ids(grid_height * grid_width, -1); // Run CPU simulation of produce_kv - SimulateProduceKV(thread_ids, - warp_size); + SimulateProduceKV(thread_ids); // Print the grid of thread IDs printf("Thread IDs writing to each offset (%dx%d grid):\n", grid_height, @@ -256,23 +215,14 @@ void RunProduceKVTest(uint32_t warp_size = WARP_SIZE_AMD) 100.0f * unwritten / (grid_height * grid_width)); } -// Tests for different SwizzleModes -// TEST(KVCacheWritePatternTest, HeadDim64_NVIDIA_k64B) { -// RunProduceKVTest<64, 1, SwizzleMode::k64B>(WARP_SIZE_NV); -// } - -// TEST(KVCacheWritePatternTest, HeadDim64_NVIDIA_k128B) { -// RunProduceKVTest<64, 1, SwizzleMode::k128B>(WARP_SIZE_NV); -// } - TEST(KVCacheWritePatternTest, HeadDim64_AMD_kLinear) { - RunProduceKVTest<64, 1, SwizzleMode::kLinear>(WARP_SIZE_AMD); + RunProduceKVTest<64, 1>(); } TEST(KVCacheWritePatternTest, HeadDim128_AMD_kLinear) { - RunProduceKVTest<128, 1, SwizzleMode::kLinear>(WARP_SIZE_AMD); + RunProduceKVTest<128, 1>(); } int main(int argc, char **argv) From 466772474430ed097a475a5992ca6707ac74c1ca Mon Sep 17 00:00:00 2001 From: rtmadduri Date: Sun, 3 Aug 2025 22:25:57 +0000 Subject: [PATCH 006/109] Add mma ops --- .../include/gpu_iface/backend/cuda/mma.cuh | 759 ++++++++++++++++++ .../include/gpu_iface/backend/hip/mma_hip.h | 177 ++++ libflashinfer/include/gpu_iface/fragment.hpp | 148 ++++ libflashinfer/include/gpu_iface/mma_ops.hpp | 104 +++ 4 files changed, 1188 insertions(+) create mode 100644 libflashinfer/include/gpu_iface/backend/cuda/mma.cuh create mode 100644 libflashinfer/include/gpu_iface/backend/hip/mma_hip.h create mode 100644 libflashinfer/include/gpu_iface/fragment.hpp create mode 100644 libflashinfer/include/gpu_iface/mma_ops.hpp diff --git a/libflashinfer/include/gpu_iface/backend/cuda/mma.cuh b/libflashinfer/include/gpu_iface/backend/cuda/mma.cuh new file mode 100644 index 0000000000..fa72f6a46e --- /dev/null +++ b/libflashinfer/include/gpu_iface/backend/cuda/mma.cuh @@ -0,0 +1,759 @@ +/* + * Copyright (c) 2023 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef FLASHINFER_MMA_CUH_ +#define FLASHINFER_MMA_CUH_ + +#include +#include +#include +#include + +#include + +namespace flashinfer +{ +namespace gpu_iface +{ +namespace mma_impl +{ +namespace cuda +{ + +#if (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 >= 120400) +#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 890)) +#define FLASHINFER_MMA_F8F8F32_M16N8K32_ENABLED +#endif +#endif + +#if (__CUDACC_VER_MAJOR__ >= 11) +#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 900)) +#define FLASHINFER_STMATRIX_M8N8X4_ENABLED +#endif +#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 800)) +#define FLASHINFER_MMA_F16F16F32_M16N8K16_ENABLED +#define FLASHINFER_MMA_F16F16F16_M16N8K16_ENABLED +#endif +#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 750)) +#define FLASHINFER_MMA_F16F16F32_M16N8K8_ENABLED +#define FLASHINFER_MMA_F16F16F16_M16N8K8_ENABLED +#define FLASHINFER_LDMATRIX_M8N8X4_ENABLED +#endif +#endif + +#if defined(__CUDA_ARCH__) +#define FLASHINFER_RUNTIME_ASSERT(x) __brkpt() +#else +#define FLASHINFER_RUNTIME_ASSERT(x) assert(0 && x) +#endif + +enum class MMAMode +{ + kInit = 0U, + kInplaceUpdate = 1U, +}; + +/*! + * \brief Wrapper of PTX ldmatrix m8n8.x4 instruction, loads data from shared + * memory to fragment + * \tparam T data type of the fragment + * \param R pointer to the fragment + * \param smem_ptr pointer to the shared memory + */ +template +__device__ __forceinline__ void load_matrix_m8n8x4(uint32_t *R, T *smem_ptr) +{ +#ifdef FLASHINFER_LDMATRIX_M8N8X4_ENABLED + uint32_t smem_int_ptr = + static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(R[0]), "=r"(R[1]), "=r"(R[2]), "=r"(R[3]) + : "r"(smem_int_ptr)); +#else + FLASHINFER_RUNTIME_ASSERT( + "Unsupported CUDA architecture for ldmatrix instruction"); +#endif +} + +/*! + * \brief Wrapper of PTX ldmatrix m8n8.x4 instruction, loads data from shared + * memory to fragment + * \tparam T data type of the fragment + * \param R pointer to the fragment + * \param smem_ptr pointer to the shared memory + */ +template +__device__ __forceinline__ void load_matrix_m8n8x4_left_half(uint32_t *R, + T *smem_ptr) +{ +#ifdef FLASHINFER_LDMATRIX_M8N8X4_ENABLED + uint32_t smem_int_ptr = + static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, _, %1, _}, [%2];\n" + : "=r"(R[0]), "=r"(R[1]) + : "r"(smem_int_ptr)); +#else + FLASHINFER_RUNTIME_ASSERT( + "Unsupported CUDA architecture for ldmatrix instruction"); +#endif +} + +/*! + * \brief Wrapper of PTX ldmatrix m8n8.x4 instruction, loads data from shared + * memory to fragment + * \tparam T data type of the fragment + * \param R pointer to the fragment + * \param smem_ptr pointer to the shared memory + */ +template +__device__ __forceinline__ void load_matrix_m8n8x4_right_half(uint32_t *R, + T *smem_ptr) +{ +#ifdef FLASHINFER_LDMATRIX_M8N8X4_ENABLED + uint32_t smem_int_ptr = + static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "ldmatrix.sync.aligned.m8n8.x4.shared.b16 {_, %0, _, %1}, [%2];\n" + : "=r"(R[0]), "=r"(R[1]) + : "r"(smem_int_ptr)); +#else + FLASHINFER_RUNTIME_ASSERT( + "Unsupported CUDA architecture for ldmatrix instruction"); +#endif +} + +/*! + * \brief Wrapper of PTX ldmatrix m8n8.x4 transposed instruction, loads data + * from shared memory to fragment and transposes the fragment + * \tparam T data type of the fragment + * \param R pointer to the fragment + * \param smem_ptr pointer to the shared memory + */ +template +__device__ __forceinline__ void load_matrix_m8n8x4_trans(uint32_t *R, + T *smem_ptr) +{ +#ifdef FLASHINFER_LDMATRIX_M8N8X4_ENABLED + uint32_t smem_int_ptr = + static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile("ldmatrix.sync.aligned.trans.m8n8.x4.shared.b16 {%0, %1, %2, " + "%3}, [%4];\n" + : "=r"(R[0]), "=r"(R[1]), "=r"(R[2]), "=r"(R[3]) + : "r"(smem_int_ptr)); +#else + FLASHINFER_RUNTIME_ASSERT( + "Unsupported CUDA architecture for ldmatrix instruction"); +#endif +} + +/*! + * \brief Wrapper of PTX ldmatrix m8n8.x4 transposed instruction, loads data + * from shared memory to fragment and transposes the fragment + * \tparam T data type of the fragment + * \param R pointer to the fragment + * \param smem_ptr pointer to the shared memory + */ +template +__device__ __forceinline__ void load_matrix_m8n8x4_trans_left_half(uint32_t *R, + T *smem_ptr) +{ +#ifdef FLASHINFER_LDMATRIX_M8N8X4_ENABLED + uint32_t smem_int_ptr = + static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "ldmatrix.sync.aligned.trans.m8n8.x4.shared.b16 {%0, %1, _, _}, [%2];\n" + : "=r"(R[0]), "=r"(R[1]) + : "r"(smem_int_ptr)); +#else + FLASHINFER_RUNTIME_ASSERT( + "Unsupported CUDA architecture for ldmatrix instruction"); +#endif +} + +/*! + * \brief Wrapper of PTX ldmatrix m8n8.x4 transposed instruction, loads data + * from shared memory to fragment and transposes the fragment + * \tparam T data type of the fragment + * \param R pointer to the fragment + * \param smem_ptr pointer to the shared memory + */ +template +__device__ __forceinline__ void load_matrix_m8n8x4_trans_right_half(uint32_t *R, + T *smem_ptr) +{ +#ifdef FLASHINFER_LDMATRIX_M8N8X4_ENABLED + uint32_t smem_int_ptr = + static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "ldmatrix.sync.aligned.trans.m8n8.x4.shared.b16 {_, _, %0, %1}, [%2];\n" + : "=r"(R[0]), "=r"(R[1]) + : "r"(smem_int_ptr)); +#else + FLASHINFER_RUNTIME_ASSERT( + "Unsupported CUDA architecture for ldmatrix instruction"); +#endif +} + +/*! + * \brief Wrapper of PTX stmatrix m8n8.x4 instruction, stores data from fragment + * to shared memory + * \tparam T data type of the fragment + * \param R pointer to the fragment + * \param smem_ptr pointer to the shared memory + */ +template +__device__ __forceinline__ void store_matrix_m8n8x4(uint32_t *R, T *smem_ptr) +{ +#ifdef FLASHINFER_STMATRIX_M8N8X4_ENABLED + uint32_t smem_int_ptr = + static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "stmatrix.sync.aligned.m8n8.x4.shared.b16 [%0], {%1, %2, %3, %4};\n" + : + : "r"(smem_int_ptr), "r"(R[0]), "r"(R[1]), "r"(R[2]), "r"(R[3])); +#else + // Fallback implementation, slower than PTX instruction + const uint32_t tx = threadIdx.x; + uint4 word; +#pragma unroll + for (uint32_t reg_id = 0; reg_id < 4; ++reg_id) { + word.x = __shfl_sync(0xffffffff, R[reg_id], (tx % 8) * 4); + word.y = __shfl_sync(0xffffffff, R[reg_id], (tx % 8) * 4 + 1); + word.z = __shfl_sync(0xffffffff, R[reg_id], (tx % 8) * 4 + 2); + word.w = __shfl_sync(0xffffffff, R[reg_id], (tx % 8) * 4 + 3); + if (tx / 8 == reg_id) { + *(uint4 *)smem_ptr = word; + } + } +#endif +} + +/*! + * \brief Wrapper of two mma m16n8k32 instructions for row major and column + * major f8 matrix multiplication, accumulated in f32. + * \tparam T data type of the fragment + * \tparam mma_mode whether we are initializing the accumulator or updating it + * \param C pointer to the accumulator + * \param A pointer to the fragment of matrix A + * \param B pointer to the fragment of matrix B + */ +template +__device__ __forceinline__ void +mma_sync_m16n16k32_row_col_f8f8f32(float *C, uint32_t *A, uint32_t *B) +{ + static_assert(sizeof(T) == 1, "DType must be 8bit floating data type"); +#if defined(FLASHINFER_MMA_F8F8F32_M16N8K32_ENABLED) + if constexpr (mma_mode == MMAMode::kInit) { + if constexpr (std::is_same_v) { + asm volatile("mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), + "r"(B[0]), "r"(B[1]), "f"(0.f), "f"(0.f), "f"(0.f), + "f"(0.f)); + asm volatile("mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), + "r"(B[2]), "r"(B[3]), "f"(0.f), "f"(0.f), "f"(0.f), + "f"(0.f)); + } + else { // e5m2 + asm volatile("mma.sync.aligned.m16n8k32.row.col.f32.e5m2.e5m2.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), + "r"(B[0]), "r"(B[1]), "f"(0.f), "f"(0.f), "f"(0.f), + "f"(0.f)); + asm volatile("mma.sync.aligned.m16n8k32.row.col.f32.e5m2.e5m2.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), + "r"(B[2]), "r"(B[3]), "f"(0.f), "f"(0.f), "f"(0.f), + "f"(0.f)); + } + } + else { + if constexpr (std::is_same_v) { + asm volatile("mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), + "r"(B[0]), "r"(B[1]), "f"(C[0]), "f"(C[1]), + "f"(C[2]), "f"(C[3])); + asm volatile("mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), + "r"(B[2]), "r"(B[3]), "f"(C[4]), "f"(C[5]), + "f"(C[6]), "f"(C[7])); + } + else { // e5m2 + asm volatile("mma.sync.aligned.m16n8k32.row.col.f32.e5m2.e5m2.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), + "r"(B[0]), "r"(B[1]), "f"(C[0]), "f"(C[1]), + "f"(C[2]), "f"(C[3])); + asm volatile("mma.sync.aligned.m16n8k32.row.col.f32.e5m2.e5m2.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), + "r"(B[2]), "r"(B[3]), "f"(C[4]), "f"(C[5]), + "f"(C[6]), "f"(C[7])); + } + } +#else + FLASHINFER_RUNTIME_ASSERT("fp8 mma instruction is only available for sm89, " + "PTX 8.4+ and CUDA 12.4+"); +#endif +} + +/*! + * \brief Wrapper of two mma m16n8k16 instructions for row major and column + * major f16 matrix multiplication, accumulated in f32. + * \tparam T data type of the fragment + * \tparam mma_mode whether we are initializing the accumulator or updating it + * \param C pointer to the accumulator + * \param A pointer to the fragment of matrix A + * \param B pointer to the fragment of matrix B + */ +template +__device__ __forceinline__ void +mma_sync_m16n16k16_row_col_f16f16f32(float *C, uint32_t *A, uint32_t *B) +{ +#if defined(FLASHINFER_MMA_F16F16F32_M16N8K16_ENABLED) + if constexpr (mma_mode == MMAMode::kInit) { + if constexpr (std::is_same_v) { + asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), + "r"(B[0]), "r"(B[1]), "f"(0.f), "f"(0.f), "f"(0.f), + "f"(0.f)); + asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), + "r"(B[2]), "r"(B[3]), "f"(0.f), "f"(0.f), "f"(0.f), + "f"(0.f)); + } + else { + asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), + "r"(B[0]), "r"(B[1]), "f"(0.f), "f"(0.f), "f"(0.f), + "f"(0.f)); + asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), + "r"(B[2]), "r"(B[3]), "f"(0.f), "f"(0.f), "f"(0.f), + "f"(0.f)); + } + } + else { + if constexpr (std::is_same_v) { + asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), + "r"(B[0]), "r"(B[1]), "f"(C[0]), "f"(C[1]), + "f"(C[2]), "f"(C[3])); + asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), + "r"(B[2]), "r"(B[3]), "f"(C[4]), "f"(C[5]), + "f"(C[6]), "f"(C[7])); + } + else { + asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), + "r"(B[0]), "r"(B[1]), "f"(C[0]), "f"(C[1]), + "f"(C[2]), "f"(C[3])); + asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), + "r"(B[2]), "r"(B[3]), "f"(C[4]), "f"(C[5]), + "f"(C[6]), "f"(C[7])); + } + } +#elif defined(FLASHINFER_MMA_F16F16F32_M16N8K8_ENABLED) + if constexpr (std::is_same_v) { + if constexpr (mma_mode == MMAMode::kInit) { + asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(B[0]), "f"(0.f), "f"(0.f), + "f"(0.f), "f"(0.f)); + asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) + : "r"(A[2]), "r"(A[3]), "r"(B[1]), "f"(C[0]), + "f"(C[1]), "f"(C[2]), "f"(C[3])); + asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) + : "r"(A[0]), "r"(A[1]), "r"(B[2]), "f"(0.f), "f"(0.f), + "f"(0.f), "f"(0.f)); + asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) + : "r"(A[2]), "r"(A[3]), "r"(B[3]), "f"(C[4]), + "f"(C[5]), "f"(C[6]), "f"(C[7])); + } + else { + asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(B[0]), "f"(C[0]), + "f"(C[1]), "f"(C[2]), "f"(C[3])); + asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=f"(C[0]), "=f"(C[1]), "=f"(C[2]), "=f"(C[3]) + : "r"(A[2]), "r"(A[3]), "r"(B[1]), "f"(C[0]), + "f"(C[1]), "f"(C[2]), "f"(C[3])); + asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) + : "r"(A[0]), "r"(A[1]), "r"(B[2]), "f"(C[4]), + "f"(C[5]), "f"(C[6]), "f"(C[7])); + asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=f"(C[4]), "=f"(C[5]), "=f"(C[6]), "=f"(C[7]) + : "r"(A[2]), "r"(A[3]), "r"(B[3]), "f"(C[4]), + "f"(C[5]), "f"(C[6]), "f"(C[7])); + } + } + else { + FLASHINFER_RUNTIME_ASSERT( + "Unsupported CUDA architecture for mma instruction"); + } +#else + FLASHINFER_RUNTIME_ASSERT( + "Unsupported CUDA architecture for mma instruction"); +#endif +} + +/*! + * \brief Use mma instructions to compute rowsum. + */ +template +__device__ __forceinline__ void m16k32_rowsum_f8f8f32(float *d, DType *s) +{ + static_assert(sizeof(DType) == 1, "DType must be 8bit floating data type"); + uint32_t *s_u32 = (uint32_t *)(s); +#if defined(FLASHINFER_MMA_F8F8F32_M16N8K32_ENABLED) + if constexpr (std::is_same_v) { + asm volatile("{\n" + "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 " + "{%0, _, %1, _}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, 0., %9, 0.};\n" + "}\n" + : "=f"(d[0]), "=f"(d[1]) + : "r"(s_u32[0]), "r"(s_u32[1]), "r"(s_u32[2]), + "r"(s_u32[3]), "r"(943208504), "r"(943208504), "f"(d[0]), + "f"(d[1])); + } + else { // e5m2 + asm volatile("{\n" + "mma.sync.aligned.m16n8k16.row.col.f32.e5m2.e5m2.f32 " + "{%0, _, %1, _}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, 0., %9, 0.};\n" + "}\n" + : "=f"(d[0]), "=f"(d[1]) + : "r"(s_u32[0]), "r"(s_u32[1]), "r"(s_u32[2]), + "r"(s_u32[3]), "r"(1010580540), "r"(1010580540), + "f"(d[0]), "f"(d[1])); + } +#else + FLASHINFER_RUNTIME_ASSERT("fp8 mma instruction is only available for sm89, " + "PTX 8.4+ and CUDA 12.4+"); +#endif +} + +/*! + * \brief Use mma instructions to compute rowsum. + */ +template +__device__ __forceinline__ void m16k16_rowsum_f16f16f32(float *d, DType *s) +{ + static_assert(sizeof(DType) == 2, "DType must be 16bit floating data type"); + uint32_t *s_u32 = (uint32_t *)(s); +#if defined(FLASHINFER_MMA_F16F16F32_M16N8K16_ENABLED) + if constexpr (std::is_same_v) { + asm volatile("{\n" + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0, _, %1, _}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, 0., %9, 0.};\n" + "}\n" + : "=f"(d[0]), "=f"(d[1]) + : "r"(s_u32[0]), "r"(s_u32[1]), "r"(s_u32[2]), + "r"(s_u32[3]), "r"(1006648320), "r"(1006648320), + "f"(d[0]), "f"(d[1])); + } + else { + asm volatile("{\n" + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0, _, %1, _}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, 0., %9, 0.};\n" + "}\n" + : "=f"(d[0]), "=f"(d[1]) + : "r"(s_u32[0]), "r"(s_u32[1]), "r"(s_u32[2]), + "r"(s_u32[3]), "r"(1065369472), "r"(1065369472), + "f"(d[0]), "f"(d[1])); + } +#elif defined(FLASHINFER_MMA_F16F16F32_M16N8K8_ENABLED) + if constexpr (std::is_same_v) { + asm volatile("{\n" + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " + "{%0, _, %1, _}," + "{%2, %3}," + "{%4}," + "{%5, 0., %6, 0.};\n" + "}\n" + : "=f"(d[0]), "=f"(d[1]) + : "r"(s_u32[0]), "r"(s_u32[1]), "r"(1006648320), "f"(d[0]), + "f"(d[1])); + asm volatile("{\n" + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " + "{%0, _, %1, _}," + "{%2, %3}," + "{%4}," + "{%5, 0., %6, 0.};\n" + "}\n" + : "=f"(d[0]), "=f"(d[1]) + : "r"(s_u32[2]), "r"(s_u32[3]), "r"(1006648320), "f"(d[0]), + "f"(d[1])); + } + else { + FLASHINFER_RUNTIME_ASSERT( + "Unsupported CUDA architecture for mma instruction"); + } +#else + FLASHINFER_RUNTIME_ASSERT( + "Unsupported CUDA architecture for mma instruction"); +#endif +} + +/*! + * \brief Wrapper of two mma m16n8k16 instructions for row major and column + * major f16 matrix multiplication, accumulated in f16. + * \tparam mma_mode whether we are initializing the accumulator or updating it + * \param C pointer to the accumulator + * \param A pointer to the fragment of matrix A + * \param B pointer to the fragment of matrix B + */ +template +__device__ __forceinline__ void +mma_sync_m16n16k16_row_col_f16f16f16(uint32_t *C, uint32_t *A, uint32_t *B) +{ +#if defined(FLASHINFER_MMA_F16F16F16_M16N8K16_ENABLED) + if constexpr (mma_mode == MMAMode::kInit) { + asm volatile("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(C[0]), "=r"(C[1]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), + "r"(B[1]), "r"(0), "r"(0)); + asm volatile("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(C[2]), "=r"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), + "r"(B[3]), "r"(0), "r"(0)); + } + else { + asm volatile("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(C[0]), "=r"(C[1]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), + "r"(B[1]), "r"(C[0]), "r"(C[1])); + asm volatile("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(C[2]), "=r"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[2]), + "r"(B[3]), "r"(C[2]), "r"(C[3])); + } +#elif defined(FLASHINFER_MMA_F16F16F16_M16N8K8_ENABLED) + if constexpr (mma_mode == MMAMode::kInit) { + asm volatile("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 " + "{%0, %1}," + "{%2, %3}," + "{%4}," + "{%5, %6};\n" + : "=r"(C[0]), "=r"(C[1]) + : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(0), "r"(0)); + asm volatile("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 " + "{%0, %1}," + "{%2, %3}," + "{%4}," + "{%5, %6};\n" + : "=r"(C[0]), "=r"(C[1]) + : "r"(A[2]), "r"(A[3]), "r"(B[1]), "r"(0), "r"(0)); + asm volatile("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 " + "{%0, %1}," + "{%2, %3}," + "{%4}," + "{%5, %6};\n" + : "=r"(C[2]), "=r"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(B[2]), "r"(0), "r"(0)); + asm volatile("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 " + "{%0, %1}," + "{%2, %3}," + "{%4}," + "{%5, %6};\n" + : "=r"(C[2]), "=r"(C[3]) + : "r"(A[2]), "r"(A[3]), "r"(B[3]), "r"(0), "r"(0)); + } + else { + asm volatile("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 " + "{%0, %1}," + "{%2, %3}," + "{%4}," + "{%5, %6};\n" + : "=r"(C[0]), "=r"(C[1]) + : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(C[0]), "r"(C[1])); + asm volatile("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 " + "{%0, %1}," + "{%2, %3}," + "{%4}," + "{%5, %6};\n" + : "=r"(C[0]), "=r"(C[1]) + : "r"(A[2]), "r"(A[3]), "r"(B[1]), "r"(C[0]), "r"(C[1])); + asm volatile("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 " + "{%0, %1}," + "{%2, %3}," + "{%4}," + "{%5, %6};\n" + : "=r"(C[2]), "=r"(C[3]) + : "r"(A[0]), "r"(A[1]), "r"(B[2]), "r"(C[2]), "r"(C[3])); + asm volatile("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 " + "{%0, %1}," + "{%2, %3}," + "{%4}," + "{%5, %6};\n" + : "=r"(C[2]), "=r"(C[3]) + : "r"(A[2]), "r"(A[3]), "r"(B[3]), "r"(C[2]), "r"(C[3])); + } +#else + FLASHINFER_RUNTIME_ASSERT( + "Unsupported CUDA architecture for mma instruction"); +#endif +} + +} // namespace cuda +} // namespace mma_impl +} // namespace gpu_iface +} // namespace flashinfer + +#endif // FLASHINFER_MMA_CUH_ diff --git a/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h b/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h new file mode 100644 index 0000000000..48b0e6529e --- /dev/null +++ b/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h @@ -0,0 +1,177 @@ +// SPDX-FileCopyrightText: 2025 Advanced Micro Devices, Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "gpu_iface/fragment.hpp" +#include "gpu_iface/mma_types.hpp" +#include "gpu_iface/platform.hpp" + +namespace flashinfer +{ +namespace gpu_iface +{ +namespace mma_impl +{ +namespace hip +{ + +using flashinfer::gpu_iface::mma::accumulator_fragment_m16n16k16; +using flashinfer::gpu_iface::mma::col_major_fragment_m16n16k16; +using flashinfer::gpu_iface::mma::MMAMode; +using flashinfer::gpu_iface::mma::row_major_fragment_m16n16k16; + +// Architecture detection for MI300 +#if defined(__gfx942__) +#define FLASHINFER_MMA_F16F16F32_M16N16K16_ENABLED +#define FLASHINFER_MMA_BF16BF16F32_M16N16K16_ENABLED +#define FLASHINFER_MMA_F16F16F16_M16N16K16_ENABLED +#define FLASHINFER_LDMATRIX_M8N8X4_ENABLED +#define FLASHINFER_STMATRIX_M8N8X4_ENABLED +#endif + +#define FLASHINFER_RUNTIME_ASSERT(x) assert(0 && x) +// Single unified load function for all fragment types +template +__device__ __forceinline__ void +load_fragment_m16n16(FragmentType &frag, + const typename FragmentType::value_type *ptr, + uint32_t stride) +{ +#ifdef FLASHINFER_LDMATRIX_M8N8X4_ENABLED + if constexpr (std::is_same_v>) + { + // Accumulator fragments need the layout parameter + rocwmma::load_matrix_sync(frag.frag, ptr, stride, + rocwmma::mem_row_major); + } + else { + // Row-major and col-major fragments already have layout baked in + rocwmma::load_matrix_sync(frag.frag, ptr, stride); + } +#else + FLASHINFER_RUNTIME_ASSERT("ldmatrix emulation not supported"); +#endif +} + +// Single unified store function for all fragment types +template +__device__ __forceinline__ void +store_fragment_m16n16(typename FragmentType::value_type *ptr, + const FragmentType &frag, + uint32_t stride) +{ +#ifdef FLASHINFER_STMATRIX_M8N8X4_ENABLED + if constexpr (std::is_same_v>) + { + // Accumulator fragments need the layout parameter + rocwmma::store_matrix_sync(ptr, frag.frag, stride, + rocwmma::mem_row_major); + } + else { + // Row-major and col-major fragments already have layout baked in + rocwmma::store_matrix_sync(ptr, frag.frag, stride); + } +#else + FLASHINFER_RUNTIME_ASSERT("stmatrix emulation not supported"); +#endif +} + +// MMA operation for FP16 inputs with FP32 accumulator +template +__device__ __forceinline__ void mma_sync_m16n16k16_row_col_f16f16f32( + accumulator_fragment_m16n16k16 &c_frag, + const row_major_fragment_m16n16k16 &a_frag, + const col_major_fragment_m16n16k16 &b_frag) +{ +#if defined(FLASHINFER_MMA_F16F16F32_M16N16K16_ENABLED) + // Ensure T is either __half or __hip_bfloat16 + static_assert(std::is_same_v || + std::is_same_v, + "T must be __half or __hip_bfloat16"); + + // Initialize C if requested + if constexpr (mma_mode == MMAMode::kInit) { + rocwmma::fill_fragment(c_frag.frag, 0.0f); + } + + // Perform MMA operation directly with fragments + rocwmma::mma_sync(c_frag.frag, a_frag.frag, b_frag.frag, c_frag.frag); +#else + FLASHINFER_RUNTIME_ASSERT( + "MMA f16f16f32 not supported on this architecture"); +#endif +} + +// MMA operation for FP16 inputs with FP16 accumulator +template +__device__ __forceinline__ void mma_sync_m16n16k16_row_col_f16f16f16( + accumulator_fragment_m16n16k16<__half> &c_frag, + const row_major_fragment_m16n16k16<__half> &a_frag, + const col_major_fragment_m16n16k16<__half> &b_frag) +{ +#if defined(FLASHINFER_MMA_F16F16F16_M16N16K16_ENABLED) + // Initialize C if requested + if constexpr (mma_mode == MMAMode::kInit) { + rocwmma::fill_fragment(c_frag.frag, __float2half(0.0f)); + } + + // Perform MMA + rocwmma::mma_sync(c_frag.frag, a_frag.frag, b_frag.frag, c_frag.frag); +#else + FLASHINFER_RUNTIME_ASSERT( + "MMA f16f16f16 not supported on this architecture"); +#endif +} + +// Rowsum operation using MMA +template +__device__ __forceinline__ void +m16k16_rowsum_f16f16f32(accumulator_fragment_m16n16k16 &d_frag, + const row_major_fragment_m16n16k16 &s_frag) +{ + static_assert(sizeof(DType) == 2, "DType must be 16bit"); + + // Create a ones fragment + col_major_fragment_m16n16k16 ones_frag; + + // Fill with ones + if constexpr (std::is_same_v) { + ones_frag.fill(__float2half(1.0f)); + } + else if constexpr (std::is_same_v) { + ones_frag.fill(__float2bfloat16(1.0f)); + } + + // Use MMA to compute rowsum + mma_sync_m16n16k16_row_col_f16f16f32( + d_frag, s_frag, ones_frag); +} + +// FP8 operations - not implemented for MI300 yet +template +__device__ __forceinline__ void mma_sync_m16n16k32_row_col_f8f8f32( + accumulator_fragment_m16n16k16 &c_frag, + const row_major_fragment_m16n16k16 &a_frag, + const col_major_fragment_m16n16k16 &b_frag) +{ + FLASHINFER_RUNTIME_ASSERT("FP8 MMA not implemented for AMD"); +} + +template +__device__ __forceinline__ void +m16k32_rowsum_f8f8f32(accumulator_fragment_m16n16k16 &d_frag, + const row_major_fragment_m16n16k16 &s_frag) +{ + FLASHINFER_RUNTIME_ASSERT("FP8 rowsum not implemented for AMD"); +} + +} // namespace hip +} // namespace mma_impl +} // namespace gpu_iface +} // namespace flashinfer diff --git a/libflashinfer/include/gpu_iface/fragment.hpp b/libflashinfer/include/gpu_iface/fragment.hpp new file mode 100644 index 0000000000..c997d304b6 --- /dev/null +++ b/libflashinfer/include/gpu_iface/fragment.hpp @@ -0,0 +1,148 @@ +// SPDX-FileCopyrightText: 2025 Advanced Micro Devices, Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "gpu_iface/mma_types.hpp" +#include "gpu_iface/platform.hpp" + +#ifdef PLATFORM_HIP_DEVICE +#include +#endif + +namespace flashinfer +{ +namespace gpu_iface +{ +namespace mma +{ + +enum class FragmentType +{ + row_major, // Row-major matrix layout + col_major, // Column-major matrix layout + accumulator // Accumulator (no layout) +}; + +template +struct fragment_t +{ + using value_type = T; +#ifdef PLATFORM_CUDA_DEVICE + // flashinfer's generic CUDA implementation uses raw arrays for matrix + // fragments and the interface is designed to accomodate use of raw arrays + // for such use cases. + static constexpr int elements_per_thread = + (frag_type == FragmentType::accumulator) ? 8 + : (sizeof(T) == 1) ? 8 + : 4; + + // Number of 32-bit registers needed + static constexpr int num_regs = (elements_per_thread * sizeof(T) + 3) / 4; + + uint32_t data[num_regs]; + + // Provide array-like access + __device__ __forceinline__ T &operator[](int i) + { + return reinterpret_cast(data)[i]; + } + __device__ __forceinline__ const T &operator[](int i) const + { + return reinterpret_cast(data)[i]; + } + + // Get number of elements this thread holds + __device__ __forceinline__ constexpr int size() const + { + return elements_per_thread; + } + + // Get raw pointer for MMA operations + __device__ __forceinline__ uint32_t *raw_ptr() { return data; } + __device__ __forceinline__ const uint32_t *raw_ptr() const { return data; } + +#elif defined(PLATFORM_HIP_DEVICE) + // AMD: Use rocWMMA fragments + using rocwmma_layout = typename std::conditional< + frag_type == FragmentType::row_major, + rocwmma::row_major, + typename std::conditional::type>::type; + + using rocwmma_matrix_t = typename std::conditional< + frag_type == FragmentType::row_major, + rocwmma::matrix_a, + typename std::conditional::type>::type; + + // Select appropriate fragment type based on whether it's accumulator or not + using rocwmma_frag_t = typename std::conditional< + frag_type == FragmentType::accumulator, + rocwmma::fragment, + rocwmma::fragment>::type; + + rocwmma_frag_t frag; + + // Provide array-like access that maps to rocWMMA fragment + __device__ __forceinline__ T operator[](int i) const { return frag.x[i]; } + + // For non-const access, we need to provide a setter since we can't return a + // reference + __device__ __forceinline__ void set(int i, T value) { frag.x[i] = value; } + + // Get number of elements this thread holds + __device__ __forceinline__ int size() const { return frag.num_elements; } + + // Get raw pointer for operations that need it + __device__ __forceinline__ rocwmma_frag_t *raw_ptr() { return &frag; } + __device__ __forceinline__ const rocwmma_frag_t *raw_ptr() const + { + return &frag; + } +#endif + + // Common interface - update fill method to use setter for HIP + __device__ __forceinline__ void fill(T value) + { +#ifdef PLATFORM_CUDA_DEVICE +#pragma unroll + for (int i = 0; i < elements_per_thread; ++i) { + (*this)[i] = value; + } +#elif defined(PLATFORM_HIP_DEVICE) + rocwmma::fill_fragment(frag, value); +#endif + } +}; + +// Convenience typedefs for common fragment types +template +using row_major_fragment_m16n16k16 = + fragment_t; + +template +using col_major_fragment_m16n16k16 = + fragment_t; + +template +using accumulator_fragment_m16n16k16 = + fragment_t; + +// Helper to get compile-time fragment size +template struct fragment_traits +{ +#ifdef PLATFORM_CUDA_DEVICE + static constexpr int size = Fragment::elements_per_thread; +#elif defined(PLATFORM_HIP_DEVICE) + // For HIP, we can't make this constexpr, so provide a device function + __device__ static int get_size(const Fragment &f) { return f.size(); } +#endif +}; + +} // namespace mma +} // namespace gpu_iface +} // namespace flashinfer diff --git a/libflashinfer/include/gpu_iface/mma_ops.hpp b/libflashinfer/include/gpu_iface/mma_ops.hpp new file mode 100644 index 0000000000..a4bca018a0 --- /dev/null +++ b/libflashinfer/include/gpu_iface/mma_ops.hpp @@ -0,0 +1,104 @@ +// SPDX - FileCopyrightText : 2025 Advanced Micro Devices, Inc. +// +// SPDX - License - Identifier : Apache - 2.0 + +#pragma once + +#include "gpu_iface/fragment.hpp" +#include "gpu_iface/mma_types.hpp" + +// Include platform-specific implementations +#if defined(PLATFORM_CUDA_DEVICE) +#include "backend/cuda/mma.cuh" +namespace detail = flashinfer::gpu_iface::mma_impl::cuda; +#elif defined(PLATFORM_HIP_DEVICE) +#include "backend/hip/mma_hip.h" +namespace detail = flashinfer::gpu_iface::mma_impl::hip; +#endif + +namespace flashinfer +{ +namespace gpu_iface +{ +namespace mma +{ + +/*! + * \brief Loads data from shared memory to fragment + * \tparam T data type of the fragment + * \param R pointer to the fragment + * \param smem_ptr pointer to the shared memory + */ +template +__device__ __forceinline__ void +load_fragment_m16n16(FragmentType &frag, + const typename FragmentType::value_type *ptr, + uint32_t stride) +{ + detail::load_fragment_m16n16(frag, ptr, stride); +} + +/*! + * \brief Stores data from fragment to shared memory + * \tparam T data type of the fragment + * \param R pointer to the fragment + * \param smem_ptr pointer to the shared memory + */ +template +__device__ __forceinline__ void +store_fragment_m16n16(typename FragmentType::value_type *ptr, + const FragmentType &frag, + uint32_t stride) +{ + detail::store_fragment_m16n16(ptr, frag, stride); +} + +/*! + * \brief Wrapper of two mma m16n16k16 instructions for row major and column + * major f16 matrix multiplication, accumulated in f32. + * \tparam T data type of the fragment + * \tparam mma_mode whether we are initializing the accumulator or updating it + * \param C pointer to the accumulator + * \param A pointer to the fragment of matrix A + * \param B pointer to the fragment of matrix B + */ +template +__device__ __forceinline__ void +mma_sync_m16n16k16_row_col_f16f16f32(accumulator_fragment_m16n16k16 &d, + const row_major_fragment_m16n16k16 &a, + const col_major_fragment_m16n16k16 &b) +{ + detail::mma_sync_m16n16k16_row_col_f16f16f32(d, a, b); +} + +/*! + * \brief Use mma instructions to compute rowsum. + */ +template +__device__ __forceinline__ void +m16k16_rowsum_f16f16f32(accumulator_fragment_m16n16k16 &d, + const row_major_fragment_m16n16k16 &s) +{ + detail::m16k16_rowsum_f16f16f32(d, s); +} + +/*! + * \brief Wrapper of two mma m16n16k16 instructions for row major and column + * major f16 matrix multiplication, accumulated in f16. + * \tparam mma_mode whether we are initializing the accumulator or updating it + * \param C pointer to the accumulator + * \param A pointer to the fragment of matrix A + * \param B pointer to the fragment of matrix B + */ +template +__device__ __forceinline__ void mma_sync_m16n16k16_row_col_f16f16f16( + accumulator_fragment_m16n16k16<__half> &d, + const row_major_fragment_m16n16k16<__half> &a, + const col_major_fragment_m16n16k16<__half> &b) +{ + detail::mma_sync_m16n16k16_row_col_f16f16f16(d, a, b); +} + +} // namespace mma +} // namespace gpu_iface +} // namespace flashinfer From 2c8e6d654f3b3e0550ffe47809a37b5efe5791dc Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Mon, 4 Aug 2025 04:38:24 -0400 Subject: [PATCH 007/109] Various initial changes to fix build issues for generic/prefill.cuh --- .../flashinfer/attention/generic/dispatch.cuh | 2 +- .../flashinfer/attention/generic/prefill.cuh | 26 ++- .../flashinfer/attention/generic/utils.cuh | 152 ------------------ .../gpu_iface/backend/hip/vec_dtypes_hip.h | 8 +- .../include/gpu_iface/gpu_runtime_compat.hpp | 26 +++ libflashinfer/include/gpu_iface/macros.hpp | 3 +- 6 files changed, 54 insertions(+), 163 deletions(-) delete mode 100644 libflashinfer/include/flashinfer/attention/generic/utils.cuh diff --git a/libflashinfer/include/flashinfer/attention/generic/dispatch.cuh b/libflashinfer/include/flashinfer/attention/generic/dispatch.cuh index cabadd58f8..d36536a6f8 100644 --- a/libflashinfer/include/flashinfer/attention/generic/dispatch.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/dispatch.cuh @@ -5,7 +5,7 @@ #pragma once -#include "enums.hpp" +#include "gpu_iface/enums.hpp" #include "gpu_iface/exception.h" #define DISPATCH_USE_FP16_QK_REDUCTION(use_fp16_qk_reduction, \ diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh index 578484221c..a4964c8714 100644 --- a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -9,8 +9,9 @@ #include "gpu_iface/fastdiv.cuh" #include "gpu_iface/math_ops.hpp" #include "gpu_iface/memory_ops.hpp" -#include "gpu_iface/mma_ops.hpp" +// #include "gpu_iface/mma_ops.hpp" #include "gpu_iface/platform.hpp" +#include "gpu_iface/utils.cuh" #ifdef FP16_QK_REDUCTION_SUPPORTED #include "../../fp16.h" @@ -22,7 +23,6 @@ #include "page.cuh" #include "permuted_smem.cuh" #include "pos_enc.cuh" -#include "utils.cuh" #include "variants.cuh" namespace flashinfer @@ -33,10 +33,10 @@ DEFINE_HAS_MEMBER(maybe_k_rope_offset) namespace cg = flashinfer::gpu_iface::cg; namespace memory = flashinfer::gpu_iface::memory; -namespace mma = gpu_iface::mma; +// namespace mma = gpu_iface::mma; using gpu_iface::vec_dtypes::vec_cast; -using mma::MMAMode; +// using mma::MMAMode; constexpr uint32_t WARP_SIZE = gpu_iface::kWarpSize; #if defined(PLATFORM_HIP_DEVICE) @@ -856,6 +856,8 @@ compute_qk(smem_t *q_smem, if constexpr (std::is_same_v) { +#warning "TODO: mma_sync_m16n16k16_row_col_f16f16f32 ...." +#if 0 if (mma_d == 0) { mma::mma_sync_m16n16k16_row_col_f16f16f32< typename KTraits::DTypeQ, MMAMode::kInit>( @@ -866,8 +868,11 @@ compute_qk(smem_t *q_smem, typename KTraits::DTypeQ>(s_frag[mma_q][mma_kv], a_frag[mma_q], b_frag); } +#endif } else if (std::is_same_v) { +#warning "Not yet implemented" +#if 0 if (mma_d == 0) { mma::mma_sync_m16n16k16_row_col_f16f16f16< MMAMode::kInit>((uint32_t *)s_frag[mma_q][mma_kv], @@ -878,6 +883,7 @@ compute_qk(smem_t *q_smem, (uint32_t *)s_frag[mma_q][mma_kv], a_frag[mma_q], b_frag); } +#endif } } } @@ -1195,6 +1201,8 @@ compute_sfm_v(smem_t *v_smem, for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { #pragma unroll for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { +#warning "TODO: m16k16_rowsum_f16f16f32 ..........." +#if 0 if constexpr (std::is_same_v) { @@ -1205,6 +1213,7 @@ compute_sfm_v(smem_t *v_smem, mma::m16k16_rowsum_f16f16f32(d[mma_q], s_frag[mma_q][mma_kv]); } +#endif } } } @@ -1215,6 +1224,8 @@ compute_sfm_v(smem_t *v_smem, for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { uint32_t b_frag[4]; if constexpr (sizeof(typename KTraits::DTypeKV) == 1) { +#warning "Not yet implemented......" +#if 0 uint32_t b_frag_f8[2]; if (mma_d % 2 == 0) { v_smem->ldmatrix_m8n8x4_trans_left_half(*v_smem_offset_r, @@ -1232,12 +1243,16 @@ compute_sfm_v(smem_t *v_smem, cast<8>((typename KTraits::DTypeQ *)b_frag, (typename KTraits::DTypeKV *)b_frag_f8); swap(b_frag[1], b_frag[2]); +#endif } else { - v_smem->ldmatrix_m8n8x4_trans(*v_smem_offset_r, b_frag); +#warning "TODO ldmatrix_m8n8x4_trans ............" + // v_smem->ldmatrix_m8n8x4_trans(*v_smem_offset_r, b_frag); } #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { +#warning "TODO mma_sync_m16n16k16_row_col_f16f16f32 ............" +#if 0 if constexpr (std::is_same_v) { @@ -1252,6 +1267,7 @@ compute_sfm_v(smem_t *v_smem, o_frag[mma_q][mma_d], (uint32_t *)s_frag[mma_q][mma_kv], b_frag); } +#endif } if constexpr (sizeof(typename KTraits::DTypeKV) == 1) { if (mma_d % 2 == 1) { diff --git a/libflashinfer/include/flashinfer/attention/generic/utils.cuh b/libflashinfer/include/flashinfer/attention/generic/utils.cuh deleted file mode 100644 index acbd374141..0000000000 --- a/libflashinfer/include/flashinfer/attention/generic/utils.cuh +++ /dev/null @@ -1,152 +0,0 @@ -// SPDX - FileCopyrightText : 2023-2035 FlashInfer team. -// SPDX - FileCopyrightText : 2025 Advanced Micro Devices, Inc. -// -// SPDX - License - Identifier : Apache 2.0 - -#pragma once - -#include "gpu_iface/gpu_runtime_compat.hpp" -#include -#include -#include -#include - -#define STR_HELPER(x) #x -#define STR(x) STR_HELPER(x) - -// macro to turn off fp16 qk reduction to reduce binary -#ifndef FLASHINFER_ALWAYS_DISUSE_FP16_QK_REDUCTION -#define FLASHINFER_ALWAYS_DISUSE_FP16_QK_REDUCTION 0 -#endif - -namespace flashinfer -{ - -template -__forceinline__ __device__ __host__ T1 ceil_div(const T1 x, const T2 y) -{ - return (x + y - 1) / y; -} - -#if defined(PLATFORM_CUDA_DEVICE) -inline std::pair GetCudaComputeCapability() -{ - int device_id = 0; - cudaGetDevice(&device_id); - int major = 0, minor = 0; - cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, - device_id); - cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, - device_id); - return std::make_pair(major, minor); -} -#endif - -template -inline void -DebugPrintCUDAArray(T *device_ptr, size_t size, std::string prefix = "") -{ - std::vector host_array(size); - std::cout << prefix; - gpuMemcpy(host_array.data(), device_ptr, size * sizeof(T), - gpuMemcpyDeviceToHost); - for (size_t i = 0; i < size; ++i) { - std::cout << host_array[i] << " "; - } - std::cout << std::endl; -} - -inline uint32_t FA2DetermineCtaTileQ(int64_t avg_packed_qo_len, - uint32_t head_dim) -{ -#if defined(PLATFORM_CUDA_DEVICE) - if (avg_packed_qo_len > 64 && head_dim < 256) { - return 128; - } - else { - auto compute_capacity = GetCudaComputeCapability(); - if (compute_capacity.first >= 8) { - // Ampere or newer - if (avg_packed_qo_len > 16) { - // avg_packed_qo_len <= 64 - return 64; - } - else { - // avg_packed_qo_len <= 16 - return 16; - } - } - else { - // NOTE(Zihao): not enough shared memory on Turing for 1x4 warp - // layout - return 64; - } - } -#elif defined(PLATFORM_HIP_DEVICE) - // Simplified version for HIP - if (avg_packed_qo_len > 64 && head_dim < 256) { - return 128; - } - else { - return avg_packed_qo_len <= 16 ? 16 : 64; - } -#endif -} - -/*! - * \brief Return x - y if x > y, otherwise return 0. - */ -__device__ __forceinline__ uint32_t sub_if_greater_or_zero(uint32_t x, - uint32_t y) -{ - return (x > y) ? x - y : 0U; -} - -__device__ __forceinline__ void swap(uint32_t &a, uint32_t &b) -{ - uint32_t tmp = a; - a = b; - b = tmp; -} - -__device__ __forceinline__ uint32_t dim2_offset(const uint32_t &dim_a, - const uint32_t &idx_b, - const uint32_t &idx_a) -{ - return idx_b * dim_a + idx_a; -} - -__device__ __forceinline__ uint32_t dim3_offset(const uint32_t &dim_b, - const uint32_t &dim_a, - const uint32_t &idx_c, - const uint32_t &idx_b, - const uint32_t &idx_a) -{ - return (idx_c * dim_b + idx_b) * dim_a + idx_a; -} - -__device__ __forceinline__ uint32_t dim4_offset(const uint32_t &dim_c, - const uint32_t &dim_b, - const uint32_t &dim_a, - const uint32_t &idx_d, - const uint32_t &idx_c, - const uint32_t &idx_b, - const uint32_t &idx_a) -{ - return ((idx_d * dim_c + idx_c) * dim_b + idx_b) * dim_a + idx_a; -} - -#define DEFINE_HAS_MEMBER(member) \ - template \ - struct has_##member : std::false_type \ - { \ - }; \ - template \ - struct has_##member().member)>> \ - : std::true_type \ - { \ - }; \ - template \ - inline constexpr bool has_##member##_v = has_##member::value; - -} // namespace flashinfer diff --git a/libflashinfer/include/gpu_iface/backend/hip/vec_dtypes_hip.h b/libflashinfer/include/gpu_iface/backend/hip/vec_dtypes_hip.h index fde75ab60e..cc903a753f 100644 --- a/libflashinfer/include/gpu_iface/backend/hip/vec_dtypes_hip.h +++ b/libflashinfer/include/gpu_iface/backend/hip/vec_dtypes_hip.h @@ -1805,7 +1805,7 @@ template struct vec_t<__hip_bfloat16, vec_size> } FLASHINFER_INLINE void fill(__hip_bfloat16 val) { -#pragma unoll +#pragma unroll for (size_t i = 0; i < vec_size / 8; ++i) { *(__hip_bfloat162 *)(&(data[i].x)) = make_bfloat162(val, val); *(__hip_bfloat162 *)(&(data[i].y)) = make_bfloat162(val, val); @@ -1815,14 +1815,14 @@ template struct vec_t<__hip_bfloat16, vec_size> } FLASHINFER_INLINE void load(const __hip_bfloat16 *ptr) { -#pragma unoll +#pragma unroll for (size_t i = 0; i < vec_size / 8; ++i) { data[i] = ((uint4 *)ptr)[i]; } } FLASHINFER_INLINE void store(__hip_bfloat16 *ptr) const { -#pragma unoll +#pragma unroll for (size_t i = 0; i < vec_size / 8; ++i) { ((uint4 *)ptr)[i] = data[i]; } @@ -1843,7 +1843,7 @@ template struct vec_t<__hip_bfloat16, vec_size> FLASHINFER_INLINE static void memcpy(__hip_bfloat16 *dst, const __hip_bfloat16 *src) { -#pragma unoll +#pragma unroll for (size_t i = 0; i < vec_size / 8; ++i) { ((uint4 *)dst)[i] = ((uint4 *)src)[i]; } diff --git a/libflashinfer/include/gpu_iface/gpu_runtime_compat.hpp b/libflashinfer/include/gpu_iface/gpu_runtime_compat.hpp index 2a0f8907c1..9afb9964fa 100644 --- a/libflashinfer/include/gpu_iface/gpu_runtime_compat.hpp +++ b/libflashinfer/include/gpu_iface/gpu_runtime_compat.hpp @@ -21,9 +21,11 @@ // Basic type mappings #if defined(PLATFORM_CUDA_DEVICE) +#define gpuEvent_t cudaEvent_t #define gpuError_t cudaError_t #define gpuStream_t cudaStream_t #elif defined(PLATFORM_HIP_DEVICE) +#define gpuEvent_t hipEvent_t #define gpuError_t hipError_t #define gpuStream_t hipStream_t #endif @@ -87,6 +89,30 @@ hipOccupancyMaxActiveBlocksPerMultiprocessor #endif +// Event iface +#if defined(PLATFORM_CUDA_DEVICE) +#define gpuEventCreate cudaEventCreate +#define gpuEventDestroy cudaEventDestroy +#define gpuEventRecord cudaEventRecord +#define gpuEventSynchronize cudaEventSynchronize +#define gpuEventElapsedTime cudaEventElapsedTime +#elif defined(PLATFORM_HIP_DEVICE) +#define gpuEventCreate hipEventCreate +#define gpuEventDestroy hipEventDestroy +#define gpuEventRecord hipEventRecord +#define gpuEventSynchronize hipEventSynchronize +#define gpuEventElapsedTime hipEventElapsedTime +#endif + +// Stream iface +#if defined(PLATFORM_CUDA_DEVICE) +#define gpuStreamCreate cudaStreamCreate +#define gpuStreamDestroy cudaStreamDestroy +#elif defined(PLATFORM_HIP_DEVICE) +#define gpuStreamCreate hipStreamCreate +#define gpuStreamDestroy hipStreamDestroy +#endif + // Error handling (for FI_GPU_CALL) #if defined(PLATFORM_CUDA_DEVICE) #define gpuGetErrorString cudaGetErrorString diff --git a/libflashinfer/include/gpu_iface/macros.hpp b/libflashinfer/include/gpu_iface/macros.hpp index 391cc88ceb..91295d35e1 100644 --- a/libflashinfer/include/gpu_iface/macros.hpp +++ b/libflashinfer/include/gpu_iface/macros.hpp @@ -12,7 +12,8 @@ #define PLATFORM_HIP_DEVICE // FIXME: Temporarily setting __forceinline__ to inline as amdclang++ 6.4 throws // an error when __forceinline__ is used. -#define __forceinline__ inline +// #define __forceinline__ inline +#define __grid_constant__ #elif defined(__CUDACC__) || defined(__CUDA_ARCH__) #define PLATFORM_CUDA_DEVICE #endif From 1d5bf96dc20fbf8993b002d24483aba120d0b5ed Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Mon, 4 Aug 2025 04:39:01 -0400 Subject: [PATCH 008/109] A standalone driver for singleprefill --- examples/cpp/standalone_single_prefill.cu | 699 ++++++++++++++++++++++ 1 file changed, 699 insertions(+) create mode 100644 examples/cpp/standalone_single_prefill.cu diff --git a/examples/cpp/standalone_single_prefill.cu b/examples/cpp/standalone_single_prefill.cu new file mode 100644 index 0000000000..d35c290ab6 --- /dev/null +++ b/examples/cpp/standalone_single_prefill.cu @@ -0,0 +1,699 @@ +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace flashinfer +{ + +// Parameter struct for SinglePrefill +template struct SinglePrefillParams +{ + using DTypeQ = half; + using DTypeKV = half; + using DTypeO = DTypeOs; + using IdType = IdTypes; + + half *q; + half *k; + half *v; + DTypeO *o; + float *lse; + uint_fastdiv group_size; + + uint8_t *maybe_custom_mask; + float *maybe_alibi_slopes; + double logits_soft_cap; + double sm_scale; + double rope_rcp_scale; + double rope_rcp_theta; + + uint32_t qo_len; + uint32_t kv_len; + uint32_t num_qo_heads; + uint32_t num_kv_heads; + uint32_t q_stride_n; + uint32_t q_stride_h; + uint32_t k_stride_n; + uint32_t k_stride_h; + uint32_t v_stride_n; + uint32_t v_stride_h; + uint32_t head_dim; + int32_t window_left; + + bool partition_kv; + + __host__ __device__ __forceinline__ uint32_t + get_qo_len(uint32_t batch_idx) const + { + return qo_len; + } + + __host__ __device__ __forceinline__ uint32_t + get_kv_len(uint32_t batch_idx) const + { + return kv_len; + } +}; + +} // namespace flashinfer + +// CPU reference implementation for validation +namespace reference +{ + +template +std::vector single_mha(const std::vector &q, + const std::vector &k, + const std::vector &v, + size_t qo_len, + size_t kv_len, + size_t num_qo_heads, + size_t num_kv_heads, + size_t head_dim, + bool causal, + flashinfer::QKVLayout kv_layout, + flashinfer::PosEncodingMode pos_encoding_mode, + float rope_scale = 1.0f, + float rope_theta = 10000.0f) +{ + float sm_scale = 1.0f / std::sqrt(static_cast(head_dim)); + std::vector o(qo_len * num_qo_heads * head_dim, static_cast(0.0f)); + std::vector att(kv_len); + size_t group_size = num_qo_heads / num_kv_heads; + + for (size_t qo_head_idx = 0; qo_head_idx < num_qo_heads; ++qo_head_idx) { + size_t kv_head_idx = qo_head_idx / group_size; + + for (size_t q_idx = 0; q_idx < qo_len; ++q_idx) { + // 1. Compute attention scores + float max_val = -5e4f; + + for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { + if (causal && kv_idx > kv_len + q_idx - qo_len) { + att[kv_idx] = -5e4f; + continue; + } + + // Compute dot product between Q and K + float score = 0.0f; + for (size_t d = 0; d < head_dim; ++d) { + float q_val = 0.0f; + float k_val = 0.0f; + + // Get Q value - always NHD layout + size_t q_offset = q_idx * num_qo_heads * head_dim + + qo_head_idx * head_dim + d; + q_val = static_cast(q[q_offset]); + + // Get K value - depends on layout + if (kv_layout == flashinfer::QKVLayout::kNHD) { + size_t k_offset = kv_idx * num_kv_heads * head_dim + + kv_head_idx * head_dim + d; + k_val = static_cast(k[k_offset]); + } + else { + size_t k_offset = kv_head_idx * kv_len * head_dim + + kv_idx * head_dim + d; + k_val = static_cast(k[k_offset]); + } + + score += q_val * k_val; + } + score *= sm_scale; + + att[kv_idx] = score; + max_val = std::max(max_val, score); + } + + // 2. Apply softmax + float sum_exp = 0.0f; + for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { + if (causal && kv_idx > kv_len + q_idx - qo_len) { + att[kv_idx] = 0.0f; + } + else { + att[kv_idx] = std::exp(att[kv_idx] - max_val); + sum_exp += att[kv_idx]; + } + } + + // Normalize + for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { + if (sum_exp > 0.0f) { + att[kv_idx] /= sum_exp; + } + } + + // 3. Compute weighted sum of values + for (size_t d = 0; d < head_dim; ++d) { + float weighted_sum = 0.0f; + + for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { + float v_val = 0.0f; + + // Get V value - depends on layout + if (kv_layout == flashinfer::QKVLayout::kNHD) { + size_t v_offset = kv_idx * num_kv_heads * head_dim + + kv_head_idx * head_dim + d; + v_val = static_cast(v[v_offset]); + } + else { + size_t v_offset = kv_head_idx * kv_len * head_dim + + kv_idx * head_dim + d; + v_val = static_cast(v[v_offset]); + } + + weighted_sum += att[kv_idx] * v_val; + } + + // Store result in output + size_t o_offset = q_idx * num_qo_heads * head_dim + + qo_head_idx * head_dim + d; + o[o_offset] = static_cast(weighted_sum); + } + } + } + + return o; +} + +} // namespace reference + +// Function to validate GPU results against CPU reference +bool validate_results(const thrust::host_vector &gpu_output, + const std::vector &cpu_output, + float rtol = 1e-3f, + float atol = 1e-3f) +{ + if (gpu_output.size() != cpu_output.size()) { + std::cerr << "Size mismatch: GPU=" << gpu_output.size() + << " vs CPU=" << cpu_output.size() << std::endl; + return false; + } + + int errors = 0; + float max_diff = 0.0f; + float max_rel_diff = 0.0f; + + for (size_t i = 0; i < gpu_output.size(); ++i) { + float gpu_val = static_cast(gpu_output[i]); + float cpu_val = static_cast(cpu_output[i]); + float abs_diff = std::abs(gpu_val - cpu_val); + float rel_diff = + (cpu_val != 0.0f) ? abs_diff / std::abs(cpu_val) : abs_diff; + + max_diff = std::max(max_diff, abs_diff); + max_rel_diff = std::max(max_rel_diff, rel_diff); + + bool close = (abs_diff <= atol + rtol * std::abs(cpu_val)); + if (!close) { + errors++; + if (errors <= 10) { // Print just a few examples + std::cerr << "Mismatch at " << i << ": GPU=" << gpu_val + << " CPU=" << cpu_val << " (diff=" << abs_diff << ")" + << std::endl; + } + } + } + + float error_rate = static_cast(errors) / gpu_output.size(); + std::cout << "\nValidation Results:" << std::endl; + std::cout << " Max absolute difference: " << max_diff << std::endl; + std::cout << " Max relative difference: " << max_rel_diff << std::endl; + std::cout << " Error rate: " << (error_rate * 100) << "% (" << errors + << " / " << gpu_output.size() << " elements)" << std::endl; + std::cout << " Status: " << (error_rate < 0.05 ? "PASSED" : "FAILED") + << std::endl; + + // Allow up to 5% error rate (similar to the threshold used in the unit + // tests) + return error_rate < 0.05; +} + +using namespace flashinfer; + +// Helper class to convert strings to parameters +class ArgParser +{ +public: + static bool get_bool(const char *arg, bool default_val) + { + return arg == nullptr + ? default_val + : (std::string(arg) == "1" || std::string(arg) == "true"); + } + + static int get_int(const char *arg, int default_val) + { + return arg == nullptr ? default_val : std::atoi(arg); + } + + static float get_float(const char *arg, float default_val) + { + return arg == nullptr ? default_val : std::atof(arg); + } + + static PosEncodingMode get_pos_encoding_mode(const char *arg) + { + if (arg == nullptr) + return PosEncodingMode::kNone; + std::string str_val = arg; + if (str_val == "none") + return PosEncodingMode::kNone; + if (str_val == "rope") + return PosEncodingMode::kRoPELlama; + if (str_val == "alibi") + return PosEncodingMode::kALiBi; + return PosEncodingMode::kNone; + } + + static QKVLayout get_layout(const char *arg) + { + if (arg == nullptr) + return QKVLayout::kNHD; + std::string str_val = arg; + if (str_val == "nhd") + return QKVLayout::kNHD; + if (str_val == "hnd") + return QKVLayout::kHND; + return QKVLayout::kNHD; + } +}; + +// Helper function to generate random data on device +void generate_random_data(thrust::device_vector &data, + float min_val = -1.0f, + float max_val = 1.0f) +{ + thrust::host_vector host_data(data.size()); + + thrust::default_random_engine rng(42); // Fixed seed for reproducibility + thrust::uniform_real_distribution dist(min_val, max_val); + + for (size_t i = 0; i < host_data.size(); ++i) { + host_data[i] = static_cast(dist(rng)); + } + + data = host_data; +} + +// Dispatch function for half precision +gpuError_t dispatch_single_prefill(half *q_ptr, + half *k_ptr, + half *v_ptr, + half *o_ptr, + half *tmp_ptr, + float *lse_ptr, + uint32_t num_qo_heads, + uint32_t num_kv_heads, + uint32_t qo_len, + uint32_t kv_len, + uint32_t head_dim, + QKVLayout kv_layout, + PosEncodingMode pos_encoding_mode, + bool causal, + bool use_fp16_qk_reduction, + double sm_scale, + int32_t window_left, + double rope_scale, + double rope_theta, + gpuStream_t stream) +{ + // Compute strides based on layout + uint32_t q_stride_n = num_qo_heads * head_dim; + uint32_t q_stride_h = head_dim; + uint32_t k_stride_n, k_stride_h, v_stride_n, v_stride_h; + + if (kv_layout == QKVLayout::kNHD) { + k_stride_n = num_kv_heads * head_dim; + k_stride_h = head_dim; + v_stride_n = num_kv_heads * head_dim; + v_stride_h = head_dim; + } + else { + k_stride_h = kv_len * head_dim; + k_stride_n = head_dim; + v_stride_h = kv_len * head_dim; + v_stride_n = head_dim; + } + + // Configure mask mode + const MaskMode mask_mode = causal ? MaskMode::kCausal : MaskMode::kNone; + + // Constants for prefill kernel + constexpr uint32_t HEAD_DIM_QK = 128; + constexpr uint32_t HEAD_DIM_VO = 128; + constexpr PosEncodingMode POS_ENCODING_MODE = PosEncodingMode::kRoPELlama; + constexpr bool USE_FP16_QK_REDUCTION = false; + + gpuError_t status = gpuSuccess; + + if (causal) { + // Causal attention + using AttentionVariantType = + DefaultAttention; + using Params = SinglePrefillParams; + + Params params; + params.q = q_ptr; + params.k = k_ptr; + params.v = v_ptr; + params.o = o_ptr; + params.lse = lse_ptr; + params.num_qo_heads = num_qo_heads; + params.num_kv_heads = num_kv_heads; + params.group_size = uint_fastdiv(num_qo_heads / num_kv_heads); + params.qo_len = qo_len; + params.kv_len = kv_len; + params.q_stride_n = q_stride_n; + params.q_stride_h = q_stride_h; + params.k_stride_n = k_stride_n; + params.k_stride_h = k_stride_h; + params.v_stride_n = v_stride_n; + params.v_stride_h = v_stride_h; + params.head_dim = head_dim; + params.window_left = window_left; + params.partition_kv = false; + params.maybe_custom_mask = nullptr; + params.maybe_alibi_slopes = nullptr; + params.logits_soft_cap = 0.0; + params.sm_scale = sm_scale; + params.rope_rcp_scale = 1.0 / rope_scale; + params.rope_rcp_theta = 1.0 / rope_theta; + + status = SinglePrefillWithKVCacheDispatched< + HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE, USE_FP16_QK_REDUCTION, + MaskMode::kCausal, AttentionVariantType>(params, tmp_ptr, stream); + } + else { + // Non-causal attention + using AttentionVariantType = + DefaultAttention; + using Params = SinglePrefillParams; + + Params params; + params.q = q_ptr; + params.k = k_ptr; + params.v = v_ptr; + params.o = o_ptr; + params.lse = lse_ptr; + params.num_qo_heads = num_qo_heads; + params.num_kv_heads = num_kv_heads; + params.group_size = uint_fastdiv(num_qo_heads / num_kv_heads); + params.qo_len = qo_len; + params.kv_len = kv_len; + params.q_stride_n = q_stride_n; + params.q_stride_h = q_stride_h; + params.k_stride_n = k_stride_n; + params.k_stride_h = k_stride_h; + params.v_stride_n = v_stride_n; + params.v_stride_h = v_stride_h; + params.head_dim = head_dim; + params.window_left = window_left; + params.partition_kv = false; + params.maybe_custom_mask = nullptr; + params.maybe_alibi_slopes = nullptr; + params.logits_soft_cap = 0.0; + params.sm_scale = sm_scale; + params.rope_rcp_scale = 1.0 / rope_scale; + params.rope_rcp_theta = 1.0 / rope_theta; + + status = SinglePrefillWithKVCacheDispatched< + HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE, USE_FP16_QK_REDUCTION, + MaskMode::kNone, AttentionVariantType>(params, tmp_ptr, stream); + } + + return status; +} + +// Function to calculate FLOPs for single_prefill +double calculate_flops(uint32_t qo_len, + uint32_t kv_len, + uint32_t num_qo_heads, + uint32_t head_dim, + bool causal) +{ + double flops; + if (causal) { + // For causal attention: qo_len * (2 * kv_len - qo_len) * 2 * + // num_qo_heads * head_dim + flops = static_cast(qo_len) * (2.0 * kv_len - qo_len) * 2.0 * + num_qo_heads * head_dim; + } + else { + // For non-causal attention: qo_len * kv_len * 4 * num_qo_heads * + // head_dim + flops = static_cast(qo_len) * kv_len * 4.0 * num_qo_heads * + head_dim; + } + return flops; +} + +void print_usage(const char *program_name) +{ + std::cerr + << "Usage: " << program_name << " [options]\n" + << "Options:\n" + << " --qo_len : Query sequence length (default: " + "512)\n" + << " --kv_len : Key/value sequence length (default: " + "512)\n" + << " --num_qo_heads : Number of query heads (default: 32)\n" + << " --num_kv_heads : Number of key/value heads (default: " + "32)\n" + << " --head_dim : Head dimension (default: 128)\n" + << " --layout : KV tensor layout (default: nhd)\n" + << " --pos_encoding : Position encoding mode " + "(default: none)\n" + << " --causal <0|1> : Use causal mask (default: 1)\n" + << " --use_fp16_qk <0|1> : Use FP16 for QK reduction (default: " + "0)\n" + << " --window_left : Window left size (default: -1)\n" + << " --rope_scale : RoPE scale factor (default: 1.0)\n" + << " --rope_theta : RoPE theta (default: 10000.0)\n" + << " --iterations : Number of iterations for timing " + "(default: 10)\n" + << " --warmup : Number of warmup iterations " + "(default: 5)\n" + << " --validate <0|1> : Validate against CPU reference " + "(default: 0)\n"; +} + +int main(int argc, char *argv[]) +{ + // Default parameter values + uint32_t qo_len = 512; + uint32_t kv_len = 512; + uint32_t num_qo_heads = 32; + uint32_t num_kv_heads = 32; + uint32_t head_dim = 128; + bool causal = true; + bool use_fp16_qk_reduction = false; + QKVLayout kv_layout = QKVLayout::kNHD; + PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone; + int32_t window_left = -1; + float rope_scale = 1.0f; + float rope_theta = 10000.0f; + int iterations = 10; + int warmup = 5; + bool validate = false; + + // Parse command-line arguments + for (int i = 1; i < argc; i++) { + std::string arg = argv[i]; + if (arg == "--qo_len" && i + 1 < argc) + qo_len = std::atoi(argv[++i]); + else if (arg == "--kv_len" && i + 1 < argc) + kv_len = std::atoi(argv[++i]); + else if (arg == "--num_qo_heads" && i + 1 < argc) + num_qo_heads = std::atoi(argv[++i]); + else if (arg == "--num_kv_heads" && i + 1 < argc) + num_kv_heads = std::atoi(argv[++i]); + else if (arg == "--head_dim" && i + 1 < argc) + head_dim = std::atoi(argv[++i]); + else if (arg == "--causal" && i + 1 < argc) + causal = ArgParser::get_bool(argv[++i], true); + else if (arg == "--use_fp16_qk" && i + 1 < argc) + use_fp16_qk_reduction = ArgParser::get_bool(argv[++i], false); + else if (arg == "--layout" && i + 1 < argc) + kv_layout = ArgParser::get_layout(argv[++i]); + else if (arg == "--pos_encoding" && i + 1 < argc) + pos_encoding_mode = ArgParser::get_pos_encoding_mode(argv[++i]); + else if (arg == "--window_left" && i + 1 < argc) + window_left = std::atoi(argv[++i]); + else if (arg == "--rope_scale" && i + 1 < argc) + rope_scale = std::atof(argv[++i]); + else if (arg == "--rope_theta" && i + 1 < argc) + rope_theta = std::atof(argv[++i]); + else if (arg == "--iterations" && i + 1 < argc) + iterations = std::atoi(argv[++i]); + else if (arg == "--warmup" && i + 1 < argc) + warmup = std::atoi(argv[++i]); + else if (arg == "--validate" && i + 1 < argc) + validate = ArgParser::get_bool(argv[++i], false); + else if (arg == "--help") { + print_usage(argv[0]); + return 0; + } + } + + // Verify that num_qo_heads is divisible by num_kv_heads + if (num_qo_heads % num_kv_heads != 0) { + std::cerr << "Error: num_qo_heads must be divisible by num_kv_heads" + << std::endl; + return 1; + } + + // Display configuration + std::cout << "Configuration:" << std::endl; + std::cout << " qo_len = " << qo_len << std::endl; + std::cout << " kv_len = " << kv_len << std::endl; + std::cout << " num_qo_heads = " << num_qo_heads << std::endl; + std::cout << " num_kv_heads = " << num_kv_heads << std::endl; + std::cout << " head_dim = " << head_dim << std::endl; + std::cout << " kv_layout = " + << (kv_layout == QKVLayout::kNHD ? "NHD" : "HND") << std::endl; + std::cout << " causal = " << (causal ? "true" : "false") << std::endl; + std::cout << " data_type = half" << std::endl; + std::cout << " use_fp16_qk_reduction = " + << (use_fp16_qk_reduction ? "true" : "false") << std::endl; + std::cout << " validate = " << (validate ? "true" : "false") << std::endl; + + // Initialize and create stream + gpuStream_t stream; + gpuStreamCreate(&stream); + + // Allocate device memory using Thrust - only for half precision + thrust::device_vector q(qo_len * num_qo_heads * head_dim); + thrust::device_vector k(kv_len * num_kv_heads * head_dim); + thrust::device_vector v(kv_len * num_kv_heads * head_dim); + thrust::device_vector o(qo_len * num_qo_heads * head_dim); + thrust::device_vector tmp(qo_len * num_qo_heads * head_dim); + thrust::device_vector lse(qo_len * num_qo_heads); + + // Generate random data + generate_random_data(q); + generate_random_data(k); + generate_random_data(v); + thrust::fill(o.begin(), o.end(), half(0.0f)); + thrust::fill(tmp.begin(), tmp.end(), half(0.0f)); + thrust::fill(lse.begin(), lse.end(), 0.0f); + + // Calculate SM scale if not provided + float sm_scale = 1.0f / std::sqrt(static_cast(head_dim)); + + // Warm-up runs + for (int i = 0; i < warmup; ++i) { + gpuError_t status = dispatch_single_prefill( + thrust::raw_pointer_cast(q.data()), + thrust::raw_pointer_cast(k.data()), + thrust::raw_pointer_cast(v.data()), + thrust::raw_pointer_cast(o.data()), + thrust::raw_pointer_cast(tmp.data()), + thrust::raw_pointer_cast(lse.data()), num_qo_heads, num_kv_heads, + qo_len, kv_len, head_dim, kv_layout, pos_encoding_mode, causal, + use_fp16_qk_reduction, sm_scale, window_left, rope_scale, + rope_theta, stream); + + if (status != gpuSuccess) { + std::cerr << "Error during warmup: " << gpuGetErrorString(status) + << std::endl; + return 1; + } + } + + // Timing runs + gpuEvent_t start, stop; + gpuEventCreate(&start); + gpuEventCreate(&stop); + + gpuEventRecord(start, stream); + + for (int i = 0; i < iterations; ++i) { + gpuError_t status = dispatch_single_prefill( + thrust::raw_pointer_cast(q.data()), + thrust::raw_pointer_cast(k.data()), + thrust::raw_pointer_cast(v.data()), + thrust::raw_pointer_cast(o.data()), + thrust::raw_pointer_cast(tmp.data()), + thrust::raw_pointer_cast(lse.data()), num_qo_heads, num_kv_heads, + qo_len, kv_len, head_dim, kv_layout, pos_encoding_mode, causal, + use_fp16_qk_reduction, sm_scale, window_left, rope_scale, + rope_theta, stream); + + if (status != gpuSuccess) { + std::cerr << "Error during benchmark: " << gpuGetErrorString(status) + << std::endl; + return 1; + } + } + + gpuEventRecord(stop, stream); + gpuEventSynchronize(stop); + + float elapsed_ms; + gpuEventElapsedTime(&elapsed_ms, start, stop); + float avg_ms = elapsed_ms / iterations; + + // Calculate FLOPS + double flops = + calculate_flops(qo_len, kv_len, num_qo_heads, head_dim, causal); + double tflops = flops / (avg_ms * 1e-3) / 1e12; + + // Report results + std::cout << std::fixed << std::setprecision(4); + std::cout << "Performance Results:" << std::endl; + std::cout << " Average time: " << avg_ms << " ms" << std::endl; + std::cout << " Performance: " << tflops << " TFLOPS" << std::endl; + + // Run validation if requested + if (validate) { + std::cout << "\nRunning validation..." << std::endl; + + // Copy output from GPU to host for validation + thrust::host_vector h_output = o; + + // Create input data on host for CPU reference + std::vector h_q(q.begin(), q.end()); + std::vector h_k(k.begin(), k.end()); + std::vector h_v(v.begin(), v.end()); + + // Compute reference output on CPU + std::vector ref_output = reference::single_mha( + h_q, h_k, h_v, qo_len, kv_len, num_qo_heads, num_kv_heads, head_dim, + causal, kv_layout, pos_encoding_mode, rope_scale, rope_theta); + + // Validate results + bool validation_passed = validate_results(h_output, ref_output); + + // Report validation status + std::cout << "Validation " << (validation_passed ? "PASSED" : "FAILED") + << std::endl; + } + + gpuEventDestroy(start); + gpuEventDestroy(stop); + gpuStreamDestroy(stream); + + return 0; +} From ea3791ebaf8f5447595db5cf515c65d902748613 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Mon, 4 Aug 2025 14:32:51 -0400 Subject: [PATCH 009/109] Completed Mechanical HIPification changes. --- .../attention/generic/frag_layout_swizzle.cuh | 18 +++-- .../flashinfer/attention/generic/prefill.cuh | 78 +++++++++---------- .../include/gpu_iface/gpu_runtime_compat.hpp | 21 ++++- libflashinfer/include/gpu_iface/macros.hpp | 2 +- .../include/gpu_iface/vec_dtypes.hpp | 1 + 5 files changed, 74 insertions(+), 46 deletions(-) diff --git a/libflashinfer/include/flashinfer/attention/generic/frag_layout_swizzle.cuh b/libflashinfer/include/flashinfer/attention/generic/frag_layout_swizzle.cuh index 9f146a1341..23a1351a62 100644 --- a/libflashinfer/include/flashinfer/attention/generic/frag_layout_swizzle.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/frag_layout_swizzle.cuh @@ -10,11 +10,19 @@ #include +// Define platform-specific full mask for warp/wavefront operations +#if defined(PLATFORM_CUDA_DEVICE) +constexpr uint32_t WARP_FULL_MASK = 0xffffffff; // 32-bit mask for CUDA +#elif defined(PLATFORM_HIP_DEVICE) +constexpr uint64_t WARP_FULL_MASK = + 0xffffffffffffffffULL; // 64-bit mask for HIP +#endif + __device__ __forceinline__ uint32_t frag_layout_swizzle_16b_to_8b(uint32_t x) { - uint32_t tmp = __shfl_xor_sync(0xffffffff, x, 0x1); + uint32_t tmp = __shfl_xor_sync(WARP_FULL_MASK, x, 0x1); x = __byte_perm(x, tmp, ((threadIdx.x & 0x1) == 0) ? 0x5410 : 0x3276); - tmp = __shfl_xor_sync(0xffffffff, x, 0x2); + tmp = __shfl_xor_sync(WARP_FULL_MASK, x, 0x2); x = __byte_perm(x, tmp, ((threadIdx.x & 0x2) == 0) ? 0x5410 : 0x3276); return x; } @@ -22,11 +30,11 @@ __device__ __forceinline__ uint32_t frag_layout_swizzle_16b_to_8b(uint32_t x) __device__ __forceinline__ uint32_t frag_layout_swizzle_16b_to_8b_trans(uint32_t x) { - uint32_t tmp = __shfl_xor_sync(0xffffffff, x, 0x4); + uint32_t tmp = __shfl_xor_sync(WARP_FULL_MASK, x, 0x4); x = __byte_perm(x, tmp, ((threadIdx.x & 0x4) == 0) ? 0x6420 : 0x3175); - tmp = __shfl_xor_sync(0xffffffff, x, 0x8); + tmp = __shfl_xor_sync(WARP_FULL_MASK, x, 0x8); x = __byte_perm(x, tmp, ((threadIdx.x & 0x8) == 0) ? 0x5410 : 0x3276); - tmp = __shfl_xor_sync(0xffffffff, x, 0x10); + tmp = __shfl_xor_sync(WARP_FULL_MASK, x, 0x10); x = __byte_perm(x, tmp, ((threadIdx.x & 0x10) == 0) ? 0x5410 : 0x3276); return x; } diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh index a4964c8714..6cfec1375d 100644 --- a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -374,8 +374,8 @@ produce_kv(smem_t smem, for (uint32_t i = 0; i < NUM_MMA_KV * 4 / NUM_WARPS_Q; ++i) { #pragma unroll for (uint32_t j = 0; j < NUM_MMA_D / (8 / sizeof(DTypeKV)); ++j) { - smem.load_128b_async(*smem_offset, *gptr, - kv_idx < kv_len); + smem.template load_128b_async(*smem_offset, *gptr, + kv_idx < kv_len); *smem_offset = smem.template advance_offset_by_column( *smem_offset, j); @@ -402,8 +402,8 @@ produce_kv(smem_t smem, static_assert(NUM_MMA_KV * 2 % NUM_WARPS_Q == 0); #pragma unroll for (uint32_t i = 0; i < NUM_MMA_KV * 2 / NUM_WARPS_Q; ++i) { - smem.load_128b_async(*smem_offset, *gptr, - kv_idx < kv_len); + smem.template load_128b_async(*smem_offset, *gptr, + kv_idx < kv_len); *smem_offset = smem.template advance_offset_by_row( *smem_offset); @@ -451,8 +451,8 @@ page_produce_kv(smem_t smem, : paged_kv.k_data + thr_local_kv_offset[i]; #pragma unroll for (uint32_t j = 0; j < NUM_MMA_D / (8 / sizeof(DType)); ++j) { - smem.load_128b_async(*smem_offset, gptr, - kv_idx < kv_len); + smem.template load_128b_async(*smem_offset, gptr, + kv_idx < kv_len); *smem_offset = smem.template advance_offset_by_column<8>(*smem_offset, j); gptr += 8 * upcast_size(); @@ -474,8 +474,8 @@ page_produce_kv(smem_t smem, for (uint32_t i = 0; i < NUM_MMA_KV * 2 / NUM_WARPS_Q; ++i) { DType *gptr = produce_v ? paged_kv.v_data + thr_local_kv_offset[i] : paged_kv.k_data + thr_local_kv_offset[i]; - smem.load_128b_async(*smem_offset, gptr, - kv_idx < kv_len); + smem.template load_128b_async(*smem_offset, gptr, + kv_idx < kv_len); kv_idx += NUM_WARPS * 8; *smem_offset = smem.template advance_offset_by_row( @@ -561,7 +561,7 @@ load_q_global_smem(uint32_t packed_offset, lane_idx / WARP_STEP_SIZE; uint32_t col = lane_idx % WARP_STEP_SIZE; uint32_t q_smem_offset_w = - q_smem->get_permuted_offset(row, col); + q_smem->template get_permuted_offset(row, col); #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { @@ -580,8 +580,9 @@ load_q_global_smem(uint32_t packed_offset, ++mma_do) { // load q fragment from gmem to smem - q_smem->load_128b_async( - q_smem_offset_w, q_ptr, q_idx < qo_upper_bound); + q_smem + ->template load_128b_async( + q_smem_offset_w, q_ptr, q_idx < qo_upper_bound); q_smem_offset_w = q_smem ->template advance_offset_by_column( @@ -841,8 +842,8 @@ compute_qk(smem_t *q_smem, b_frag_f8[0] = frag_layout_swizzle_16b_to_8b(b_frag_f8[0]); b_frag_f8[1] = frag_layout_swizzle_16b_to_8b(b_frag_f8[1]); vec_cast:: - cast<8>((typename KTraits::DTypeQ *)b_frag, - (typename KTraits::DTypeKV *)b_frag_f8); + template cast<8>((typename KTraits::DTypeQ *)b_frag, + (typename KTraits::DTypeKV *)b_frag_f8); } else { k_smem->ldmatrix_m8n8x4(*k_smem_offset_r, b_frag); @@ -1113,6 +1114,8 @@ __device__ __forceinline__ void update_mdo_states( } } else if constexpr (std::is_same_v) { +#warning "Not implemented yet ...." +#if 0 const half2 sm_scale = __float2half2_rn(variant.sm_scale_log2); #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { @@ -1169,6 +1172,7 @@ __device__ __forceinline__ void update_mdo_states( } } } +#endif } } } @@ -1190,7 +1194,7 @@ compute_sfm_v(smem_t *v_smem, for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { #pragma unroll for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { - vec_cast::cast<8>( + vec_cast::template cast<8>( s_frag_f16[mma_q][mma_kv], s_frag[mma_q][mma_kv]); } } @@ -1239,7 +1243,7 @@ compute_sfm_v(smem_t *v_smem, frag_layout_swizzle_16b_to_8b_trans(b_frag_f8[0]); b_frag_f8[1] = frag_layout_swizzle_16b_to_8b_trans(b_frag_f8[1]); - vec_cast:: + vec_cast::template cast<8>((typename KTraits::DTypeQ *)b_frag, (typename KTraits::DTypeKV *)b_frag_f8); swap(b_frag[1], b_frag[2]); @@ -1552,19 +1556,19 @@ write_o_reg_gmem(float (*o_frag)[KTraits::NUM_MMA_D_VO][8], for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { uint32_t o_frag_f16[8 / 2]; - vec_cast::cast<8>((DTypeO *)o_frag_f16, - o_frag[mma_q][mma_d]); + vec_cast::template cast<8>( + (DTypeO *)o_frag_f16, o_frag[mma_q][mma_d]); #ifdef FLASHINFER_STMATRIX_M8N8X4_ENABLED uint32_t o_smem_offset_w = - o_smem->get_permuted_offset( + o_smem->template get_permuted_offset( (warp_idx_x * KTraits::NUM_MMA_Q + mma_q) * 16 + lane_idx % 16, mma_d * 2 + lane_idx / 16); o_smem->stmatrix_m8n8x4(o_smem_offset_w, o_frag_f16); #else uint32_t o_smem_offset_w = - o_smem->get_permuted_offset( + o_smem->template get_permuted_offset( (warp_idx_x * KTraits::NUM_MMA_Q + mma_q) * 16 + lane_idx / 4, mma_d * 2); @@ -1585,7 +1589,7 @@ write_o_reg_gmem(float (*o_frag)[KTraits::NUM_MMA_D_VO][8], } uint32_t o_smem_offset_w = - o_smem->get_permuted_offset( + o_smem->template get_permuted_offset( warp_idx_x * KTraits::NUM_MMA_Q * 16 + lane_idx / 8, lane_idx % 8); @@ -1763,9 +1767,10 @@ SinglePrefillWithKVCacheDevice(const Params params, (kv_head_idx * group_size) * o_stride_h : o + (kv_head_idx * group_size) * o_stride_h; - uint32_t q_smem_offset_r = qo_smem.get_permuted_offset( - get_warp_idx_q(tid.y) * NUM_MMA_Q * 16 + lane_idx % 16, - lane_idx / 16); + uint32_t q_smem_offset_r = + qo_smem.template get_permuted_offset( + get_warp_idx_q(tid.y) * NUM_MMA_Q * 16 + lane_idx % 16, + lane_idx / 16); load_q_global_smem(qo_packed_idx_base, qo_len, q_ptr_base, q_stride_n, q_stride_h, group_size, @@ -2023,10 +2028,7 @@ gpuError_t SinglePrefillWithKVCacheDispatched(Params params, int dev_id = 0; FI_GPU_CALL(gpuGetDevice(&dev_id)); - int max_smem_per_sm = 0; - FI_GPU_CALL(gpuDeviceGetAttribute( - &max_smem_per_sm, gpuDevAttrMaxSharedMemoryPerMultiProcessor, - dev_id)); + int max_smem_per_sm = getMaxSharedMemPerMultiprocessor(dev_id); // we expect each sm execute two threadblocks const int num_ctas_per_sm = max_smem_per_sm >= 2 * (CTA_TILE_Q * HEAD_DIM_QK * sizeof(DTypeQ) + @@ -2281,9 +2283,10 @@ __launch_bounds__(KTraits::NUM_THREADS) void BatchPrefillWithRaggedKVCacheKernel : o + o_indptr[request_idx] * o_stride_n + (kv_head_idx * group_size) * o_stride_h; - uint32_t q_smem_offset_r = qo_smem.get_permuted_offset( - get_warp_idx_q(tid.y) * NUM_MMA_Q * 16 + lane_idx % 16, - lane_idx / 16); + uint32_t q_smem_offset_r = + qo_smem.template get_permuted_offset( + get_warp_idx_q(tid.y) * NUM_MMA_Q * 16 + lane_idx % 16, + lane_idx / 16); load_q_global_smem(qo_packed_idx_base, qo_upper_bound, q_ptr_base, q_stride_n, q_stride_h, @@ -2642,9 +2645,10 @@ __device__ __forceinline__ void BatchPrefillWithPagedKVCacheDevice( (kv_head_idx * group_size) * o_stride_h : o + o_indptr[request_idx] * o_stride_n + (kv_head_idx * group_size) * o_stride_h; - uint32_t q_smem_offset_r = qo_smem.get_permuted_offset( - get_warp_idx_q(tid.y) * NUM_MMA_Q * 16 + lane_idx % 16, - lane_idx / 16); + uint32_t q_smem_offset_r = + qo_smem.template get_permuted_offset( + get_warp_idx_q(tid.y) * NUM_MMA_Q * 16 + lane_idx % 16, + lane_idx / 16); load_q_global_smem(qo_packed_idx_base, qo_upper_bound, q_ptr_base, q_stride_n, q_stride_h, @@ -2952,9 +2956,7 @@ BatchPrefillWithRaggedKVCacheDispatched(Params params, int dev_id = 0; FI_GPU_CALL(gpuGetDevice(&dev_id)); - int max_smem_per_sm = 0; - FI_GPU_CALL(gpuDeviceGetAttribute( - &max_smem_per_sm, gpuDevAttrMaxSharedMemoryPerMultiProcessor, dev_id)); + int max_smem_per_sm = getMaxSharedMemPerMultiprocessor(dev_id); // we expect each sm execute two threadblocks const int num_ctas_per_sm = max_smem_per_sm >= 2 * (CTA_TILE_Q * HEAD_DIM_QK * sizeof(DTypeQ) + @@ -3083,9 +3085,7 @@ BatchPrefillWithPagedKVCacheDispatched(Params params, int dev_id = 0; FI_GPU_CALL(gpuGetDevice(&dev_id)); - int max_smem_per_sm = 0; - FI_GPU_CALL(gpuDeviceGetAttribute( - &max_smem_per_sm, gpuDevAttrMaxSharedMemoryPerMultiProcessor, dev_id)); + int max_smem_per_sm = getMaxSharedMemPerMultiprocessor(dev_id); // we expect each sm execute two threadblocks const int num_ctas_per_sm = max_smem_per_sm >= 2 * (CTA_TILE_Q * HEAD_DIM_QK * sizeof(DTypeQ) + diff --git a/libflashinfer/include/gpu_iface/gpu_runtime_compat.hpp b/libflashinfer/include/gpu_iface/gpu_runtime_compat.hpp index 9afb9964fa..5c46f62602 100644 --- a/libflashinfer/include/gpu_iface/gpu_runtime_compat.hpp +++ b/libflashinfer/include/gpu_iface/gpu_runtime_compat.hpp @@ -40,7 +40,8 @@ #elif defined(PLATFORM_HIP_DEVICE) #define gpuGetDevice hipGetDevice #define gpuLaunchKernel hipLaunchKernel -#define gpuFuncSetAttribute hipFuncSetAttribute +#define gpuFuncSetAttribute(func, attr, val) \ + hipFuncSetAttribute(reinterpret_cast(func), attr, val) #define gpuDeviceGetAttribute hipDeviceGetAttribute #define gpuDeviceSynchronize hipDeviceSynchronize #endif @@ -48,6 +49,7 @@ #if defined(PLATFORM_CUDA_DEVICE) #define gpuMemcpy cudaMemcpy #define gpuMalloc cudaMalloc +#define gpuMemset cudaMemset #define gpFree cudaFree #define gpuMemCpyAsync cudaMemcpyAsync #define gpuMemcpyHostToDevice cudaMemcpyHostToDevice @@ -55,6 +57,7 @@ #elif defined(PLATFORM_HIP_DEVICE) #define gpuMemcpy hipMemcpy #define gpuMalloc hipMalloc +#define gpuMemset hipMemset #define gpuFree hipFree #define gpuMemcpyAsync hipMemcpyAsync #define gpuMemcpyHostToDevice hipMemcpyHostToDevice @@ -129,6 +132,7 @@ #define gpuLaunchConfig_t hipLaunchConfig_t #define gpuLaunchAttribute hipLaunchAttribute #endif + // CUDA error checking macro (replaces FLASHINFER_CUDA_CALL) #define FI_GPU_CALL(call) \ do { \ @@ -140,3 +144,18 @@ throw std::runtime_error(err_msg.str()); \ } \ } while (0) + +inline int getMaxSharedMemPerMultiprocessor(int dev_id) +{ + int max_smem_per_sm = 0; +#if defined(PLATFORM_CUDA_DEVICE) + FI_GPU_CALL(gpuDeviceGetAttribute( + &max_smem_per_sm, gpuDevAttrMaxSharedMemoryPerMultiProcessor, dev_id)); +#elif defined(PLATFORM_HIP_DEVICE) + hipDeviceProp_t deviceProp; + FI_GPU_CALL(hipGetDeviceProperties(&deviceProp, dev_id)); + max_smem_per_sm = deviceProp.sharedMemPerMultiprocessor; +#endif + + return max_smem_per_sm; +} diff --git a/libflashinfer/include/gpu_iface/macros.hpp b/libflashinfer/include/gpu_iface/macros.hpp index 91295d35e1..5c26acab77 100644 --- a/libflashinfer/include/gpu_iface/macros.hpp +++ b/libflashinfer/include/gpu_iface/macros.hpp @@ -12,7 +12,7 @@ #define PLATFORM_HIP_DEVICE // FIXME: Temporarily setting __forceinline__ to inline as amdclang++ 6.4 throws // an error when __forceinline__ is used. -// #define __forceinline__ inline +#define __forceinline__ inline #define __grid_constant__ #elif defined(__CUDACC__) || defined(__CUDA_ARCH__) #define PLATFORM_CUDA_DEVICE diff --git a/libflashinfer/include/gpu_iface/vec_dtypes.hpp b/libflashinfer/include/gpu_iface/vec_dtypes.hpp index 302f42f9ad..3a92de2c05 100644 --- a/libflashinfer/include/gpu_iface/vec_dtypes.hpp +++ b/libflashinfer/include/gpu_iface/vec_dtypes.hpp @@ -26,6 +26,7 @@ namespace detail_t = flashinfer::gpu_iface::vec_dtypes::detail::hip; // Re-export types and functions from the appropriate backend // This allows code to use flashinfer::gpu_iface::vec_dtypes::vec_t +using detail_t::vec_cast; using detail_t::vec_t; } // namespace vec_dtypes From 3e76b2c6bf60b18874197fa7205c806b173f0fbc Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Mon, 4 Aug 2025 14:34:42 -0400 Subject: [PATCH 010/109] Updated standalone example --- examples/cpp/standalone_single_prefill.cu | 357 ++++++++++++---------- 1 file changed, 201 insertions(+), 156 deletions(-) diff --git a/examples/cpp/standalone_single_prefill.cu b/examples/cpp/standalone_single_prefill.cu index d35c290ab6..345f87f401 100644 --- a/examples/cpp/standalone_single_prefill.cu +++ b/examples/cpp/standalone_single_prefill.cu @@ -1,27 +1,25 @@ -#include -#include -#include -#include - #include #include #include #include -#include -#include -#include -#include -#include -#include -#include -#include -#include #include #include #include +#include #include #include +// GPU interface headers +#include +#include +#include +#include +#include +#include +#include +#include +#include + namespace flashinfer { @@ -199,24 +197,49 @@ std::vector single_mha(const std::vector &q, } // namespace reference -// Function to validate GPU results against CPU reference -bool validate_results(const thrust::host_vector &gpu_output, +// Helper function to generate random data (without Thrust) +void generate_random_data(half *data, + size_t size, + float min_val = -1.0f, + float max_val = 1.0f) +{ + std::vector host_data(size); + std::mt19937 rng(42); // Fixed seed for reproducibility + std::uniform_real_distribution dist(min_val, max_val); + + for (size_t i = 0; i < size; ++i) { + host_data[i] = static_cast(dist(rng)); + } + + // Copy to device + FI_GPU_CALL(gpuMemcpy(data, host_data.data(), size * sizeof(half), + gpuMemcpyHostToDevice)); +} + +// Function to validate GPU results against CPU reference (simplified) +bool validate_results(const half *gpu_output, + size_t gpu_size, const std::vector &cpu_output, float rtol = 1e-3f, float atol = 1e-3f) { - if (gpu_output.size() != cpu_output.size()) { - std::cerr << "Size mismatch: GPU=" << gpu_output.size() + if (gpu_size != cpu_output.size()) { + std::cerr << "Size mismatch: GPU=" << gpu_size << " vs CPU=" << cpu_output.size() << std::endl; return false; } + // Copy GPU data to host for comparison + std::vector host_output(gpu_size); + FI_GPU_CALL(gpuMemcpy(host_output.data(), gpu_output, + gpu_size * sizeof(half), gpuMemcpyDeviceToHost)); + int errors = 0; float max_diff = 0.0f; float max_rel_diff = 0.0f; - for (size_t i = 0; i < gpu_output.size(); ++i) { - float gpu_val = static_cast(gpu_output[i]); + for (size_t i = 0; i < gpu_size; ++i) { + float gpu_val = static_cast(host_output[i]); float cpu_val = static_cast(cpu_output[i]); float abs_diff = std::abs(gpu_val - cpu_val); float rel_diff = @@ -236,17 +259,16 @@ bool validate_results(const thrust::host_vector &gpu_output, } } - float error_rate = static_cast(errors) / gpu_output.size(); + float error_rate = static_cast(errors) / gpu_size; std::cout << "\nValidation Results:" << std::endl; std::cout << " Max absolute difference: " << max_diff << std::endl; std::cout << " Max relative difference: " << max_rel_diff << std::endl; std::cout << " Error rate: " << (error_rate * 100) << "% (" << errors - << " / " << gpu_output.size() << " elements)" << std::endl; + << " / " << gpu_size << " elements)" << std::endl; std::cout << " Status: " << (error_rate < 0.05 ? "PASSED" : "FAILED") << std::endl; - // Allow up to 5% error rate (similar to the threshold used in the unit - // tests) + // Allow up to 5% error rate return error_rate < 0.05; } @@ -300,23 +322,6 @@ public: } }; -// Helper function to generate random data on device -void generate_random_data(thrust::device_vector &data, - float min_val = -1.0f, - float max_val = 1.0f) -{ - thrust::host_vector host_data(data.size()); - - thrust::default_random_engine rng(42); // Fixed seed for reproducibility - thrust::uniform_real_distribution dist(min_val, max_val); - - for (size_t i = 0; i < host_data.size(); ++i) { - host_data[i] = static_cast(dist(rng)); - } - - data = host_data; -} - // Dispatch function for half precision gpuError_t dispatch_single_prefill(half *q_ptr, half *k_ptr, @@ -499,18 +504,26 @@ void print_usage(const char *program_name) "(default: 0)\n"; } +// Main function with simplified memory management int main(int argc, char *argv[]) { - // Default parameter values + if (argc > 1 && + (std::string(argv[1]) == "--help" || std::string(argv[1]) == "-h")) + { + print_usage(argv[0]); + return 0; + } + + // Process parameter pairs (--param value) uint32_t qo_len = 512; uint32_t kv_len = 512; uint32_t num_qo_heads = 32; uint32_t num_kv_heads = 32; uint32_t head_dim = 128; - bool causal = true; - bool use_fp16_qk_reduction = false; QKVLayout kv_layout = QKVLayout::kNHD; PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone; + bool causal = true; + bool use_fp16_qk_reduction = false; int32_t window_left = -1; float rope_scale = 1.0f; float rope_theta = 10000.0f; @@ -518,102 +531,131 @@ int main(int argc, char *argv[]) int warmup = 5; bool validate = false; - // Parse command-line arguments - for (int i = 1; i < argc; i++) { + for (int i = 1; i < argc; i += 2) { std::string arg = argv[i]; - if (arg == "--qo_len" && i + 1 < argc) - qo_len = std::atoi(argv[++i]); - else if (arg == "--kv_len" && i + 1 < argc) - kv_len = std::atoi(argv[++i]); - else if (arg == "--num_qo_heads" && i + 1 < argc) - num_qo_heads = std::atoi(argv[++i]); - else if (arg == "--num_kv_heads" && i + 1 < argc) - num_kv_heads = std::atoi(argv[++i]); - else if (arg == "--head_dim" && i + 1 < argc) - head_dim = std::atoi(argv[++i]); - else if (arg == "--causal" && i + 1 < argc) - causal = ArgParser::get_bool(argv[++i], true); - else if (arg == "--use_fp16_qk" && i + 1 < argc) - use_fp16_qk_reduction = ArgParser::get_bool(argv[++i], false); - else if (arg == "--layout" && i + 1 < argc) - kv_layout = ArgParser::get_layout(argv[++i]); - else if (arg == "--pos_encoding" && i + 1 < argc) - pos_encoding_mode = ArgParser::get_pos_encoding_mode(argv[++i]); - else if (arg == "--window_left" && i + 1 < argc) - window_left = std::atoi(argv[++i]); - else if (arg == "--rope_scale" && i + 1 < argc) - rope_scale = std::atof(argv[++i]); - else if (arg == "--rope_theta" && i + 1 < argc) - rope_theta = std::atof(argv[++i]); - else if (arg == "--iterations" && i + 1 < argc) - iterations = std::atoi(argv[++i]); - else if (arg == "--warmup" && i + 1 < argc) - warmup = std::atoi(argv[++i]); - else if (arg == "--validate" && i + 1 < argc) - validate = ArgParser::get_bool(argv[++i], false); - else if (arg == "--help") { + if (i + 1 >= argc && arg != "--help" && arg != "-h") { + std::cerr << "Missing value for parameter " << arg << std::endl; print_usage(argv[0]); - return 0; + return 1; } - } - // Verify that num_qo_heads is divisible by num_kv_heads - if (num_qo_heads % num_kv_heads != 0) { - std::cerr << "Error: num_qo_heads must be divisible by num_kv_heads" - << std::endl; - return 1; + if (arg == "--qo_len") { + qo_len = ArgParser::get_int(argv[i + 1], 512); + } + else if (arg == "--kv_len") { + kv_len = ArgParser::get_int(argv[i + 1], 512); + } + else if (arg == "--num_qo_heads") { + num_qo_heads = ArgParser::get_int(argv[i + 1], 32); + } + else if (arg == "--num_kv_heads") { + num_kv_heads = ArgParser::get_int(argv[i + 1], 32); + } + else if (arg == "--head_dim") { + head_dim = ArgParser::get_int(argv[i + 1], 128); + } + else if (arg == "--layout") { + kv_layout = ArgParser::get_layout(argv[i + 1]); + } + else if (arg == "--pos_encoding") { + pos_encoding_mode = ArgParser::get_pos_encoding_mode(argv[i + 1]); + } + else if (arg == "--causal") { + causal = ArgParser::get_bool(argv[i + 1], true); + } + else if (arg == "--use_fp16_qk") { + use_fp16_qk_reduction = ArgParser::get_bool(argv[i + 1], false); + } + else if (arg == "--window_left") { + window_left = ArgParser::get_int(argv[i + 1], -1); + } + else if (arg == "--rope_scale") { + rope_scale = ArgParser::get_float(argv[i + 1], 1.0f); + } + else if (arg == "--rope_theta") { + rope_theta = ArgParser::get_float(argv[i + 1], 10000.0f); + } + else if (arg == "--iterations") { + iterations = ArgParser::get_int(argv[i + 1], 10); + } + else if (arg == "--warmup") { + warmup = ArgParser::get_int(argv[i + 1], 5); + } + else if (arg == "--validate") { + validate = ArgParser::get_bool(argv[i + 1], false); + } + else { + std::cerr << "Unknown parameter: " << arg << std::endl; + print_usage(argv[0]); + return 1; + } } - // Display configuration - std::cout << "Configuration:" << std::endl; - std::cout << " qo_len = " << qo_len << std::endl; - std::cout << " kv_len = " << kv_len << std::endl; - std::cout << " num_qo_heads = " << num_qo_heads << std::endl; - std::cout << " num_kv_heads = " << num_kv_heads << std::endl; - std::cout << " head_dim = " << head_dim << std::endl; - std::cout << " kv_layout = " - << (kv_layout == QKVLayout::kNHD ? "NHD" : "HND") << std::endl; - std::cout << " causal = " << (causal ? "true" : "false") << std::endl; - std::cout << " data_type = half" << std::endl; - std::cout << " use_fp16_qk_reduction = " - << (use_fp16_qk_reduction ? "true" : "false") << std::endl; - std::cout << " validate = " << (validate ? "true" : "false") << std::endl; - - // Initialize and create stream + // Print configuration + std::cout << "Configuration:" << std::endl + << " QO Length: " << qo_len << std::endl + << " KV Length: " << kv_len << std::endl + << " QO Heads: " << num_qo_heads << std::endl + << " KV Heads: " << num_kv_heads << std::endl + << " Head Dimension: " << head_dim << std::endl + << " KV Layout: " + << (kv_layout == QKVLayout::kNHD ? "NHD" : "HND") << std::endl + << " Position Encoding: " + << (pos_encoding_mode == PosEncodingMode::kNone ? "None" + : pos_encoding_mode == PosEncodingMode::kRoPELlama ? "RoPE" + : "ALiBi") + << std::endl + << " Causal: " << (causal ? "Yes" : "No") << std::endl + << " Use FP16 QK Reduction: " + << (use_fp16_qk_reduction ? "Yes" : "No") << std::endl + << " Window Left: " << window_left << std::endl + << " RoPE Scale: " << rope_scale << std::endl + << " RoPE Theta: " << rope_theta << std::endl + << " Iterations: " << iterations << std::endl + << " Warmup: " << warmup << std::endl + << " Validation: " << (validate ? "Yes" : "No") << std::endl; + + // Create stream gpuStream_t stream; - gpuStreamCreate(&stream); - - // Allocate device memory using Thrust - only for half precision - thrust::device_vector q(qo_len * num_qo_heads * head_dim); - thrust::device_vector k(kv_len * num_kv_heads * head_dim); - thrust::device_vector v(kv_len * num_kv_heads * head_dim); - thrust::device_vector o(qo_len * num_qo_heads * head_dim); - thrust::device_vector tmp(qo_len * num_qo_heads * head_dim); - thrust::device_vector lse(qo_len * num_qo_heads); - - // Generate random data - generate_random_data(q); - generate_random_data(k); - generate_random_data(v); - thrust::fill(o.begin(), o.end(), half(0.0f)); - thrust::fill(tmp.begin(), tmp.end(), half(0.0f)); - thrust::fill(lse.begin(), lse.end(), 0.0f); - - // Calculate SM scale if not provided + FI_GPU_CALL(gpuStreamCreate(&stream)); + + // Allocate device memory using gpuMalloc instead of Thrust + half *q_dev, *k_dev, *v_dev, *o_dev, *tmp_dev; + float *lse_dev; + + size_t q_size = qo_len * num_qo_heads * head_dim; + size_t k_size = kv_len * num_kv_heads * head_dim; + size_t v_size = kv_len * num_kv_heads * head_dim; + size_t o_size = qo_len * num_qo_heads * head_dim; + size_t lse_size = qo_len * num_qo_heads; + + FI_GPU_CALL(gpuMalloc(&q_dev, q_size * sizeof(half))); + FI_GPU_CALL(gpuMalloc(&k_dev, k_size * sizeof(half))); + FI_GPU_CALL(gpuMalloc(&v_dev, v_size * sizeof(half))); + FI_GPU_CALL(gpuMalloc(&o_dev, o_size * sizeof(half))); + FI_GPU_CALL(gpuMalloc(&tmp_dev, o_size * sizeof(half))); + FI_GPU_CALL(gpuMalloc(&lse_dev, lse_size * sizeof(float))); + + // Initialize data + generate_random_data(q_dev, q_size); + generate_random_data(k_dev, k_size); + generate_random_data(v_dev, v_size); + + // Zero out output arrays + FI_GPU_CALL(gpuMemset(o_dev, 0, o_size * sizeof(half))); + FI_GPU_CALL(gpuMemset(tmp_dev, 0, o_size * sizeof(half))); + FI_GPU_CALL(gpuMemset(lse_dev, 0, lse_size * sizeof(float))); + + // Calculate SM scale float sm_scale = 1.0f / std::sqrt(static_cast(head_dim)); - // Warm-up runs + // Warmup runs for (int i = 0; i < warmup; ++i) { gpuError_t status = dispatch_single_prefill( - thrust::raw_pointer_cast(q.data()), - thrust::raw_pointer_cast(k.data()), - thrust::raw_pointer_cast(v.data()), - thrust::raw_pointer_cast(o.data()), - thrust::raw_pointer_cast(tmp.data()), - thrust::raw_pointer_cast(lse.data()), num_qo_heads, num_kv_heads, - qo_len, kv_len, head_dim, kv_layout, pos_encoding_mode, causal, - use_fp16_qk_reduction, sm_scale, window_left, rope_scale, - rope_theta, stream); + q_dev, k_dev, v_dev, o_dev, tmp_dev, lse_dev, num_qo_heads, + num_kv_heads, qo_len, kv_len, head_dim, kv_layout, + pos_encoding_mode, causal, use_fp16_qk_reduction, sm_scale, + window_left, rope_scale, rope_theta, stream); if (status != gpuSuccess) { std::cerr << "Error during warmup: " << gpuGetErrorString(status) @@ -624,22 +666,17 @@ int main(int argc, char *argv[]) // Timing runs gpuEvent_t start, stop; - gpuEventCreate(&start); - gpuEventCreate(&stop); + FI_GPU_CALL(gpuEventCreate(&start)); + FI_GPU_CALL(gpuEventCreate(&stop)); - gpuEventRecord(start, stream); + FI_GPU_CALL(gpuEventRecord(start, stream)); for (int i = 0; i < iterations; ++i) { gpuError_t status = dispatch_single_prefill( - thrust::raw_pointer_cast(q.data()), - thrust::raw_pointer_cast(k.data()), - thrust::raw_pointer_cast(v.data()), - thrust::raw_pointer_cast(o.data()), - thrust::raw_pointer_cast(tmp.data()), - thrust::raw_pointer_cast(lse.data()), num_qo_heads, num_kv_heads, - qo_len, kv_len, head_dim, kv_layout, pos_encoding_mode, causal, - use_fp16_qk_reduction, sm_scale, window_left, rope_scale, - rope_theta, stream); + q_dev, k_dev, v_dev, o_dev, tmp_dev, lse_dev, num_qo_heads, + num_kv_heads, qo_len, kv_len, head_dim, kv_layout, + pos_encoding_mode, causal, use_fp16_qk_reduction, sm_scale, + window_left, rope_scale, rope_theta, stream); if (status != gpuSuccess) { std::cerr << "Error during benchmark: " << gpuGetErrorString(status) @@ -648,14 +685,14 @@ int main(int argc, char *argv[]) } } - gpuEventRecord(stop, stream); - gpuEventSynchronize(stop); + FI_GPU_CALL(gpuEventRecord(stop, stream)); + FI_GPU_CALL(gpuEventSynchronize(stop)); float elapsed_ms; - gpuEventElapsedTime(&elapsed_ms, start, stop); + FI_GPU_CALL(gpuEventElapsedTime(&elapsed_ms, start, stop)); float avg_ms = elapsed_ms / iterations; - // Calculate FLOPS + // Calculate and report performance double flops = calculate_flops(qo_len, kv_len, num_qo_heads, head_dim, causal); double tflops = flops / (avg_ms * 1e-3) / 1e12; @@ -670,13 +707,14 @@ int main(int argc, char *argv[]) if (validate) { std::cout << "\nRunning validation..." << std::endl; - // Copy output from GPU to host for validation - thrust::host_vector h_output = o; - - // Create input data on host for CPU reference - std::vector h_q(q.begin(), q.end()); - std::vector h_k(k.begin(), k.end()); - std::vector h_v(v.begin(), v.end()); + // Copy input data to host for CPU reference + std::vector h_q(q_size), h_k(k_size), h_v(v_size); + FI_GPU_CALL(gpuMemcpy(h_q.data(), q_dev, q_size * sizeof(half), + gpuMemcpyHostToDevice)); + FI_GPU_CALL(gpuMemcpy(h_k.data(), k_dev, k_size * sizeof(half), + gpuMemcpyHostToDevice)); + FI_GPU_CALL(gpuMemcpy(h_v.data(), v_dev, v_size * sizeof(half), + gpuMemcpyHostToDevice)); // Compute reference output on CPU std::vector ref_output = reference::single_mha( @@ -684,16 +722,23 @@ int main(int argc, char *argv[]) causal, kv_layout, pos_encoding_mode, rope_scale, rope_theta); // Validate results - bool validation_passed = validate_results(h_output, ref_output); + bool validation_passed = validate_results(o_dev, o_size, ref_output); // Report validation status std::cout << "Validation " << (validation_passed ? "PASSED" : "FAILED") << std::endl; } - gpuEventDestroy(start); - gpuEventDestroy(stop); - gpuStreamDestroy(stream); + // Cleanup + FI_GPU_CALL(gpuEventDestroy(start)); + FI_GPU_CALL(gpuEventDestroy(stop)); + FI_GPU_CALL(gpuStreamDestroy(stream)); + FI_GPU_CALL(gpuFree(q_dev)); + FI_GPU_CALL(gpuFree(k_dev)); + FI_GPU_CALL(gpuFree(v_dev)); + FI_GPU_CALL(gpuFree(o_dev)); + FI_GPU_CALL(gpuFree(tmp_dev)); + FI_GPU_CALL(gpuFree(lse_dev)); return 0; } From 0ed875a82a09ee3dba519d6007cac6da6db3e726 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Mon, 4 Aug 2025 21:36:32 -0400 Subject: [PATCH 011/109] Updated load_q_global_smem. --- .../flashinfer/attention/generic/prefill.cuh | 68 +++++++++++-------- 1 file changed, 39 insertions(+), 29 deletions(-) diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh index 6cfec1375d..78a3035ec7 100644 --- a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -39,10 +39,16 @@ using gpu_iface::vec_dtypes::vec_cast; // using mma::MMAMode; constexpr uint32_t WARP_SIZE = gpu_iface::kWarpSize; + +// Defines thread layouts that are specific to AMD CDNA3 or Nvidia warp sizes. #if defined(PLATFORM_HIP_DEVICE) -constexpr uint32_t WARP_STEP_SIZE = 16; +constexpr uint32_t WARP_THREAD_COLS = 16; +constexpr uint32_t WARP_THREAD_ROWS = 4; +constexpr uint32_t HP_QUERY_ELEMS_PER_THREAD = 4; #else -constexpr uint32_t WARP_STEP_SIZE = 8; // NVIDIA +constexpr uint32_t WARP_THREAD_COLS = 8; +constexpr uint32_t WARP_THREAD_ROWS = 4; +constexpr uint32_t HP_QUERY_ELEMS_PER_THREAD = 8; #endif constexpr uint32_t get_num_warps_q(const uint32_t cta_tile_q) @@ -365,8 +371,7 @@ produce_kv(smem_t smem, lane_idx = tid.x; if constexpr (KTraits::SWIZZLE_MODE_KV == SwizzleMode::k128B) { - uint32_t kv_idx = - kv_idx_base + warp_idx * 4 + lane_idx / WARP_STEP_SIZE; + uint32_t kv_idx = kv_idx_base + warp_idx * 4 + lane_idx / 8; // NOTE: NUM_MMA_KV * 4 / NUM_WARPS_Q = NUM_WARPS_KV * NUM_MMA_KV * 4 / // num_warps static_assert(NUM_MMA_KV * 4 % NUM_WARPS_Q == 0); @@ -377,9 +382,8 @@ produce_kv(smem_t smem, smem.template load_128b_async(*smem_offset, *gptr, kv_idx < kv_len); *smem_offset = - smem.template advance_offset_by_column( - *smem_offset, j); - *gptr += WARP_STEP_SIZE * upcast_size(); + smem.template advance_offset_by_column<8>(*smem_offset, j); + *gptr += 8 * upcast_size(); } kv_idx += NUM_WARPS * 4; *smem_offset = smem.template advance_offset_by_row smem, *smem_offset -= CTA_TILE_KV * UPCAST_STRIDE; } else { -#if defined(PLATFORM_HIP_DEVICE) - static_assert(false, - "SwizzleMode::k64B is not supported on AMD/CDNA3."); -#else uint32_t kv_idx = kv_idx_base + warp_idx * 8 + lane_idx / 4; // NOTE: NUM_MMA_KV * 2 / NUM_WARPS_Q = NUM_WARPS_KV * NUM_MMA_KV * 2 / // num_warps @@ -411,7 +411,6 @@ produce_kv(smem_t smem, *gptr += NUM_WARPS * 8 * stride_n; } *smem_offset -= KTraits::CTA_TILE_KV * UPCAST_STRIDE; -#endif } } @@ -552,47 +551,58 @@ load_q_global_smem(uint32_t packed_offset, const dim3 tid = threadIdx) { using DTypeQ = typename KTraits::DTypeQ; +#if defined(PLATFORM_HIP_DEVICE) + constexpr uint32_t UPCAST_STRIDE_Q = + KTraits::HEAD_DIM_QK / HP_QUERY_ELEMS_PER_THREAD; + constexpr uint32_t COLUMN_RESET_OFFSET = + (NUM_MMA_D_QK / 4) * WARP_THREAD_COLS; +#else constexpr uint32_t UPCAST_STRIDE_Q = KTraits::UPCAST_STRIDE_Q; + constexpr uint32_t COLUMN_RESET_OFFSET = 2 * KTraits::NUM_MMA_D_QK; +#endif + const uint32_t lane_idx = tid.x, warp_idx_x = get_warp_idx_q(tid.y); + uint32_t row = tid / WARP_THREAD_COLS; + uint32_t col = tid % WARP_THREAD_COLS; if (get_warp_idx_kv(tid.z) == 0) { - uint32_t row = warp_idx_x * KTraits::NUM_MMA_Q * WARP_STEP_SIZE + - lane_idx / WARP_STEP_SIZE; - uint32_t col = lane_idx % WARP_STEP_SIZE; uint32_t q_smem_offset_w = - q_smem->template get_permuted_offset(row, col); + q_smem->template get_permuted_offset( + warp_idx_x * KTraits::NUM_MMA_Q * 16 + row, col); #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { #pragma unroll for (uint32_t j = 0; j < 2 * 2; ++j) { uint32_t q, r; - group_size.divmod(packed_offset + lane_idx / WARP_STEP_SIZE + - mma_q * 16 + j * 4, - q, r); + group_size.divmod(packed_offset + row + mma_q * 16 + j * 4, q, + r); const uint32_t q_idx = q; - DTypeQ *q_ptr = - q_ptr_base + q * q_stride_n + r * q_stride_h + - (lane_idx % WARP_STEP_SIZE) * upcast_size(); + DTypeQ *q_ptr = q_ptr_base + q * q_stride_n + r * q_stride_h + + col * upcast_size(); #pragma unroll for (uint32_t mma_do = 0; mma_do < KTraits::NUM_MMA_D_QK / 4; ++mma_do) { +#if defined(PLATFORM_HIP_DEVICE) // load q fragment from gmem to smem q_smem ->template load_128b_async( q_smem_offset_w, q_ptr, q_idx < qo_upper_bound); - q_smem_offset_w = - q_smem - ->template advance_offset_by_column( - q_smem_offset_w, mma_do); - q_ptr += WARP_STEP_SIZE * upcast_size(); +#else + q_smem->template load_64b_async( + q_smem_offset_w, q_ptr, q_idx < qo_upper_bound); +#endif + q_smem_offset_w = q_smem->template advance_offset_by_column< + WARP_THREAD_COLS>(q_smem_offset_w, mma_do); + q_ptr += HP_QUERY_ELEMS_PER_THREAD * upcast_size(); } q_smem_offset_w = - q_smem->template advance_offset_by_row<4, UPCAST_STRIDE_Q>( + q_smem->template advance_offset_by_row( q_smem_offset_w) - - 2 * KTraits::NUM_MMA_D_QK; + COLUMN_RESET_OFFSET; } } } From b7621c6e48f1c9a0cff2181e952e4b68f58a5547 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Tue, 5 Aug 2025 00:19:13 -0400 Subject: [PATCH 012/109] Port produce_kv to HIP --- .../flashinfer/attention/generic/prefill.cuh | 84 ++++++++++++++----- 1 file changed, 61 insertions(+), 23 deletions(-) diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh index 78a3035ec7..3181ba1a6a 100644 --- a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -335,6 +335,59 @@ q_frag_apply_llama_rope_with_pos(T *x_first_half, } } +template +__device__ __forceinline__ void produce_kv_helper_(uint32_t warp_idx, + uint32_t lane_idx) +{ + using DTypeKV = typename KTraits::DTypeKV; + constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; + constexpr uint32_t NUM_WARPS_Q = KTraits::NUM_WARPS_Q; + constexpr uint32_t NUM_MMA_D = + produce_v ? KTraits::NUM_MMA_D_VO : KTraits::NUM_MMA_D_QK; + +#if defined(PLATFORM_HIP_DEVICE) + constexpr uint32_t HEAD_DIM = + produce_v ? KTraits::HEAD_DIM_QK : KTraits::HEAD_DIM_VO; + constexpr uint32_t UPCAST_STRIDE = HEAD_DIM / HP_QUERY_ELEMS_PER_THREAD; + constexpr uint32_t COLUMN_RESET_OFFSET = (NUM_MMA_D / 4) * WARP_THREAD_COLS; +#else + constexpr uint32_t UPCAST_STRIDE = + produce_v ? KTraits::UPCAST_STRIDE_V : KTraits::UPCAST_STRIDE_K; + constexpr uint32_t COLUMN_RESET_OFFSET = sizeof(DTypeKV) * NUM_MMA_D; +#endif + + uint32_t row = lane_idx / WARP_THREAD_COLS; + uint32_t col = lane_idx % WARP_THREAD_COLS; + uint32_t kv_idx = kv_idx_base + warp_idx * WARP_THREAD_ROWS + row; + // NOTE: NUM_MMA_KV*4/NUM_WARPS_Q = NUM_WARPS_KV*NUM_MMA_KV*4/num_warps + static_assert(NUM_MMA_KV * 4 % NUM_WARPS_Q == 0); +#pragma unroll + for (uint32_t i = 0; i < NUM_MMA_KV * 4 / NUM_WARPS_Q; ++i) { +#pragma unroll + for (uint32_t j = 0; j < NUM_MMA_D / (8 / sizeof(DTypeKV)); ++j) { +#if defined(PLATFORM_HIP_DEVICE) + smem.template load_64b_async(*smem_offset, *gptr, + kv_idx < kv_len); +#else + smem.template load_128b_async(*smem_offset, *gptr, + kv_idx < kv_len); +#endif + *smem_offset = + smem.template advance_offset_by_column( + *smem_offset, j); + *gptr += 8 * upcast_size(); + } + kv_idx += NUM_WARPS * WARP_THREAD_ROWS; + *smem_offset = + smem.template advance_offset_by_row(*smem_offset) - + COLUMN_RESET_OFFSET; + *gptr += NUM_WARPS * WARP_THREAD_ROWS * stride_n - + sizeof(DTypeKV) * NUM_MMA_D * upcast_size(); + } + *smem_offset -= KTraits::CTA_TILE_KV * UPCAST_STRIDE; +} + /*! * \brief Produce k/v fragments from global memory to shared memory. * \tparam fill_mode The fill mode of the shared memory. @@ -371,30 +424,15 @@ produce_kv(smem_t smem, lane_idx = tid.x; if constexpr (KTraits::SWIZZLE_MODE_KV == SwizzleMode::k128B) { - uint32_t kv_idx = kv_idx_base + warp_idx * 4 + lane_idx / 8; - // NOTE: NUM_MMA_KV * 4 / NUM_WARPS_Q = NUM_WARPS_KV * NUM_MMA_KV * 4 / - // num_warps - static_assert(NUM_MMA_KV * 4 % NUM_WARPS_Q == 0); -#pragma unroll - for (uint32_t i = 0; i < NUM_MMA_KV * 4 / NUM_WARPS_Q; ++i) { -#pragma unroll - for (uint32_t j = 0; j < NUM_MMA_D / (8 / sizeof(DTypeKV)); ++j) { - smem.template load_128b_async(*smem_offset, *gptr, - kv_idx < kv_len); - *smem_offset = - smem.template advance_offset_by_column<8>(*smem_offset, j); - *gptr += 8 * upcast_size(); - } - kv_idx += NUM_WARPS * 4; - *smem_offset = smem.template advance_offset_by_row( - *smem_offset) - - sizeof(DTypeKV) * NUM_MMA_D; - *gptr += NUM_WARPS * 4 * stride_n - - sizeof(DTypeKV) * NUM_MMA_D * upcast_size(); - } - *smem_offset -= CTA_TILE_KV * UPCAST_STRIDE; + produce_kv_helper_(uint32_t warp_idx, + uint32_t lane_idx) } +#if defined(PLATFORM_HIP_DEVICE) + else if constexpr (KTraits::SWIZZLE_MODE_KV == SwizzleMode::kLinear) { + produce_kv_helper_(uint32_t warp_idx, + uint32_t lane_idx) + } +#endif else { uint32_t kv_idx = kv_idx_base + warp_idx * 8 + lane_idx / 4; // NOTE: NUM_MMA_KV * 2 / NUM_WARPS_Q = NUM_WARPS_KV * NUM_MMA_KV * 2 / From a29fe258c704c5e42754ffb24b4e1af164496956 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Tue, 5 Aug 2025 02:30:26 -0400 Subject: [PATCH 013/109] Update KernelTraits --- .../attention/generic/permuted_smem.cuh | 8 +- .../flashinfer/attention/generic/prefill.cuh | 89 +++++++++++-------- 2 files changed, 55 insertions(+), 42 deletions(-) diff --git a/libflashinfer/include/flashinfer/attention/generic/permuted_smem.cuh b/libflashinfer/include/flashinfer/attention/generic/permuted_smem.cuh index d5531b4333..ec94717da4 100644 --- a/libflashinfer/include/flashinfer/attention/generic/permuted_smem.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/permuted_smem.cuh @@ -37,15 +37,15 @@ using b64_t = uint2; * \brief Compute the number of elements that can be stored in a b128_t. * \tparam T The data type of the elements. */ -template +template constexpr __host__ __device__ __forceinline__ uint32_t upcast_size() { - static_assert(NumBits == 128 || NumBits == 64, + static_assert(VectorWidthBits == 128 || VectorWidthBits == 64, "Only 64 and 128 bits are supported"); - if constexpr (NumBits == 128) { + if constexpr (VectorWidthBits == 128) { return sizeof(b128_t) / sizeof(T); } - else if constexpr (NumBits == 64) { + else if constexpr (VectorWidthBits == 64) { return sizeof(b64_t) / sizeof(T); } } diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh index 3181ba1a6a..b5590f0ce3 100644 --- a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -40,17 +40,6 @@ using gpu_iface::vec_dtypes::vec_cast; constexpr uint32_t WARP_SIZE = gpu_iface::kWarpSize; -// Defines thread layouts that are specific to AMD CDNA3 or Nvidia warp sizes. -#if defined(PLATFORM_HIP_DEVICE) -constexpr uint32_t WARP_THREAD_COLS = 16; -constexpr uint32_t WARP_THREAD_ROWS = 4; -constexpr uint32_t HP_QUERY_ELEMS_PER_THREAD = 4; -#else -constexpr uint32_t WARP_THREAD_COLS = 8; -constexpr uint32_t WARP_THREAD_ROWS = 4; -constexpr uint32_t HP_QUERY_ELEMS_PER_THREAD = 8; -#endif - constexpr uint32_t get_num_warps_q(const uint32_t cta_tile_q) { if (cta_tile_q > 16) { @@ -133,37 +122,61 @@ struct KernelTraits static constexpr uint32_t NUM_MMA_D_VO = NUM_MMA_D_VO_; static constexpr uint32_t NUM_WARPS_Q = NUM_WARPS_Q_; static constexpr uint32_t NUM_WARPS_KV = NUM_WARPS_KV_; - static constexpr uint32_t NUM_THREADS = - NUM_WARPS_Q * NUM_WARPS_KV * WARP_SIZE; static constexpr uint32_t NUM_WARPS = NUM_WARPS_Q * NUM_WARPS_KV; static constexpr uint32_t HEAD_DIM_QK = NUM_MMA_D_QK * 16; static constexpr uint32_t HEAD_DIM_VO = NUM_MMA_D_VO * 16; - static constexpr uint32_t UPCAST_STRIDE_Q = - HEAD_DIM_QK / upcast_size(); - static constexpr uint32_t UPCAST_STRIDE_K = - HEAD_DIM_QK / upcast_size(); - static constexpr uint32_t UPCAST_STRIDE_V = - HEAD_DIM_VO / upcast_size(); - static constexpr uint32_t UPCAST_STRIDE_O = - HEAD_DIM_VO / upcast_size(); static constexpr uint32_t CTA_TILE_Q = CTA_TILE_Q_; static constexpr uint32_t CTA_TILE_KV = NUM_MMA_KV * NUM_WARPS_KV * 16; + static constexpr PosEncodingMode POS_ENCODING_MODE = POS_ENCODING_MODE_; + + using DTypeQ = DTypeQ_; + using DTypeKV = DTypeKV_; + using DTypeO = DTypeO_; + using DTypeQKAccum = DTypeQKAccum_; + using IdType = IdType_; + using AttentionVariant = AttentionVariant_; + +#if defined(PLATFORM_HIP_DEVICE) + static_assert( + sizeof(DTypeKV_) != 1, + "8-bit types not supported for CDNA3") static constexpr uint32_t + NUM_THREADS = NUM_WARPS_Q * NUM_WARPS_KV * 64; + constexpr uint32_t WARP_THREAD_COLS = 16; + constexpr uint32_t WARP_THREAD_ROWS = 4; + constexpr uint32_t HALF_ELEMS_PER_THREAD = 4; + constexpr uint32_t VECTOR_BIT_WIDTH = HALF_ELEMS_PER_THREAD * 16; + // FIXME: Update with a proper swizzle pattern. Linear is used primarily + // for intial testing. + static constexpr SwizzleMode SWIZZLE_MODE_Q = SwizzleMode::kLinear; + static constexpr SwizzleMode SWIZZLE_MODE_KV = SwizzleMode::kLinear; + // Presently we use 16x4 thread layout for all cases. + static constexpr uint32_t KV_THR_LAYOUT_ROW = WARP_THREAD_ROWS; + static constexpr uint32_t KV_THR_LAYOUT_COL = WARP_THREAD_COLS; +#else + static constexpr uint32_t NUM_THREADS = NUM_WARPS_Q * NUM_WARPS_KV * 32; + constexpr uint32_t WARP_THREAD_COLS = 8; + constexpr uint32_t WARP_THREAD_ROWS = 4; + constexpr uint32_t HALF_ELEMS_PER_THREAD = 8; + constexpr uint32_t VECTOR_BIT_WIDTH = HALF_ELEMS_PER_THREAD * 16; static constexpr SwizzleMode SWIZZLE_MODE_Q = SwizzleMode::k128B; static constexpr SwizzleMode SWIZZLE_MODE_KV = (sizeof(DTypeKV_) == 1 && HEAD_DIM_VO == 64) ? SwizzleMode::k64B : SwizzleMode::k128B; static constexpr uint32_t KV_THR_LAYOUT_ROW = - SWIZZLE_MODE_KV == SwizzleMode::k128B ? 4 : 8; + SWIZZLE_MODE_KV == SwizzleMode::k128B ? WARP_THREAD_ROWS + : WARP_THREAD_COLS; static constexpr uint32_t KV_THR_LAYOUT_COL = SWIZZLE_MODE_KV == SwizzleMode::k128B ? 8 : 4; - static constexpr PosEncodingMode POS_ENCODING_MODE = POS_ENCODING_MODE_; - using DTypeQ = DTypeQ_; - using DTypeKV = DTypeKV_; - using DTypeO = DTypeO_; - using DTypeQKAccum = DTypeQKAccum_; - using IdType = IdType_; - using AttentionVariant = AttentionVariant_; +#endif + static constexpr uint32_t UPCAST_STRIDE_Q = + HEAD_DIM_QK / upcast_size(); + static constexpr uint32_t UPCAST_STRIDE_K = + HEAD_DIM_QK / upcast_size(); + static constexpr uint32_t UPCAST_STRIDE_V = + HEAD_DIM_VO / upcast_size(); + static constexpr uint32_t UPCAST_STRIDE_O = + HEAD_DIM_VO / upcast_size(); static constexpr bool IsInvalid() { @@ -340,19 +353,18 @@ __device__ __forceinline__ void produce_kv_helper_(uint32_t warp_idx, uint32_t lane_idx) { using DTypeKV = typename KTraits::DTypeKV; + constexpr uint32_t WARP_THREAD_COLS = KTraits::WARP_THREAD_COLS; + constexpr uint32_t WARP_THREAD_ROWS = KTraits::WARP_THREAD_ROWS; constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; constexpr uint32_t NUM_WARPS_Q = KTraits::NUM_WARPS_Q; constexpr uint32_t NUM_MMA_D = produce_v ? KTraits::NUM_MMA_D_VO : KTraits::NUM_MMA_D_QK; + constexpr uint32_t UPCAST_STRIDE = + produce_v ? KTraits::UPCAST_STRIDE_V : KTraits::UPCAST_STRIDE_K; #if defined(PLATFORM_HIP_DEVICE) - constexpr uint32_t HEAD_DIM = - produce_v ? KTraits::HEAD_DIM_QK : KTraits::HEAD_DIM_VO; - constexpr uint32_t UPCAST_STRIDE = HEAD_DIM / HP_QUERY_ELEMS_PER_THREAD; constexpr uint32_t COLUMN_RESET_OFFSET = (NUM_MMA_D / 4) * WARP_THREAD_COLS; #else - constexpr uint32_t UPCAST_STRIDE = - produce_v ? KTraits::UPCAST_STRIDE_V : KTraits::UPCAST_STRIDE_K; constexpr uint32_t COLUMN_RESET_OFFSET = sizeof(DTypeKV) * NUM_MMA_D; #endif @@ -589,13 +601,14 @@ load_q_global_smem(uint32_t packed_offset, const dim3 tid = threadIdx) { using DTypeQ = typename KTraits::DTypeQ; + constexpr uint32_t WARP_THREAD_COLS = KTraits::WARP_THREAD_COLS; + constexpr uint32_t WARP_THREAD_ROWS = KTraits::WARP_THREAD_ROWS; + constexpr uint32_t HALF_ELEMS_PER_THREAD = KTraits::HALF_ELEMS_PER_THREAD; + #if defined(PLATFORM_HIP_DEVICE) - constexpr uint32_t UPCAST_STRIDE_Q = - KTraits::HEAD_DIM_QK / HP_QUERY_ELEMS_PER_THREAD; constexpr uint32_t COLUMN_RESET_OFFSET = (NUM_MMA_D_QK / 4) * WARP_THREAD_COLS; #else - constexpr uint32_t UPCAST_STRIDE_Q = KTraits::UPCAST_STRIDE_Q; constexpr uint32_t COLUMN_RESET_OFFSET = 2 * KTraits::NUM_MMA_D_QK; #endif @@ -634,7 +647,7 @@ load_q_global_smem(uint32_t packed_offset, #endif q_smem_offset_w = q_smem->template advance_offset_by_column< WARP_THREAD_COLS>(q_smem_offset_w, mma_do); - q_ptr += HP_QUERY_ELEMS_PER_THREAD * upcast_size(); + q_ptr += HALF_ELEMS_PER_THREAD * upcast_size(); } q_smem_offset_w = q_smem->template advance_offset_by_row Date: Tue, 5 Aug 2025 13:42:34 -0400 Subject: [PATCH 014/109] Ported query rope transformation to MI300. --- .../attention/generic/permuted_smem.cuh | 29 +++++ .../flashinfer/attention/generic/prefill.cuh | 100 ++++++++++++------ 2 files changed, 97 insertions(+), 32 deletions(-) diff --git a/libflashinfer/include/flashinfer/attention/generic/permuted_smem.cuh b/libflashinfer/include/flashinfer/attention/generic/permuted_smem.cuh index ec94717da4..e9ab8b6292 100644 --- a/libflashinfer/include/flashinfer/attention/generic/permuted_smem.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/permuted_smem.cuh @@ -151,6 +151,35 @@ template struct smem_t } } + template + __device__ __forceinline__ void load_fragment(uint32_t offset, T *frag) + { +#if defined(PLATFORM_HIP_DEVICE) + static_assert(sizeof(T) == 4, "Only 32-bit fragment loading supported"); + reinterpret_cast(frag)[0] = + *reinterpret_cast(base + offset); + reinterpret_cast(&frag[2])[0] = + *reinterpret_cast(base + (offset ^ 0x1)); +#else + ldmatrix_m8n8x4(offset, frag); +#endif + } + + template + __device__ __forceinline__ void store_fragment(uint32_t offset, + const T *frag) + { +#if defined(PLATFORM_HIP_DEVICE) + static_assert(sizeof(T) == 4, "Only 32-bit fragment storing supported"); + *reinterpret_cast(base + offset) = + reinterpret_cast(frag)[0]; + *reinterpret_cast(base + (offset ^ 0x1)) = + reinterpret_cast(&frag[2])[0]; +#else + stmatrix_m8n8x4(offset, frag); +#endif + } + __device__ __forceinline__ void ldmatrix_m8n8x4(uint32_t offset, uint32_t *R) { diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh index b5590f0ce3..36ff1ff87f 100644 --- a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -144,11 +144,13 @@ struct KernelTraits constexpr uint32_t WARP_THREAD_COLS = 16; constexpr uint32_t WARP_THREAD_ROWS = 4; constexpr uint32_t HALF_ELEMS_PER_THREAD = 4; + constexpr uint32_t INT32_ELEMS_PER_THREAD = 2; constexpr uint32_t VECTOR_BIT_WIDTH = HALF_ELEMS_PER_THREAD * 16; // FIXME: Update with a proper swizzle pattern. Linear is used primarily // for intial testing. static constexpr SwizzleMode SWIZZLE_MODE_Q = SwizzleMode::kLinear; static constexpr SwizzleMode SWIZZLE_MODE_KV = SwizzleMode::kLinear; + static constexpr SmemBasePtrTy = uint2; // Presently we use 16x4 thread layout for all cases. static constexpr uint32_t KV_THR_LAYOUT_ROW = WARP_THREAD_ROWS; static constexpr uint32_t KV_THR_LAYOUT_COL = WARP_THREAD_COLS; @@ -157,12 +159,14 @@ struct KernelTraits constexpr uint32_t WARP_THREAD_COLS = 8; constexpr uint32_t WARP_THREAD_ROWS = 4; constexpr uint32_t HALF_ELEMS_PER_THREAD = 8; + constexpr uint32_t INT32_ELEMS_PER_THREAD = 4; constexpr uint32_t VECTOR_BIT_WIDTH = HALF_ELEMS_PER_THREAD * 16; static constexpr SwizzleMode SWIZZLE_MODE_Q = SwizzleMode::k128B; static constexpr SwizzleMode SWIZZLE_MODE_KV = (sizeof(DTypeKV_) == 1 && HEAD_DIM_VO == 64) ? SwizzleMode::k64B : SwizzleMode::k128B; + static constexpr SmemBasePtrTy = uint4; static constexpr uint32_t KV_THR_LAYOUT_ROW = SWIZZLE_MODE_KV == SwizzleMode::k128B ? WARP_THREAD_ROWS : WARP_THREAD_COLS; @@ -310,7 +314,11 @@ q_frag_apply_llama_rope(T *x_first_half, // 0 1 | 4 5 // --------- // 2 3 | 6 7 +#if defined(PLATFORM_HIP_DEVICE) + uint32_t i = reg_id / 2, j = reg_id % 2; +#else uint32_t i = ((reg_id % 4) / 2), j = (reg_id / 4); +#endif __sincosf(float((qo_packed_offset + 8 * i) / group_size) * rope_freq[2 * j + reg_id % 2], &sin, &cos); @@ -543,6 +551,17 @@ init_rope_freq(float (*rope_freq)[4], { constexpr uint32_t HEAD_DIM = KTraits::NUM_MMA_D_QK * 16; const uint32_t lane_idx = tid_x; + +#if defined(PLATFORM_HIP_DEVICE) + // MI300: 8 threads handle 8 elements (1 element per thread) + constexpr uint32_t THREADS_PER_ROW = 8; + constexpr uint32_t ELEMS_PER_THREAD = 1; +#else + // NVIDIA: 4 threads handle 8 elements (2 elements per thread) + constexpr uint32_t THREADS_PER_ROW = 4; + constexpr uint32_t ELEMS_PER_THREAD = 2; +#endif + #pragma unroll for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO / 2; ++mma_d) { #pragma unroll @@ -550,27 +569,31 @@ init_rope_freq(float (*rope_freq)[4], rope_freq[mma_d][j] = rope_rcp_scale * __powf(rope_rcp_theta, - float(2 * ((mma_d * 16 + (j / 2) * 8 + - (lane_idx % 4) * 2 + (j % 2)) % - (HEAD_DIM / 2))) / + float(2 * + ((mma_d * 16 + (j / 2) * 8 + + (lane_idx % THREADS_PER_ROW) * ELEMS_PER_THREAD + + (j % 2)) % + (HEAD_DIM / 2))) / float(HEAD_DIM)); } } } template -__device__ __forceinline__ void -init_states(typename KTraits::AttentionVariant variant, - float (*o_frag)[KTraits::NUM_MMA_D_VO][8], - typename KTraits::DTypeQKAccum (*m)[2], - float (*d)[2]) +__device__ __forceinline__ void init_states( + typename KTraits::AttentionVariant variant, + float (*o_frag)[KTraits::NUM_MMA_D_VO][KTraits::HALF_ELEMS_PER_THREAD], + typename KTraits::DTypeQKAccum (*m)[2], + float (*d)[2]) { #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { #pragma unroll for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { #pragma unroll - for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { + for (uint32_t reg_id = 0; reg_id < KTraits::HALF_ELEMS_PER_THREAD; + ++reg_id) + { o_frag[mma_q][mma_d][reg_id] = 0.f; } } @@ -660,22 +683,31 @@ load_q_global_smem(uint32_t packed_offset, } template -__device__ __forceinline__ void -q_smem_inplace_apply_rotary(const uint32_t q_packed_idx, - const uint32_t qo_len, - const uint32_t kv_len, - const uint_fastdiv group_size, - smem_t *q_smem, - uint32_t *q_smem_offset_r, - float (*rope_freq)[4], - const dim3 tid = threadIdx) +__device__ __forceinline__ void q_smem_inplace_apply_rotary( + const uint32_t q_packed_idx, + const uint32_t qo_len, + const uint32_t kv_len, + const uint_fastdiv group_size, + smem_t *q_smem, + uint32_t *q_smem_offset_r, + float (*rope_freq)[4], + const dim3 tid = threadIdx) { if (get_warp_idx_kv(tid.z) == 0) { constexpr uint32_t UPCAST_STRIDE_Q = KTraits::UPCAST_STRIDE_Q; const uint32_t lane_idx = tid.x; - uint32_t q_frag_local[2][4]; + uint32_t q_frag_local[2][KTraits::INT32_ELEMS_PER_THREAD]; static_assert(KTraits::NUM_MMA_D_QK % 4 == 0, "NUM_MMA_D_QK must be a multiple of 4"); + +#if defined(PLATFORM_HIP_DEVICE) + // MI300: 8 threads handle a row of 8 elements + const uint32_t pos_group_idx = lane_idx / 8; +#else + // NVIDIA: 4 threads handle a row of 8 elements + const uint32_t pos_group_idx = lane_idx / 4; +#endif + #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { uint32_t q_smem_offset_r_first_half = *q_smem_offset_r; @@ -683,24 +715,24 @@ q_smem_inplace_apply_rotary(const uint32_t q_packed_idx, for (uint32_t mma_di = 0; mma_di < KTraits::NUM_MMA_D_QK / 2; ++mma_di) { - q_smem->ldmatrix_m8n8x4(q_smem_offset_r_first_half, - q_frag_local[0]); + q_smem->template load_fragment(q_smem_offset_r_first_half, + q_frag_local[0]); uint32_t q_smem_offset_r_last_half = q_smem->template advance_offset_by_column< KTraits::NUM_MMA_D_QK>(q_smem_offset_r_first_half, 0); - q_smem->ldmatrix_m8n8x4(q_smem_offset_r_last_half, - q_frag_local[1]); + q_smem->template load_fragment(q_smem_offset_r_last_half, + q_frag_local[1]); q_frag_apply_llama_rope( (typename KTraits::DTypeQ *)q_frag_local[0], (typename KTraits::DTypeQ *)q_frag_local[1], rope_freq[mma_di], q_packed_idx + kv_len * group_size - qo_len * group_size + - mma_q * 16 + lane_idx / 4, + mma_q * 16 + pos_group_idx, group_size); - q_smem->stmatrix_m8n8x4(q_smem_offset_r_last_half, - q_frag_local[1]); - q_smem->stmatrix_m8n8x4(q_smem_offset_r_first_half, - q_frag_local[0]); + q_smem->template store_fragment(q_smem_offset_r_last_half, + q_frag_local[1]); + q_smem->template store_fragment(q_smem_offset_r_first_half, + q_frag_local[0]); q_smem_offset_r_first_half = q_smem->template advance_offset_by_column<2>( q_smem_offset_r_first_half, mma_di); @@ -1760,11 +1792,14 @@ SinglePrefillWithKVCacheDevice(const Params params, KTraits::SWIZZLE_MODE_Q; [[maybe_unused]] constexpr SwizzleMode SWIZZLE_MODE_KV = KTraits::SWIZZLE_MODE_KV; + [[maybe_unused]] constexpr SmemBasePtrTy = KTraits::SmemBasePtrTy; [[maybe_unused]] constexpr uint32_t KV_THR_LAYOUT_ROW = KTraits::KV_THR_LAYOUT_ROW; [[maybe_unused]] constexpr uint32_t KV_THR_LAYOUT_COL = KTraits::KV_THR_LAYOUT_COL; [[maybe_unused]] constexpr MaskMode MASK_MODE = KTraits::MASK_MODE; + [[maybe_unused]] constexpr HALF_ELEMS_PER_THREAD = + KTraits::HALF_ELEMS_PER_THREAD; DTypeQ *q = params.q; DTypeKV *k = params.k; @@ -1801,8 +1836,9 @@ SinglePrefillWithKVCacheDevice(const Params params, AttentionVariant variant(params, /*batch_idx=*/0, smem); const uint32_t window_left = variant.window_left; - DTypeQKAccum s_frag[NUM_MMA_Q][NUM_MMA_KV][8]; - alignas(16) float o_frag[NUM_MMA_Q][NUM_MMA_D_VO][8]; + DTypeQKAccum s_frag[NUM_MMA_Q][NUM_MMA_KV][HALF_ELEMS_PER_THREAD]; + alignas( + 16) float o_frag[NUM_MMA_Q][NUM_MMA_D_VO][HALF_ELEMS_PER_THREAD]; DTypeQKAccum m[NUM_MMA_Q][2]; float d[NUM_MMA_Q][2]; float rope_freq[NUM_MMA_D_QK / 2][4]; @@ -1819,7 +1855,7 @@ SinglePrefillWithKVCacheDevice(const Params params, const uint32_t qo_packed_idx_base = (bx * NUM_WARPS_Q + get_warp_idx_q(tid.y)) * NUM_MMA_Q * 16; - smem_t qo_smem(smem_storage.q_smem); + smem_t qo_smem(smem_storage.q_smem); const uint32_t o_stride_n = num_qo_heads * HEAD_DIM_VO, o_stride_h = HEAD_DIM_VO; DTypeQ *q_ptr_base = q + (kv_head_idx * group_size) * q_stride_h; @@ -1848,7 +1884,7 @@ SinglePrefillWithKVCacheDevice(const Params params, block.sync(); } - smem_t k_smem(smem_storage.k_smem), + smem_t k_smem(smem_storage.k_smem), v_smem(smem_storage.v_smem); const uint32_t num_iterations = ceil_div( From de13cd06f34258806bcb5c08cb107bfe567c2bee Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Wed, 6 Aug 2025 04:42:01 -0400 Subject: [PATCH 015/109] WIP changes to compute_qk --- .../flashinfer/attention/generic/prefill.cuh | 21 +++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh index 36ff1ff87f..95749c24f9 100644 --- a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -805,7 +805,7 @@ k_smem_inplace_apply_rotary(const uint32_t kv_idx_base, using DTypeKV = typename KTraits::DTypeKV; static_assert(sizeof(DTypeKV) == 2); constexpr uint32_t UPCAST_STRIDE_K = KTraits::UPCAST_STRIDE_K; - uint32_t k_frag_local[2][4]; + uint32_t k_frag_local[2][KTraits::INT32_ELEMS_PER_THREAD]; const uint32_t lane_idx = tid.x; if constexpr (KTraits::NUM_MMA_D_QK == 4 && KTraits::NUM_WARPS_Q == 4) { static_assert(KTraits::NUM_WARPS_KV == 1); @@ -904,7 +904,11 @@ compute_qk(smem_t *q_smem, { constexpr uint32_t UPCAST_STRIDE_Q = KTraits::UPCAST_STRIDE_Q; constexpr uint32_t UPCAST_STRIDE_K = KTraits::UPCAST_STRIDE_K; - uint32_t a_frag[KTraits::NUM_MMA_Q][4], b_frag[4]; + constexpr uint32_t Q_SMEM_COLUMN_ADVANCE = + 16 / KTraits::HALF_ELEMS_PER_THREAD; + + uint32_t a_frag[KTraits::NUM_MMA_Q][KTraits::INT32_ELEMS_PER_THREAD], + b_frag[KTraits::INT32_ELEMS_PER_THREAD]; // compute q*k^T #pragma unroll for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_QK; ++mma_d) { @@ -916,13 +920,18 @@ compute_qk(smem_t *q_smem, *q_smem_offset_r); } - *q_smem_offset_r = q_smem->template advance_offset_by_column<2>( - *q_smem_offset_r, mma_d) - - KTraits::NUM_MMA_Q * 16 * UPCAST_STRIDE_Q; + *q_smem_offset_r = + q_smem->template advance_offset_by_column( + *q_smem_offset_r, mma_d) - + KTraits::NUM_MMA_Q * 16 * UPCAST_STRIDE_Q; #pragma unroll for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { if constexpr (sizeof(typename KTraits::DTypeKV) == 1) { +#if defined(PLATFORM_HIP_DEVICE) + static_assert(false, + "FP8 support not yet implemented for CDNA3"); +#endif uint32_t b_frag_f8[2]; if (mma_d % 2 == 0) { k_smem->ldmatrix_m8n8x4_left_half(*k_smem_offset_r, @@ -994,7 +1003,7 @@ compute_qk(smem_t *q_smem, KTraits::NUM_MMA_KV * 16 * UPCAST_STRIDE_K; } } - *q_smem_offset_r -= KTraits::NUM_MMA_D_QK * 2; + *q_smem_offset_r -= KTraits::NUM_MMA_D_QK * Q_SMEM_COLUMN_ADVANCE; *k_smem_offset_r -= KTraits::NUM_MMA_D_QK * sizeof(typename KTraits::DTypeKV); } From cf8bd74635934493b5f0e0e515815b4daa94b104 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Thu, 7 Aug 2025 01:18:26 -0400 Subject: [PATCH 016/109] WIP --- .../flashinfer/attention/generic/prefill.cuh | 2 +- .../tests/hip/test_k_smem_read_pattern.cpp | 204 ++++++++++++ .../hip/test_transpose_4x4_half_registers.cpp | 307 ++++++++++++++++++ 3 files changed, 512 insertions(+), 1 deletion(-) create mode 100644 libflashinfer/tests/hip/test_k_smem_read_pattern.cpp create mode 100644 libflashinfer/tests/hip/test_transpose_4x4_half_registers.cpp diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh index 95749c24f9..089f9531f3 100644 --- a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -1353,7 +1353,7 @@ compute_sfm_v(smem_t *v_smem, } else { #warning "TODO ldmatrix_m8n8x4_trans ............" - // v_smem->ldmatrix_m8n8x4_trans(*v_smem_offset_r, b_frag); + v_smem->ldmatrix_m8n8x4_trans(*v_smem_offset_r, b_frag); } #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { diff --git a/libflashinfer/tests/hip/test_k_smem_read_pattern.cpp b/libflashinfer/tests/hip/test_k_smem_read_pattern.cpp new file mode 100644 index 0000000000..1676c88285 --- /dev/null +++ b/libflashinfer/tests/hip/test_k_smem_read_pattern.cpp @@ -0,0 +1,204 @@ +#include +#include +#include +#include +#include + +// Constants for MI300 +constexpr uint32_t WARP_SIZE = 64; // 64 threads per wavefront +constexpr uint32_t HALF_ELEMS_PER_THREAD = + 4; // Each thread processes 4 half elements +constexpr uint32_t INT32_ELEMS_PER_THREAD = 2; // 2 int32 registers per thread + +// Simplified linear shared memory operations (CPU implementation) +template +uint32_t get_permuted_offset_linear(uint32_t row, uint32_t col) +{ + return row * stride + col; +} + +template +uint32_t advance_offset_by_column_linear(uint32_t offset, uint32_t step_idx) +{ + return offset + step_size; +} + +template +uint32_t advance_offset_by_row_linear(uint32_t offset) +{ + return offset + step_size * row_stride; +} + +// CPU-based simulation of k-matrix access pattern in compute_qk +template +void SimulateKReadPattern(std::vector &thread_ids_reading_offsets) +{ + // Constants derived from HEAD_DIM + constexpr uint32_t UPCAST_STRIDE_K = HEAD_DIM / HALF_ELEMS_PER_THREAD; + constexpr uint32_t NUM_MMA_D_QK = HEAD_DIM / 16; + constexpr uint32_t grid_width = HEAD_DIM / HALF_ELEMS_PER_THREAD; + constexpr uint32_t grid_height = 16 * NUM_MMA_KV; + + constexpr uint32_t K_SMEM_COLUMN_ADVANCE = + 16 / HALF_ELEMS_PER_THREAD; // = 4 for MI300 + + // Initialize with -1 (unread) + thread_ids_reading_offsets.assign(grid_height * grid_width, -1); + + // Simulate each thread's read pattern + for (uint32_t tid = 0; tid < WARP_SIZE; tid++) { + // Map tid to kernel's lane_idx + uint32_t lane_idx = tid; + uint32_t warp_idx_kv = 0; // For simplicity, assuming one warp group + + // Exactly match the kernel's initial offset calculation - MI300 version + uint32_t k_smem_offset_r = get_permuted_offset_linear( + warp_idx_kv * NUM_MMA_KV * 16 + 4 * (lane_idx / 16) + lane_idx % 4, + (lane_idx % 16) / 4); + + // uint32_t k_smem_offset_r = + // get_permuted_offset_linear( + // warp_idx_kv * NUM_MMA_KV * 16 + + // 4 * (lane_idx / 16), + // (lane_idx % 16)); + + // Follow the same loop structure as in compute_qk + for (uint32_t mma_d = 0; mma_d < NUM_MMA_D_QK; ++mma_d) { + for (uint32_t mma_kv = 0; mma_kv < NUM_MMA_KV; ++mma_kv) { + // Mark grid positions accessed by ldmatrix_m8n8x4 / + // load_fragment + uint32_t read_row = k_smem_offset_r / UPCAST_STRIDE_K; + uint32_t read_col = k_smem_offset_r % UPCAST_STRIDE_K; + + if (tid == 0) { + std::cout << "Thread " << tid << " k_smem_offset_r " + << k_smem_offset_r << '\n'; + } + + // Simulate loading a matrix fragment + for (uint32_t reg_id = 0; reg_id < INT32_ELEMS_PER_THREAD; + reg_id++) + { + if (read_row < grid_height && read_col < grid_width) { + thread_ids_reading_offsets[read_row * grid_width + + read_col] = tid; + } + + // Each INT32_ELEMS_PER_THREAD register holds 2 half + // elements For simplicity, we're just recording the base + // offset + } + + // Advance to next row, exactly as in compute_qk + k_smem_offset_r = + advance_offset_by_row_linear<16, UPCAST_STRIDE_K>( + k_smem_offset_r); + } + + // Reset row position and advance to next column section, exactly as + // in compute_qk For MI300, advance by 4 columns (vs 2 for NVIDIA) + k_smem_offset_r = + advance_offset_by_column_linear( + k_smem_offset_r, mma_d) - + NUM_MMA_KV * 16 * UPCAST_STRIDE_K; + } + } +} + +// Helper function to run the test with configurable parameters +template void RunKReadPatternTest() +{ + constexpr uint32_t grid_width = HEAD_DIM / HALF_ELEMS_PER_THREAD; + constexpr uint32_t grid_height = 16 * NUM_MMA_KV; + + printf("\n=== Testing key read pattern with HEAD_DIM = %u, NUM_MMA_KV = %u " + "===\n", + HEAD_DIM, NUM_MMA_KV); + + // Host array to store thread IDs at each offset + std::vector thread_ids(grid_height * grid_width, -1); + + // Run CPU simulation of read pattern + SimulateKReadPattern(thread_ids); + + // Print the grid of thread IDs + printf("Thread IDs reading from each offset (%dx%d grid):\n", grid_height, + grid_width); + + // Column headers + printf(" "); + for (int c = 0; c < grid_width; c++) { + printf("%3d ", c); + if (c == 15 && grid_width > 16) + printf("| "); // Divider for HEAD_DIM=128 + } + printf("\n +"); + for (int c = 0; c < grid_width; c++) { + printf("----"); + if (c == 15 && grid_width > 16) + printf("+"); + } + printf("\n"); + + // Print the grid + for (int r = 0; r < grid_height; r++) { + printf("%2d | ", r); + for (int c = 0; c < grid_width; c++) { + int thread_id = thread_ids[r * grid_width + c]; + if (thread_id >= 0) { + printf("%3d ", thread_id); + } + else { + printf(" . "); // Dot for unread positions + } + if (c == 15 && grid_width > 16) + printf("| "); // Divider for HEAD_DIM=128 + } + printf("\n"); + } + + // Check for unread positions + int unread = 0; + for (int i = 0; i < grid_height * grid_width; i++) { + if (thread_ids[i] == -1) { + unread++; + } + } + + // Print statistics + printf("\nStatistics:\n"); + printf("- Positions read: %d/%d (%.1f%%)\n", + grid_height * grid_width - unread, grid_height * grid_width, + 100.0f * (grid_height * grid_width - unread) / + (grid_height * grid_width)); + printf("- Unread positions: %d/%d (%.1f%%)\n", unread, + grid_height * grid_width, + 100.0f * unread / (grid_height * grid_width)); + + // Validate full coverage + EXPECT_EQ(unread, 0) << "Not all positions were read"; +} + +// Tests for different configurations +TEST(MI300KReadPatternTest, HeadDim64_NumMmaKV1) +{ + RunKReadPatternTest<64, 1>(); +} + +// TEST(MI300KReadPatternTest, HeadDim128_NumMmaKV1) { +// RunKReadPatternTest<128, 1>(); +// } + +// TEST(MI300KReadPatternTest, HeadDim64_NumMmaKV2) { +// RunKReadPatternTest<64, 2>(); +// } + +// TEST(MI300KReadPatternTest, HeadDim128_NumMmaKV2) { +// RunKReadPatternTest<128, 2>(); +// } + +int main(int argc, char **argv) +{ + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/libflashinfer/tests/hip/test_transpose_4x4_half_registers.cpp b/libflashinfer/tests/hip/test_transpose_4x4_half_registers.cpp new file mode 100644 index 0000000000..4a4354fb3c --- /dev/null +++ b/libflashinfer/tests/hip/test_transpose_4x4_half_registers.cpp @@ -0,0 +1,307 @@ +// test_transpose_4x4_half_registers.cpp +#include +#include +#include + +// Define WARP_FULL_MASK for HIP +constexpr uint64_t WARP_FULL_MASK = + 0xffffffffffffffffULL; // 64-bit mask for HIP + +__device__ __forceinline__ void debug_print_registers(const char *stage, + uint32_t lane_id, + uint32_t lane_in_group, + uint32_t *regs, + int num_regs, + uint32_t debug_group = 0) +{ + + // Only debug a specific group to avoid excessive output + if (lane_id / 4 != debug_group) + return; + + // Print identification info + printf("STAGE: %s | Thread %d (lane_in_group=%d): ", stage, lane_id, + lane_in_group); + + // Print raw 32-bit values + printf("RAW=["); + for (int i = 0; i < num_regs; i++) { + printf("0x%08x", regs[i]); + if (i < num_regs - 1) + printf(", "); + } + printf("] | "); + + // Print unpacked 16-bit values + printf("UNPACKED=["); + for (int i = 0; i < num_regs; i++) { + uint16_t hi = (regs[i] >> 16) & 0xFFFF; + uint16_t lo = regs[i] & 0xFFFF; + printf("%d,%d", hi, lo); + if (i < num_regs - 1) + printf(", "); + } + printf("]\n"); +} + +__device__ __forceinline__ void transpose_4x4_half_registers(uint32_t *R, + uint32_t *out) +{ + // Each thread has 4 half-precision values in 2 registers: + // R[0] = [B[lane_id][0], B[lane_id][1]] + // R[1] = [B[lane_id][2], B[lane_id][3]] + + // Calculate lane within 4-thread group + uint32_t lane_id = threadIdx.x % 64; + uint32_t lane_in_group = lane_id % 4; + uint32_t temp_regs[2]; + + if (lane_id == 0) { + debug_print_registers("Initial", lane_id, lane_in_group, R, 2, 0); + } + + // === ROUND 1: Exchange with neighbor (XOR with 1) === + // T0↔T1, T2↔T3 partial exchange + uint32_t r0_exchanged = __shfl_xor(R[0], 0x1); + uint32_t r1_exchanged = __shfl_xor(R[1], 0x1); + + // Debug first exchange + if (lane_id == 0) { + uint32_t debug_regs[2] = {r0_exchanged, r1_exchanged}; + debug_print_registers("Round1-Exchange", lane_id, lane_in_group, + debug_regs, 2, 0); + } + + // Update based on thread position + if (lane_in_group < 2) { + // Top half (T0, T1) update R[0] + if (lane_in_group & 1) { // T1 + temp_regs[0] = (R[0] & 0xFFFF0000) | (r0_exchanged & 0x0000FFFF); + } + else { // T0 + temp_regs[0] = (R[0] & 0x0000FFFF) | (r0_exchanged & 0xFFFF0000); + } + // Keep R[1] unchanged + temp_regs[1] = R[1]; + } + else { + // Bottom half (T2, T3) update R[1] + if (lane_in_group & 1) { // T3 + temp_regs[1] = (R[1] & 0xFFFF0000) | (r1_exchanged & 0x0000FFFF); + } + else { // T2 + temp_regs[1] = (R[1] & 0x0000FFFF) | (r1_exchanged & 0xFFFF0000); + } + // Keep R[0] unchanged + temp_regs[0] = R[0]; + } + + // Debug after first recombination + if (lane_id == 0) { + debug_print_registers("Round1-Exchange", lane_id, lane_in_group, + temp_regs, 2, 0); + } + + // === ROUND 2: Exchange with one hop (XOR with 2) === + // T0↔T2, T1↔T3 exchange R[0] and R[1] + uint32_t temp0_exchanged = __shfl_xor(temp_regs[0], 2); + uint32_t temp1_exchanged = __shfl_xor(temp_regs[1], 2); + + // Debug second exchange + if (lane_id < 4) { + uint32_t debug_regs[2] = {temp0_exchanged, temp1_exchanged}; + debug_print_registers("Round2-Exchange", lane_id, lane_in_group, + debug_regs, 2, 0); + } + + // Swap entire registers based on thread position + if (lane_in_group < 2) { + // Top threads (T0, T1) get R[0] from partner, keep own R[1] + temp_regs[0] = temp0_exchanged; + // Keep R[1] unchanged + } + else { + // Bottom threads (T2, T3) get R[1] from partner, keep own R[0] + temp_regs[1] = temp1_exchanged; + // Keep R[0] unchanged + } + + // Debug after second recombination + if (lane_id < 4) { + debug_print_registers("Round2-Result", lane_id, lane_in_group, + temp_regs, 2, 0); + } + + // === ROUND 3: Exchange with neighbor again (XOR with 1) === + // T0↔T1, T2↔T3 exchange remaining parts + uint32_t final0_exchanged = __shfl_xor(temp_regs[0], 1); + uint32_t final1_exchanged = __shfl_xor(temp_regs[1], 1); + + // Debug third exchange + if (lane_id < 4) { + uint32_t debug_regs[2] = {final0_exchanged, final1_exchanged}; + debug_print_registers("Round3-Exchange", lane_id, lane_in_group, + debug_regs, 2, 0); + } + + // Final combination based on thread position + if (lane_in_group < 2) { + // Top half (T0, T1) update R[1] + if (lane_in_group & 1) { // T1 + out[1] = + (temp_regs[1] & 0xFFFF0000) | (final1_exchanged & 0x0000FFFF); + } + else { // T0 + out[1] = + (temp_regs[1] & 0x0000FFFF) | (final1_exchanged & 0xFFFF0000); + } + // Keep R[0] unchanged + out[0] = temp_regs[0]; + } + else { + // Bottom half (T2, T3) update R[0] + if (lane_in_group & 1) { // T3 + out[0] = + (temp_regs[0] & 0xFFFF0000) | (final0_exchanged & 0x0000FFFF); + } + else { // T2 + out[0] = + (temp_regs[0] & 0x0000FFFF) | (final0_exchanged & 0xFFFF0000); + } + // Keep R[1] unchanged + out[1] = temp_regs[1]; + } + + // Debug final result + if (lane_id < 4) { + debug_print_registers("Final-Result", lane_id, lane_in_group, out, 2, + 0); + } +} + +// Helper function to convert two uint16_t values to a single uint32_t +__host__ __device__ uint32_t pack_half2(uint16_t a, uint16_t b) +{ + return ((uint32_t)a << 16) | (uint32_t)b; +} + +// Helper function to extract two uint16_t values from a single uint32_t +__host__ __device__ void unpack_half2(uint32_t packed, uint16_t &a, uint16_t &b) +{ + a = (packed >> 16) & 0xFFFF; + b = packed & 0xFFFF; +} + +// Kernel to test the transpose function +__global__ void test_transpose_kernel(uint16_t *output) +{ + uint32_t thread_id = threadIdx.x + blockIdx.x * blockDim.x; + uint32_t lane_id = thread_id % 64; + + // Calculate the thread's position in the logical 4x4 grid + uint32_t group_id = lane_id / 4; // Which 4-thread group + uint32_t lane_in_group = lane_id % 4; // Position within group + + // Initialize test data - each thread creates a row of the matrix B + // Values are designed for easy verification: lane_in_group * 100 + column + uint16_t row_elements[4]; + for (int i = 0; i < 4; i++) { + row_elements[i] = lane_in_group * 100 + i; // B[lane_in_group][i] + } + + // Pack the 4 half-precision values into 2 registers + uint32_t R[2]; + R[0] = pack_half2(row_elements[0], row_elements[1]); + R[1] = pack_half2(row_elements[2], row_elements[3]); + + // Call the transpose function + uint32_t out[2]; + transpose_4x4_half_registers(R, out); + + // Unpack the transposed results + uint16_t transposed[4]; + unpack_half2(out[0], transposed[0], transposed[1]); + unpack_half2(out[1], transposed[2], transposed[3]); + + // Write output - store both original and transposed values for verification + for (int i = 0; i < 4; i++) { + // Original values (row-major layout) + output[thread_id * 8 + i] = row_elements[i]; + // Transposed values (column-major layout) + output[thread_id * 8 + 4 + i] = transposed[i]; + } +} + +int main() +{ + // Allocate memory for output (both original and transposed data) + const int num_threads = 64; // One wavefront + const int values_per_thread = + 8; // Each thread stores 4 original + 4 transposed values + const int total_values = num_threads * values_per_thread; + + std::vector h_output(total_values); + uint16_t *d_output; + + hipMalloc(&d_output, total_values * sizeof(uint16_t)); + + // Launch the kernel + test_transpose_kernel<<<1, num_threads>>>(d_output); + + // Copy results back to host + hipMemcpy(h_output.data(), d_output, total_values * sizeof(uint16_t), + hipMemcpyDeviceToHost); + + // Verify the results + bool success = true; + std::cout << "Testing matrix transposition with shuffle operations..." + << std::endl; + + // for (int group = 0; group < num_threads / 4; group++) { + // std::cout << "\nGroup " << group << " results:" << std::endl; + + // for (int lane = 0; lane < 4; lane++) { + // int thread_idx = group * 4 + lane; + + // // Print original values + // std::cout << "Thread " << thread_idx << " original: "; + // for (int i = 0; i < 4; i++) { + // std::cout << h_output[thread_idx * 8 + i] << " "; + // } + // std::cout << std::endl; + + // // Print and verify transposed values + // std::cout << "Thread " << thread_idx << " transposed: "; + // for (int i = 0; i < 4; i++) { + // uint16_t actual = h_output[thread_idx * 8 + 4 + i]; + // std::cout << actual << " "; + + // // Expected after transpose: Thread N gets column N + // // Thread 0 should have [0*100+0, 1*100+0, 2*100+0, 3*100+0] + // // Thread 1 should have [0*100+1, 1*100+1, 2*100+1, 3*100+1] + // uint16_t expected = i * 100 + lane; + + // if (actual != expected) { + // success = false; + // std::cout << "(Expected: " << expected << ") "; + // } + // } + // std::cout << std::endl; + // } + // } + + if (success) { + std::cout << "\nTranspose test PASSED! All values correctly transposed." + << std::endl; + } + else { + std::cout << "\nTranspose test FAILED! Some values were not correctly " + "transposed." + << std::endl; + } + + // Clean up + hipFree(d_output); + + return success ? 0 : 1; +} From 28aa7183e83de2ba96a45eef9cae1720e7868e02 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Thu, 7 Aug 2025 01:51:04 -0400 Subject: [PATCH 017/109] WIP2 --- .gitignore | 1 + .../hip/test_transpose_4x4_half_registers.cpp | 41 ++++++++----------- 2 files changed, 17 insertions(+), 25 deletions(-) diff --git a/.gitignore b/.gitignore index 22609edc36..9f58c56551 100644 --- a/.gitignore +++ b/.gitignore @@ -15,6 +15,7 @@ flashinfer/__config__.py flashinfer/jit/aot_config.py src/generated/ csrc/aot_default_additional_params.h +*.out # DS_Store files .DS_store diff --git a/libflashinfer/tests/hip/test_transpose_4x4_half_registers.cpp b/libflashinfer/tests/hip/test_transpose_4x4_half_registers.cpp index 4a4354fb3c..8a7e837045 100644 --- a/libflashinfer/tests/hip/test_transpose_4x4_half_registers.cpp +++ b/libflashinfer/tests/hip/test_transpose_4x4_half_registers.cpp @@ -62,50 +62,40 @@ __device__ __forceinline__ void transpose_4x4_half_registers(uint32_t *R, // === ROUND 1: Exchange with neighbor (XOR with 1) === // T0↔T1, T2↔T3 partial exchange - uint32_t r0_exchanged = __shfl_xor(R[0], 0x1); - uint32_t r1_exchanged = __shfl_xor(R[1], 0x1); - - // Debug first exchange - if (lane_id == 0) { - uint32_t debug_regs[2] = {r0_exchanged, r1_exchanged}; - debug_print_registers("Round1-Exchange", lane_id, lane_in_group, - debug_regs, 2, 0); - } // Update based on thread position if (lane_in_group < 2) { + uint32_t r0_exchanged = __shfl_xor(R[0], 0x1); // Top half (T0, T1) update R[0] if (lane_in_group & 1) { // T1 - temp_regs[0] = (R[0] & 0xFFFF0000) | (r0_exchanged & 0x0000FFFF); + R[0] = (R[0] & 0x0000FFFF) | (r0_exchanged << 16); } else { // T0 - temp_regs[0] = (R[0] & 0x0000FFFF) | (r0_exchanged & 0xFFFF0000); + r0_exchanged >>= 16; + R[0] = (R[0] & 0xFFFF0000) | (r0_exchanged); } - // Keep R[1] unchanged - temp_regs[1] = R[1]; } else { + uint32_t r1_exchanged = __shfl_xor(R[1], 0x1); // Bottom half (T2, T3) update R[1] - if (lane_in_group & 1) { // T3 - temp_regs[1] = (R[1] & 0xFFFF0000) | (r1_exchanged & 0x0000FFFF); + if (lane_in_group & 1) { // T1 + R[1] = (R[1] & 0x0000FFFF) | (r1_exchanged << 16); } - else { // T2 - temp_regs[1] = (R[1] & 0x0000FFFF) | (r1_exchanged & 0xFFFF0000); + else { // T0 + R[1] = (R[1] & 0xFFFF0000) | (r1_exchanged >> 16); } - // Keep R[0] unchanged - temp_regs[0] = R[0]; } // Debug after first recombination - if (lane_id == 0) { - debug_print_registers("Round1-Exchange", lane_id, lane_in_group, - temp_regs, 2, 0); + if (lane_id == 3) { + debug_print_registers("After Round 1 shuffles", lane_id, lane_in_group, + R, 2, 0); } - +#if 0 // === ROUND 2: Exchange with one hop (XOR with 2) === // T0↔T2, T1↔T3 exchange R[0] and R[1] - uint32_t temp0_exchanged = __shfl_xor(temp_regs[0], 2); - uint32_t temp1_exchanged = __shfl_xor(temp_regs[1], 2); + uint32_t temp0_exchanged = __shfl_xor(R[0], 0x2); + uint32_t temp1_exchanged = __shfl_xor(R[1], 0x2); // Debug second exchange if (lane_id < 4) { @@ -177,6 +167,7 @@ __device__ __forceinline__ void transpose_4x4_half_registers(uint32_t *R, debug_print_registers("Final-Result", lane_id, lane_in_group, out, 2, 0); } +#endif } // Helper function to convert two uint16_t values to a single uint32_t From 6dd2063e5a57ccc7e07baa12da01cf02d41e6edd Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Thu, 7 Aug 2025 03:54:33 -0400 Subject: [PATCH 018/109] Working transpose test --- .../hip/test_transpose_4x4_half_registers.cpp | 206 ++++++++++-------- 1 file changed, 119 insertions(+), 87 deletions(-) diff --git a/libflashinfer/tests/hip/test_transpose_4x4_half_registers.cpp b/libflashinfer/tests/hip/test_transpose_4x4_half_registers.cpp index 8a7e837045..dedf960e1b 100644 --- a/libflashinfer/tests/hip/test_transpose_4x4_half_registers.cpp +++ b/libflashinfer/tests/hip/test_transpose_4x4_half_registers.cpp @@ -44,17 +44,75 @@ __device__ __forceinline__ void debug_print_registers(const char *stage, printf("]\n"); } +__device__ __forceinline__ void transpose_4x4_half_registers_opt(uint32_t *R, + uint32_t *out) +{ + // Calculate lane within 4-thread group + uint32_t lane_id = threadIdx.x % 64; + uint32_t lane_in_group = lane_id % 4; + + // === ROUND 1: Exchange with neighbor (XOR with 1) === + // Remove conditionals by using masks and bit operations + + // Exchange values with neighbor + uint32_t r0_exchanged = __shfl_xor(R[0], 0x1); + uint32_t r1_exchanged = __shfl_xor(R[1], 0x1); + + // Register selection mask: 0xFFFFFFFF for R[0] if lane<2, 0 otherwise + uint32_t r0_mask = ~((lane_in_group >> 1) * 0xFFFFFFFF); + // Register selection mask: 0xFFFFFFFF for R[1] if lane>=2, 0 otherwise + uint32_t r1_mask = (lane_in_group >> 1) * 0xFFFFFFFF; + + // Bit selection based on odd/even thread + uint32_t shift = (lane_in_group & 1) * 16; + uint32_t keep_mask = + 0xFFFF0000 >> shift; // 0xFFFF0000 for even, 0x0000FFFF for odd + + // Apply the masks to merge old and exchanged values + R[0] = (R[0] & keep_mask & r0_mask) | + ((lane_in_group & 1) ? ((r0_exchanged & 0xFFFF) << 16) & r0_mask + : ((r0_exchanged >> 16) & 0xFFFF) & r0_mask); + + R[1] = (R[1] & keep_mask & r1_mask) | + ((lane_in_group & 1) ? ((r1_exchanged & 0xFFFF) << 16) & r1_mask + : ((r1_exchanged >> 16) & 0xFFFF) & r1_mask); + + // === ROUND 2: Exchange with one hop (XOR with 2) === + uint32_t temp0_exchanged = __shfl_xor(R[0], 0x2); + uint32_t temp1_exchanged = __shfl_xor(R[1], 0x2); + + // Use predicated assignments instead of if-else + uint32_t r0_old = R[0]; + uint32_t r1_old = R[1]; + + // Branchless swap using masks + R[0] = (lane_in_group < 2) ? r0_old : temp1_exchanged; + R[1] = (lane_in_group < 2) ? temp0_exchanged : r1_old; + + // === ROUND 3: Exchange with neighbor again (XOR with 1) === + r0_exchanged = __shfl_xor(R[0], 0x1); + r1_exchanged = __shfl_xor(R[1], 0x1); + + // Swap register selectors for round 3 (inverse of round 1) + r0_mask = (lane_in_group >> 1) * 0xFFFFFFFF; + r1_mask = ~((lane_in_group >> 1) * 0xFFFFFFFF); + + // Apply the masks to merge old and exchanged values (same logic as round 1) + out[0] = (R[0] & keep_mask & r0_mask) | + ((lane_in_group & 1) ? ((r0_exchanged & 0xFFFF) << 16) & r0_mask + : ((r0_exchanged >> 16) & 0xFFFF) & r0_mask); + + out[1] = (R[1] & keep_mask & r1_mask) | + ((lane_in_group & 1) ? ((r1_exchanged & 0xFFFF) << 16) & r1_mask + : ((r1_exchanged >> 16) & 0xFFFF) & r1_mask); +} + __device__ __forceinline__ void transpose_4x4_half_registers(uint32_t *R, uint32_t *out) { - // Each thread has 4 half-precision values in 2 registers: - // R[0] = [B[lane_id][0], B[lane_id][1]] - // R[1] = [B[lane_id][2], B[lane_id][3]] - // Calculate lane within 4-thread group uint32_t lane_id = threadIdx.x % 64; uint32_t lane_in_group = lane_id % 4; - uint32_t temp_regs[2]; if (lane_id == 0) { debug_print_registers("Initial", lane_id, lane_in_group, R, 2, 0); @@ -71,8 +129,7 @@ __device__ __forceinline__ void transpose_4x4_half_registers(uint32_t *R, R[0] = (R[0] & 0x0000FFFF) | (r0_exchanged << 16); } else { // T0 - r0_exchanged >>= 16; - R[0] = (R[0] & 0xFFFF0000) | (r0_exchanged); + R[0] = (R[0] & 0xFFFF0000) | (r0_exchanged >> 16); } } else { @@ -91,83 +148,57 @@ __device__ __forceinline__ void transpose_4x4_half_registers(uint32_t *R, debug_print_registers("After Round 1 shuffles", lane_id, lane_in_group, R, 2, 0); } -#if 0 + // === ROUND 2: Exchange with one hop (XOR with 2) === // T0↔T2, T1↔T3 exchange R[0] and R[1] uint32_t temp0_exchanged = __shfl_xor(R[0], 0x2); uint32_t temp1_exchanged = __shfl_xor(R[1], 0x2); - // Debug second exchange - if (lane_id < 4) { - uint32_t debug_regs[2] = {temp0_exchanged, temp1_exchanged}; - debug_print_registers("Round2-Exchange", lane_id, lane_in_group, - debug_regs, 2, 0); - } - // Swap entire registers based on thread position if (lane_in_group < 2) { - // Top threads (T0, T1) get R[0] from partner, keep own R[1] - temp_regs[0] = temp0_exchanged; - // Keep R[1] unchanged + R[1] = temp0_exchanged; } else { // Bottom threads (T2, T3) get R[1] from partner, keep own R[0] - temp_regs[1] = temp1_exchanged; - // Keep R[0] unchanged + R[0] = temp1_exchanged; } - // Debug after second recombination - if (lane_id < 4) { - debug_print_registers("Round2-Result", lane_id, lane_in_group, - temp_regs, 2, 0); + if (lane_id == 0) { + debug_print_registers("After Round 2 shuffles", lane_id, lane_in_group, + R, 2, 0); } // === ROUND 3: Exchange with neighbor again (XOR with 1) === // T0↔T1, T2↔T3 exchange remaining parts - uint32_t final0_exchanged = __shfl_xor(temp_regs[0], 1); - uint32_t final1_exchanged = __shfl_xor(temp_regs[1], 1); - - // Debug third exchange - if (lane_id < 4) { - uint32_t debug_regs[2] = {final0_exchanged, final1_exchanged}; - debug_print_registers("Round3-Exchange", lane_id, lane_in_group, - debug_regs, 2, 0); - } - // Final combination based on thread position if (lane_in_group < 2) { - // Top half (T0, T1) update R[1] + uint32_t r1_exchanged = __shfl_xor(R[1], 0x1); + // Top half (T0, T1) update R[0] if (lane_in_group & 1) { // T1 - out[1] = - (temp_regs[1] & 0xFFFF0000) | (final1_exchanged & 0x0000FFFF); + R[1] = (R[1] & 0x0000FFFF) | (r1_exchanged << 16); } else { // T0 - out[1] = - (temp_regs[1] & 0x0000FFFF) | (final1_exchanged & 0xFFFF0000); + R[1] = (R[1] & 0xFFFF0000) | (r1_exchanged >> 16); } - // Keep R[0] unchanged - out[0] = temp_regs[0]; } else { - // Bottom half (T2, T3) update R[0] - if (lane_in_group & 1) { // T3 - out[0] = - (temp_regs[0] & 0xFFFF0000) | (final0_exchanged & 0x0000FFFF); + uint32_t r1_exchanged = __shfl_xor(R[0], 0x1); + // Bottom half (T2, T3) update R[1] + if (lane_in_group & 1) { // T1 + R[0] = (R[0] & 0x0000FFFF) | (r1_exchanged << 16); } - else { // T2 - out[0] = - (temp_regs[0] & 0x0000FFFF) | (final0_exchanged & 0xFFFF0000); + else { // T0 + R[0] = (R[0] & 0xFFFF0000) | (r1_exchanged >> 16); } - // Keep R[1] unchanged - out[1] = temp_regs[1]; } - // Debug final result - if (lane_id < 4) { - debug_print_registers("Final-Result", lane_id, lane_in_group, out, 2, - 0); + if (lane_id == 3) { + debug_print_registers("After Round 2 shuffles", lane_id, lane_in_group, + R, 2, 0); } -#endif + + out[0] = R[0]; + out[1] = R[1]; } // Helper function to convert two uint16_t values to a single uint32_t @@ -208,6 +239,7 @@ __global__ void test_transpose_kernel(uint16_t *output) // Call the transpose function uint32_t out[2]; transpose_4x4_half_registers(R, out); + // transpose_4x4_half_registers_opt(R, out); // Unpack the transposed results uint16_t transposed[4]; @@ -248,38 +280,38 @@ int main() std::cout << "Testing matrix transposition with shuffle operations..." << std::endl; - // for (int group = 0; group < num_threads / 4; group++) { - // std::cout << "\nGroup " << group << " results:" << std::endl; - - // for (int lane = 0; lane < 4; lane++) { - // int thread_idx = group * 4 + lane; - - // // Print original values - // std::cout << "Thread " << thread_idx << " original: "; - // for (int i = 0; i < 4; i++) { - // std::cout << h_output[thread_idx * 8 + i] << " "; - // } - // std::cout << std::endl; - - // // Print and verify transposed values - // std::cout << "Thread " << thread_idx << " transposed: "; - // for (int i = 0; i < 4; i++) { - // uint16_t actual = h_output[thread_idx * 8 + 4 + i]; - // std::cout << actual << " "; - - // // Expected after transpose: Thread N gets column N - // // Thread 0 should have [0*100+0, 1*100+0, 2*100+0, 3*100+0] - // // Thread 1 should have [0*100+1, 1*100+1, 2*100+1, 3*100+1] - // uint16_t expected = i * 100 + lane; - - // if (actual != expected) { - // success = false; - // std::cout << "(Expected: " << expected << ") "; - // } - // } - // std::cout << std::endl; - // } - // } + for (int group = 0; group < num_threads / 4; group++) { + std::cout << "\nGroup " << group << " results:" << std::endl; + + for (int lane = 0; lane < 4; lane++) { + int thread_idx = group * 4 + lane; + + // Print original values + std::cout << "Thread " << thread_idx << " original: "; + for (int i = 0; i < 4; i++) { + std::cout << h_output[thread_idx * 8 + i] << " "; + } + std::cout << std::endl; + + // Print and verify transposed values + std::cout << "Thread " << thread_idx << " transposed: "; + for (int i = 0; i < 4; i++) { + uint16_t actual = h_output[thread_idx * 8 + 4 + i]; + std::cout << actual << " "; + + // Expected after transpose: Thread N gets column N + // Thread 0 should have [0*100+0, 1*100+0, 2*100+0, 3*100+0] + // Thread 1 should have [0*100+1, 1*100+1, 2*100+1, 3*100+1] + uint16_t expected = i * 100 + lane; + + if (actual != expected) { + success = false; + std::cout << "(Expected: " << expected << ") "; + } + } + std::cout << std::endl; + } + } if (success) { std::cout << "\nTranspose test PASSED! All values correctly transposed." From 2ffa2b1546f2bca317490ae524fc2da314d9131f Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Thu, 7 Aug 2025 05:33:13 -0400 Subject: [PATCH 019/109] Optimizations for transposed loads --- .../hip/test_transpose_4x4_half_registers.cpp | 130 ++++++++++++------ 1 file changed, 88 insertions(+), 42 deletions(-) diff --git a/libflashinfer/tests/hip/test_transpose_4x4_half_registers.cpp b/libflashinfer/tests/hip/test_transpose_4x4_half_registers.cpp index dedf960e1b..762954377e 100644 --- a/libflashinfer/tests/hip/test_transpose_4x4_half_registers.cpp +++ b/libflashinfer/tests/hip/test_transpose_4x4_half_registers.cpp @@ -51,60 +51,106 @@ __device__ __forceinline__ void transpose_4x4_half_registers_opt(uint32_t *R, uint32_t lane_id = threadIdx.x % 64; uint32_t lane_in_group = lane_id % 4; - // === ROUND 1: Exchange with neighbor (XOR with 1) === - // Remove conditionals by using masks and bit operations - - // Exchange values with neighbor - uint32_t r0_exchanged = __shfl_xor(R[0], 0x1); - uint32_t r1_exchanged = __shfl_xor(R[1], 0x1); - - // Register selection mask: 0xFFFFFFFF for R[0] if lane<2, 0 otherwise - uint32_t r0_mask = ~((lane_in_group >> 1) * 0xFFFFFFFF); - // Register selection mask: 0xFFFFFFFF for R[1] if lane>=2, 0 otherwise - uint32_t r1_mask = (lane_in_group >> 1) * 0xFFFFFFFF; + if (lane_id == 0) { + debug_print_registers("Initial", lane_id, lane_in_group, R, 2, 0); + } - // Bit selection based on odd/even thread + // === ROUND 1: Exchange with neighbor (XOR with 1) === + // T0↔T1, T2↔T3 partial exchange + uint32_t reg_idx = (lane_in_group >> 1) & 0x1; + uint32_t exchanged_val = __shfl_xor(R[reg_idx], 0x1); uint32_t shift = (lane_in_group & 1) * 16; - uint32_t keep_mask = - 0xFFFF0000 >> shift; // 0xFFFF0000 for even, 0x0000FFFF for odd - - // Apply the masks to merge old and exchanged values - R[0] = (R[0] & keep_mask & r0_mask) | - ((lane_in_group & 1) ? ((r0_exchanged & 0xFFFF) << 16) & r0_mask - : ((r0_exchanged >> 16) & 0xFFFF) & r0_mask); + uint32_t keep_mask = 0xFFFF0000 >> shift; + int right_shift_amount = 16 * (1 - (lane_in_group & 1)); + int left_shift_amount = 16 * (lane_in_group & 1); + R[reg_idx] = (R[reg_idx] & keep_mask) | + ((exchanged_val >> right_shift_amount) << left_shift_amount); + + // if (lane_in_group & 1) { // Odd threads (1, 3) + // R[reg_idx] = (R[reg_idx] & keep_mask) | (exchanged_val << 16); + // } + // else { // Even threads (0, 2) + // R[reg_idx] = (R[reg_idx] & keep_mask) | (exchanged_val >> 16); + // } + + // // Update based on thread position + // if (lane_in_group < 2) { + // uint32_t r0_exchanged = __shfl_xor(R[0], 0x1); + // // Top half (T0, T1) update R[0] + // if (lane_in_group & 1) { // T1 + // R[0] = (R[0] & 0x0000FFFF) | (r0_exchanged << 16); + // } + // else { // T0 + // R[0] = (R[0] & 0xFFFF0000) | (r0_exchanged >> 16); + // } + // } + // else { + // uint32_t r1_exchanged = __shfl_xor(R[1], 0x1); + // // Bottom half (T2, T3) update R[1] + // if (lane_in_group & 1) { // T1 + // R[1] = (R[1] & 0x0000FFFF) | (r1_exchanged << 16); + // } + // else { // T0 + // R[1] = (R[1] & 0xFFFF0000) | (r1_exchanged >> 16); + // } + // } - R[1] = (R[1] & keep_mask & r1_mask) | - ((lane_in_group & 1) ? ((r1_exchanged & 0xFFFF) << 16) & r1_mask - : ((r1_exchanged >> 16) & 0xFFFF) & r1_mask); + // Debug after first recombination + if (lane_id == 03) { + debug_print_registers("After Round 1 shuffles", lane_id, lane_in_group, + R, 2, 0); + } // === ROUND 2: Exchange with one hop (XOR with 2) === + // T0↔T2, T1↔T3 exchange R[0] and R[1] uint32_t temp0_exchanged = __shfl_xor(R[0], 0x2); uint32_t temp1_exchanged = __shfl_xor(R[1], 0x2); - // Use predicated assignments instead of if-else - uint32_t r0_old = R[0]; - uint32_t r1_old = R[1]; + // Swap entire registers based on thread position + if (lane_in_group < 2) { + R[1] = temp0_exchanged; + } + else { + // Bottom threads (T2, T3) get R[1] from partner, keep own R[0] + R[0] = temp1_exchanged; + } - // Branchless swap using masks - R[0] = (lane_in_group < 2) ? r0_old : temp1_exchanged; - R[1] = (lane_in_group < 2) ? temp0_exchanged : r1_old; + if (lane_id == 0) { + debug_print_registers("After Round 2 shuffles", lane_id, lane_in_group, + R, 2, 0); + } // === ROUND 3: Exchange with neighbor again (XOR with 1) === - r0_exchanged = __shfl_xor(R[0], 0x1); - r1_exchanged = __shfl_xor(R[1], 0x1); + // T0↔T1, T2↔T3 exchange remaining parts - // Swap register selectors for round 3 (inverse of round 1) - r0_mask = (lane_in_group >> 1) * 0xFFFFFFFF; - r1_mask = ~((lane_in_group >> 1) * 0xFFFFFFFF); + if (lane_in_group < 2) { + uint32_t r1_exchanged = __shfl_xor(R[1], 0x1); + // Top half (T0, T1) update R[0] + if (lane_in_group & 1) { // T1 + R[1] = (R[1] & 0x0000FFFF) | (r1_exchanged << 16); + } + else { // T0 + R[1] = (R[1] & 0xFFFF0000) | (r1_exchanged >> 16); + } + } + else { + uint32_t r1_exchanged = __shfl_xor(R[0], 0x1); + // Bottom half (T2, T3) update R[1] + if (lane_in_group & 1) { // T1 + R[0] = (R[0] & 0x0000FFFF) | (r1_exchanged << 16); + } + else { // T0 + R[0] = (R[0] & 0xFFFF0000) | (r1_exchanged >> 16); + } + } - // Apply the masks to merge old and exchanged values (same logic as round 1) - out[0] = (R[0] & keep_mask & r0_mask) | - ((lane_in_group & 1) ? ((r0_exchanged & 0xFFFF) << 16) & r0_mask - : ((r0_exchanged >> 16) & 0xFFFF) & r0_mask); + if (lane_id == 3) { + debug_print_registers("After Round 2 shuffles", lane_id, lane_in_group, + R, 2, 0); + } - out[1] = (R[1] & keep_mask & r1_mask) | - ((lane_in_group & 1) ? ((r1_exchanged & 0xFFFF) << 16) & r1_mask - : ((r1_exchanged >> 16) & 0xFFFF) & r1_mask); + out[0] = R[0]; + out[1] = R[1]; } __device__ __forceinline__ void transpose_4x4_half_registers(uint32_t *R, @@ -238,8 +284,8 @@ __global__ void test_transpose_kernel(uint16_t *output) // Call the transpose function uint32_t out[2]; - transpose_4x4_half_registers(R, out); - // transpose_4x4_half_registers_opt(R, out); + // transpose_4x4_half_registers(R, out); + transpose_4x4_half_registers_opt(R, out); // Unpack the transposed results uint16_t transposed[4]; From b8e617ae60aa854d89abbff9db2e2f760e56b0f6 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Thu, 7 Aug 2025 10:41:01 -0400 Subject: [PATCH 020/109] transpose 4x4 test case. --- .../hip/test_transpose_4x4_half_registers.cpp | 111 +++--------------- 1 file changed, 16 insertions(+), 95 deletions(-) diff --git a/libflashinfer/tests/hip/test_transpose_4x4_half_registers.cpp b/libflashinfer/tests/hip/test_transpose_4x4_half_registers.cpp index 762954377e..9e40648324 100644 --- a/libflashinfer/tests/hip/test_transpose_4x4_half_registers.cpp +++ b/libflashinfer/tests/hip/test_transpose_4x4_half_registers.cpp @@ -4,8 +4,7 @@ #include // Define WARP_FULL_MASK for HIP -constexpr uint64_t WARP_FULL_MASK = - 0xffffffffffffffffULL; // 64-bit mask for HIP +constexpr uint64_t WARP_FULL_MASK = 0xffffffffffffffffULL; __device__ __forceinline__ void debug_print_registers(const char *stage, uint32_t lane_id, @@ -44,17 +43,12 @@ __device__ __forceinline__ void debug_print_registers(const char *stage, printf("]\n"); } -__device__ __forceinline__ void transpose_4x4_half_registers_opt(uint32_t *R, - uint32_t *out) +__device__ __forceinline__ void transpose_4x4_half_registers_opt(uint32_t *R) { // Calculate lane within 4-thread group uint32_t lane_id = threadIdx.x % 64; uint32_t lane_in_group = lane_id % 4; - if (lane_id == 0) { - debug_print_registers("Initial", lane_id, lane_in_group, R, 2, 0); - } - // === ROUND 1: Exchange with neighbor (XOR with 1) === // T0↔T1, T2↔T3 partial exchange uint32_t reg_idx = (lane_in_group >> 1) & 0x1; @@ -66,95 +60,27 @@ __device__ __forceinline__ void transpose_4x4_half_registers_opt(uint32_t *R, R[reg_idx] = (R[reg_idx] & keep_mask) | ((exchanged_val >> right_shift_amount) << left_shift_amount); - // if (lane_in_group & 1) { // Odd threads (1, 3) - // R[reg_idx] = (R[reg_idx] & keep_mask) | (exchanged_val << 16); - // } - // else { // Even threads (0, 2) - // R[reg_idx] = (R[reg_idx] & keep_mask) | (exchanged_val >> 16); - // } - - // // Update based on thread position - // if (lane_in_group < 2) { - // uint32_t r0_exchanged = __shfl_xor(R[0], 0x1); - // // Top half (T0, T1) update R[0] - // if (lane_in_group & 1) { // T1 - // R[0] = (R[0] & 0x0000FFFF) | (r0_exchanged << 16); - // } - // else { // T0 - // R[0] = (R[0] & 0xFFFF0000) | (r0_exchanged >> 16); - // } - // } - // else { - // uint32_t r1_exchanged = __shfl_xor(R[1], 0x1); - // // Bottom half (T2, T3) update R[1] - // if (lane_in_group & 1) { // T1 - // R[1] = (R[1] & 0x0000FFFF) | (r1_exchanged << 16); - // } - // else { // T0 - // R[1] = (R[1] & 0xFFFF0000) | (r1_exchanged >> 16); - // } - // } - - // Debug after first recombination - if (lane_id == 03) { - debug_print_registers("After Round 1 shuffles", lane_id, lane_in_group, - R, 2, 0); - } - // === ROUND 2: Exchange with one hop (XOR with 2) === // T0↔T2, T1↔T3 exchange R[0] and R[1] - uint32_t temp0_exchanged = __shfl_xor(R[0], 0x2); - uint32_t temp1_exchanged = __shfl_xor(R[1], 0x2); - // Swap entire registers based on thread position - if (lane_in_group < 2) { - R[1] = temp0_exchanged; - } - else { - // Bottom threads (T2, T3) get R[1] from partner, keep own R[0] - R[0] = temp1_exchanged; - } + uint32_t is_top = 1 - reg_idx; + uint32_t temp0 = __shfl_xor(R[0], 0x2); + uint32_t temp1 = __shfl_xor(R[1], 0x2); - if (lane_id == 0) { - debug_print_registers("After Round 2 shuffles", lane_id, lane_in_group, - R, 2, 0); - } + // Compute both possibilities and select + R[0] = R[0] * is_top + temp1 * reg_idx; + R[1] = temp0 * is_top + R[1] * reg_idx; // === ROUND 3: Exchange with neighbor again (XOR with 1) === // T0↔T1, T2↔T3 exchange remaining parts - if (lane_in_group < 2) { - uint32_t r1_exchanged = __shfl_xor(R[1], 0x1); - // Top half (T0, T1) update R[0] - if (lane_in_group & 1) { // T1 - R[1] = (R[1] & 0x0000FFFF) | (r1_exchanged << 16); - } - else { // T0 - R[1] = (R[1] & 0xFFFF0000) | (r1_exchanged >> 16); - } - } - else { - uint32_t r1_exchanged = __shfl_xor(R[0], 0x1); - // Bottom half (T2, T3) update R[1] - if (lane_in_group & 1) { // T1 - R[0] = (R[0] & 0x0000FFFF) | (r1_exchanged << 16); - } - else { // T0 - R[0] = (R[0] & 0xFFFF0000) | (r1_exchanged >> 16); - } - } - - if (lane_id == 3) { - debug_print_registers("After Round 2 shuffles", lane_id, lane_in_group, - R, 2, 0); - } - - out[0] = R[0]; - out[1] = R[1]; + reg_idx = 1 - reg_idx; + exchanged_val = __shfl_xor(R[reg_idx], 0x1); + R[reg_idx] = (R[reg_idx] & keep_mask) | + ((exchanged_val >> right_shift_amount) << left_shift_amount); } -__device__ __forceinline__ void transpose_4x4_half_registers(uint32_t *R, - uint32_t *out) +__device__ __forceinline__ void transpose_4x4_half_registers_naive(uint32_t *R) { // Calculate lane within 4-thread group uint32_t lane_id = threadIdx.x % 64; @@ -242,9 +168,6 @@ __device__ __forceinline__ void transpose_4x4_half_registers(uint32_t *R, debug_print_registers("After Round 2 shuffles", lane_id, lane_in_group, R, 2, 0); } - - out[0] = R[0]; - out[1] = R[1]; } // Helper function to convert two uint16_t values to a single uint32_t @@ -283,14 +206,12 @@ __global__ void test_transpose_kernel(uint16_t *output) R[1] = pack_half2(row_elements[2], row_elements[3]); // Call the transpose function - uint32_t out[2]; - // transpose_4x4_half_registers(R, out); - transpose_4x4_half_registers_opt(R, out); + transpose_4x4_half_registers_opt(R); // Unpack the transposed results uint16_t transposed[4]; - unpack_half2(out[0], transposed[0], transposed[1]); - unpack_half2(out[1], transposed[2], transposed[3]); + unpack_half2(R[0], transposed[0], transposed[1]); + unpack_half2(R[1], transposed[2], transposed[3]); // Write output - store both original and transposed values for verification for (int i = 0; i < 4; i++) { From 2c003914afc4e6922dcd98bcfd21029d294bb022 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Fri, 8 Aug 2025 12:25:50 -0400 Subject: [PATCH 021/109] various compilation error fixes --- .../attention/generic/permuted_smem.cuh | 14 ++++ .../flashinfer/attention/generic/prefill.cuh | 73 ++++++++++--------- .../include/gpu_iface/memory_ops.hpp | 24 +++--- libflashinfer/include/gpu_iface/mma_ops.hpp | 16 ++-- .../include/gpu_iface/vec_dtypes.hpp | 8 +- 5 files changed, 78 insertions(+), 57 deletions(-) diff --git a/libflashinfer/include/flashinfer/attention/generic/permuted_smem.cuh b/libflashinfer/include/flashinfer/attention/generic/permuted_smem.cuh index e9ab8b6292..6f71514872 100644 --- a/libflashinfer/include/flashinfer/attention/generic/permuted_smem.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/permuted_smem.cuh @@ -7,6 +7,7 @@ #define FLASHINFER_PERMUTED_SMEM_CUH_ #include "gpu_iface/memory_ops.hpp" +#include "gpu_iface/mma_ops.hpp" #include "gpu_iface/platform.hpp" #if 0 @@ -165,6 +166,19 @@ template struct smem_t #endif } + template + __device__ __forceinline__ void + load_fragment_4x4_transposed(uint32_t offset, T *frag) + { +#if defined(PLATFORM_HIP_DEVICE) + auto smem_t_ptr = reinterpret_cast(base + offset); + flashinfer::gpu_iface::mma::load_fragment_transpose_4x4_half_registers( + smem_t_ptr, frag); +#else + ldmatrix_m8n8x4(offset, frag); +#endif + } + template __device__ __forceinline__ void store_fragment(uint32_t offset, const T *frag) diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh index 089f9531f3..a7dca2b843 100644 --- a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -137,24 +137,25 @@ struct KernelTraits using AttentionVariant = AttentionVariant_; #if defined(PLATFORM_HIP_DEVICE) - static_assert( - sizeof(DTypeKV_) != 1, - "8-bit types not supported for CDNA3") static constexpr uint32_t - NUM_THREADS = NUM_WARPS_Q * NUM_WARPS_KV * 64; - constexpr uint32_t WARP_THREAD_COLS = 16; - constexpr uint32_t WARP_THREAD_ROWS = 4; - constexpr uint32_t HALF_ELEMS_PER_THREAD = 4; - constexpr uint32_t INT32_ELEMS_PER_THREAD = 2; - constexpr uint32_t VECTOR_BIT_WIDTH = HALF_ELEMS_PER_THREAD * 16; + static_assert(sizeof(DTypeKV_) != 1, "8-bit types not supported for CDNA3"); + + using SmemBasePtrTy = uint2; + static constexpr uint32_t NUM_THREADS = NUM_WARPS_Q * NUM_WARPS_KV * 64; + static constexpr uint32_t WARP_THREAD_COLS = 16; + static constexpr uint32_t WARP_THREAD_ROWS = 4; + static constexpr uint32_t HALF_ELEMS_PER_THREAD = 4; + static constexpr uint32_t INT32_ELEMS_PER_THREAD = 2; + static constexpr uint32_t VECTOR_BIT_WIDTH = HALF_ELEMS_PER_THREAD * 16; // FIXME: Update with a proper swizzle pattern. Linear is used primarily // for intial testing. static constexpr SwizzleMode SWIZZLE_MODE_Q = SwizzleMode::kLinear; static constexpr SwizzleMode SWIZZLE_MODE_KV = SwizzleMode::kLinear; - static constexpr SmemBasePtrTy = uint2; + // Presently we use 16x4 thread layout for all cases. static constexpr uint32_t KV_THR_LAYOUT_ROW = WARP_THREAD_ROWS; static constexpr uint32_t KV_THR_LAYOUT_COL = WARP_THREAD_COLS; #else + using SmemBasePtrTy = uint4; static constexpr uint32_t NUM_THREADS = NUM_WARPS_Q * NUM_WARPS_KV * 32; constexpr uint32_t WARP_THREAD_COLS = 8; constexpr uint32_t WARP_THREAD_ROWS = 4; @@ -166,7 +167,6 @@ struct KernelTraits static constexpr SwizzleMode SWIZZLE_MODE_KV = (sizeof(DTypeKV_) == 1 && HEAD_DIM_VO == 64) ? SwizzleMode::k64B : SwizzleMode::k128B; - static constexpr SmemBasePtrTy = uint4; static constexpr uint32_t KV_THR_LAYOUT_ROW = SWIZZLE_MODE_KV == SwizzleMode::k128B ? WARP_THREAD_ROWS : WARP_THREAD_COLS; @@ -357,12 +357,17 @@ q_frag_apply_llama_rope_with_pos(T *x_first_half, } template -__device__ __forceinline__ void produce_kv_helper_(uint32_t warp_idx, - uint32_t lane_idx) +__device__ __forceinline__ void +produce_kv_helper_(uint32_t warp_idx, + uint32_t lane_idx, + smem_t smem, + uint32_t *smem_offset, + typename KTraits::DTypeKV **gptr) { using DTypeKV = typename KTraits::DTypeKV; constexpr uint32_t WARP_THREAD_COLS = KTraits::WARP_THREAD_COLS; constexpr uint32_t WARP_THREAD_ROWS = KTraits::WARP_THREAD_ROWS; + constexpr uint32_t NUM_WARPS = KTraits::NUM_WARPS; constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; constexpr uint32_t NUM_WARPS_Q = KTraits::NUM_WARPS_Q; constexpr uint32_t NUM_MMA_D = @@ -432,11 +437,8 @@ produce_kv(smem_t smem, { // NOTE: for fp8, this function doesn't work for head_dim = 64 at the moment using DTypeKV = typename KTraits::DTypeKV; - constexpr uint32_t CTA_TILE_KV = KTraits::CTA_TILE_KV; constexpr uint32_t NUM_WARPS = KTraits::NUM_WARPS; constexpr uint32_t NUM_WARPS_Q = KTraits::NUM_WARPS_Q; - constexpr uint32_t NUM_MMA_D = - produce_v ? KTraits::NUM_MMA_D_VO : KTraits::NUM_MMA_D_QK; constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; constexpr uint32_t UPCAST_STRIDE = produce_v ? KTraits::UPCAST_STRIDE_V : KTraits::UPCAST_STRIDE_K; @@ -444,13 +446,13 @@ produce_kv(smem_t smem, lane_idx = tid.x; if constexpr (KTraits::SWIZZLE_MODE_KV == SwizzleMode::k128B) { - produce_kv_helper_(uint32_t warp_idx, - uint32_t lane_idx) + produce_kv_helper_( + warp_idx, lane_idx, smem, smem_offset, gptr); } #if defined(PLATFORM_HIP_DEVICE) else if constexpr (KTraits::SWIZZLE_MODE_KV == SwizzleMode::kLinear) { - produce_kv_helper_(uint32_t warp_idx, - uint32_t lane_idx) + produce_kv_helper_( + warp_idx, lane_idx, smem, smem_offset, gptr); } #endif else { @@ -637,8 +639,8 @@ load_q_global_smem(uint32_t packed_offset, const uint32_t lane_idx = tid.x, warp_idx_x = get_warp_idx_q(tid.y); - uint32_t row = tid / WARP_THREAD_COLS; - uint32_t col = tid % WARP_THREAD_COLS; + uint32_t row = lane_idx / WARP_THREAD_COLS; + uint32_t col = lane_idx % WARP_THREAD_COLS; if (get_warp_idx_kv(tid.z) == 0) { uint32_t q_smem_offset_w = @@ -904,7 +906,7 @@ compute_qk(smem_t *q_smem, { constexpr uint32_t UPCAST_STRIDE_Q = KTraits::UPCAST_STRIDE_Q; constexpr uint32_t UPCAST_STRIDE_K = KTraits::UPCAST_STRIDE_K; - constexpr uint32_t Q_SMEM_COLUMN_ADVANCE = + constexpr uint32_t QK_SMEM_COLUMN_ADVANCE = 16 / KTraits::HALF_ELEMS_PER_THREAD; uint32_t a_frag[KTraits::NUM_MMA_Q][KTraits::INT32_ELEMS_PER_THREAD], @@ -914,14 +916,14 @@ compute_qk(smem_t *q_smem, for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_QK; ++mma_d) { #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { - q_smem->ldmatrix_m8n8x4(*q_smem_offset_r, a_frag[mma_q]); + q_smem->load_fragment(*q_smem_offset_r, a_frag[mma_q]); *q_smem_offset_r = q_smem->template advance_offset_by_row<16, UPCAST_STRIDE_Q>( *q_smem_offset_r); } *q_smem_offset_r = - q_smem->template advance_offset_by_column( + q_smem->template advance_offset_by_column( *q_smem_offset_r, mma_d) - KTraits::NUM_MMA_Q * 16 * UPCAST_STRIDE_Q; @@ -948,7 +950,7 @@ compute_qk(smem_t *q_smem, (typename KTraits::DTypeKV *)b_frag_f8); } else { - k_smem->ldmatrix_m8n8x4(*k_smem_offset_r, b_frag); + k_smem->load_fragment_4x4_transposed(*k_smem_offset_r, b_frag); } *k_smem_offset_r = k_smem->template advance_offset_by_row<16, UPCAST_STRIDE_K>( @@ -992,18 +994,20 @@ compute_qk(smem_t *q_smem, } if constexpr (sizeof(typename KTraits::DTypeKV) == 1) { if (mma_d % 2 == 1) { - *k_smem_offset_r = k_smem->template advance_offset_by_column<2>( - *k_smem_offset_r, mma_d / 2); + *k_smem_offset_r = k_smem->template advance_offset_by_column< + QK_SMEM_COLUMN_ADVANCE>(*k_smem_offset_r, mma_d / 2); } *k_smem_offset_r -= KTraits::NUM_MMA_KV * 16 * UPCAST_STRIDE_K; } else { - *k_smem_offset_r = k_smem->template advance_offset_by_column<2>( - *k_smem_offset_r, mma_d) - - KTraits::NUM_MMA_KV * 16 * UPCAST_STRIDE_K; + *k_smem_offset_r = + k_smem + ->template advance_offset_by_column( + *k_smem_offset_r, mma_d) - + KTraits::NUM_MMA_KV * 16 * UPCAST_STRIDE_K; } } - *q_smem_offset_r -= KTraits::NUM_MMA_D_QK * Q_SMEM_COLUMN_ADVANCE; + *q_smem_offset_r -= KTraits::NUM_MMA_D_QK * QK_SMEM_COLUMN_ADVANCE; *k_smem_offset_r -= KTraits::NUM_MMA_D_QK * sizeof(typename KTraits::DTypeKV); } @@ -1937,8 +1941,9 @@ SinglePrefillWithKVCacheDevice(const Params params, uint32_t k_smem_offset_r = k_smem.template get_permuted_offset( get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + - 8 * (lane_idx / 16) + lane_idx % 8, - (lane_idx % 16) / 8), + HALF_ELEMS_PER_THREAD * (lane_idx / 16) + + lane_idx % HALF_ELEMS_PER_THREAD, + (lane_idx % 16) / HALF_ELEMS_PER_THREAD), v_smem_offset_r = v_smem.template get_permuted_offset( get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + diff --git a/libflashinfer/include/gpu_iface/memory_ops.hpp b/libflashinfer/include/gpu_iface/memory_ops.hpp index c811d0b501..94048502cd 100644 --- a/libflashinfer/include/gpu_iface/memory_ops.hpp +++ b/libflashinfer/include/gpu_iface/memory_ops.hpp @@ -33,16 +33,16 @@ enum class PrefetchMode // Include platform-specific implementations #if defined(PLATFORM_CUDA_DEVICE) #include "backend/cuda/memory_ops.cuh" -namespace detail = flashinfer::gpu_iface::memory::detail::cuda; +namespace mem_detail = flashinfer::gpu_iface::memory::detail::cuda; #elif defined(PLATFORM_HIP_DEVICE) #include "backend/hip/memory_ops_hip.h" -namespace detail = flashinfer::gpu_iface::memory::detail::hip; +namespace mem_detail = flashinfer::gpu_iface::memory::detail::hip; #endif /** * @brief Commits pending asynchronous memory operations to a group */ -__device__ __forceinline__ void commit_group() { detail::commit_group(); } +__device__ __forceinline__ void commit_group() { mem_detail::commit_group(); } /** * @brief Waits until N most recent groups of async operations are complete @@ -51,7 +51,7 @@ __device__ __forceinline__ void commit_group() { detail::commit_group(); } */ template __device__ __forceinline__ void wait_group() { - detail::wait_group(); + mem_detail::wait_group(); } /** @@ -65,14 +65,14 @@ template __device__ __forceinline__ void wait_group() template __device__ __forceinline__ void load_128b(T *smem_ptr, const T *gmem_ptr) { - detail::load_128b(smem_ptr, gmem_ptr); + mem_detail::load_128b(smem_ptr, gmem_ptr); } template __device__ __forceinline__ void load_64b(T *smem_ptr, const T *gmem_ptr) { #if defined(PLATFORM_HIP_DEVICE) - detail::load_64b(smem_ptr, gmem_ptr); + mem_detail::load_64b(smem_ptr, gmem_ptr); #else #error "load_64b not implemented for this platform" #endif @@ -92,7 +92,8 @@ template __device__ __forceinline__ void pred_load_128b(T *smem_ptr, const T *gmem_ptr, bool predicate) { - detail::pred_load_128b(smem_ptr, gmem_ptr, predicate); + mem_detail::pred_load_128b(smem_ptr, gmem_ptr, + predicate); } template @@ -100,7 +101,8 @@ __device__ __forceinline__ void pred_load_64b(T *smem_ptr, const T *gmem_ptr, bool predicate) { #if defined(PLATFORM_HIP_DEVICE) - detail::pred_load_64b(smem_ptr, gmem_ptr, predicate); + mem_detail::pred_load_64b(smem_ptr, gmem_ptr, + predicate); #else #error "pred_load_64b not implemented for this platform" #endif @@ -118,7 +120,7 @@ pred_load_64b(T *smem_ptr, const T *gmem_ptr, bool predicate) template __device__ __forceinline__ void load(T *smem_ptr, const T *gmem_ptr) { - detail::load(smem_ptr, gmem_ptr); + mem_detail::load(smem_ptr, gmem_ptr); } /** @@ -139,8 +141,8 @@ template (smem_ptr, gmem_ptr, - predicate); + mem_detail::pred_load(smem_ptr, gmem_ptr, + predicate); } } // namespace memory diff --git a/libflashinfer/include/gpu_iface/mma_ops.hpp b/libflashinfer/include/gpu_iface/mma_ops.hpp index 890d838bb3..113c0aa7f6 100644 --- a/libflashinfer/include/gpu_iface/mma_ops.hpp +++ b/libflashinfer/include/gpu_iface/mma_ops.hpp @@ -9,10 +9,10 @@ // Include platform-specific implementations #if defined(PLATFORM_CUDA_DEVICE) #include "backend/cuda/mma.cuh" -namespace detail = flashinfer::gpu_iface::mma_impl::cuda; +namespace mma_detail = flashinfer::gpu_iface::mma_impl::cuda; #elif defined(PLATFORM_HIP_DEVICE) #include "backend/hip/mma_hip.h" -namespace detail = flashinfer::gpu_iface::mma_impl::hip; +namespace mma_detail = flashinfer::gpu_iface::mma_impl::hip; #endif namespace flashinfer @@ -34,24 +34,24 @@ namespace mma template __device__ __forceinline__ void load_fragment(uint32_t *R, const T *smem_ptr) { - detail::load_fragment(R, smem_ptr); + mma_detail::load_fragment(R, smem_ptr); } template __device__ __forceinline__ void load_fragment_transpose(uint32_t *R, const T *smem_ptr, uint32_t stride) { - detail::load_fragment_transpose(R, smem_ptr, stride); + mma_detail::load_fragment_transpose(R, smem_ptr, stride); } #if defined(PLATFORM_HIP_DEVICE) && defined(__gfx942__) template __device__ __forceinline__ void -load_fragment_transpose_4x4_half_registers(uint32_t *R, const T *smem_ptr) +load_fragment_transpose_4x4_half_registers(const T *smem_ptr, uint32_t *R) { - static_assert(std::is_same::value, + static_assert(std::is_same::value, "Only __half is supported for the 4x4 register transpose"); - detail::load_fragment_4x4_half_registers(R, smem_ptr); + mma_detail::load_fragment_4x4_half_registers(R, smem_ptr); } #endif @@ -69,7 +69,7 @@ __device__ __forceinline__ void amdgcn_mfma_fp32_16x16x16fp16(float *C, uint32_t *A, uint32_t *B) { #if defined(PLATFORM_HIP_DEVICE) - detail::amdgcn_mfma_fp32_16x16x16fp16(C, A, B); + mma_detail::amdgcn_mfma_fp32_16x16x16fp16(C, A, B); #else FLASHINFER_RUNTIME_ASSERT( "MMA f16f16f32 not supported on this architecture"); diff --git a/libflashinfer/include/gpu_iface/vec_dtypes.hpp b/libflashinfer/include/gpu_iface/vec_dtypes.hpp index 3a92de2c05..b769286a94 100644 --- a/libflashinfer/include/gpu_iface/vec_dtypes.hpp +++ b/libflashinfer/include/gpu_iface/vec_dtypes.hpp @@ -17,17 +17,17 @@ namespace vec_dtypes // Include the appropriate backend implementation #if defined(PLATFORM_CUDA_DEVICE) #include "backend/cuda/vec_dtypes.cuh" -namespace detail = flashinfer::gpu_iface::vec_dtypes::detail::cuda; +namespace vec_t_detail = flashinfer::gpu_iface::vec_dtypes::detail::cuda; #elif defined(PLATFORM_HIP_DEVICE) #include "backend/hip/vec_dtypes_hip.h" #define HIP_ENABLE_WARP_SYNC_BUILTINS 1 -namespace detail_t = flashinfer::gpu_iface::vec_dtypes::detail::hip; +namespace vec_t_detail = flashinfer::gpu_iface::vec_dtypes::detail::hip; #endif // Re-export types and functions from the appropriate backend // This allows code to use flashinfer::gpu_iface::vec_dtypes::vec_t -using detail_t::vec_cast; -using detail_t::vec_t; +using vec_t_detail::vec_cast; +using vec_t_detail::vec_t; } // namespace vec_dtypes } // namespace gpu_iface From c418cdc17f38e60b22cebf4b83a35bcefd56dd04 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Fri, 8 Aug 2025 13:25:25 -0400 Subject: [PATCH 022/109] Compilation fixes --- .../flashinfer/attention/generic/prefill.cuh | 208 +++++++++--------- .../include/gpu_iface/backend/hip/mma_hip.h | 2 +- 2 files changed, 111 insertions(+), 99 deletions(-) diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh index a7dca2b843..3cd943a778 100644 --- a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -356,13 +356,16 @@ q_frag_apply_llama_rope_with_pos(T *x_first_half, } } -template -__device__ __forceinline__ void -produce_kv_helper_(uint32_t warp_idx, - uint32_t lane_idx, - smem_t smem, - uint32_t *smem_offset, - typename KTraits::DTypeKV **gptr) +template +__device__ __forceinline__ void produce_kv_helper_( + uint32_t warp_idx, + uint32_t lane_idx, + smem_t smem, + uint32_t *smem_offset, + typename KTraits::DTypeKV **gptr, + const uint32_t stride_n, + const uint32_t kv_idx_base, + const uint32_t kv_len) { using DTypeKV = typename KTraits::DTypeKV; constexpr uint32_t WARP_THREAD_COLS = KTraits::WARP_THREAD_COLS; @@ -426,14 +429,14 @@ produce_kv_helper_(uint32_t warp_idx, * \param kv_len The length of kv tensor. */ template -__device__ __forceinline__ void -produce_kv(smem_t smem, - uint32_t *smem_offset, - typename KTraits::DTypeKV **gptr, - const uint32_t stride_n, - const uint32_t kv_idx_base, - const uint32_t kv_len, - const dim3 tid = threadIdx) +__device__ __forceinline__ void produce_kv( + smem_t smem, + uint32_t *smem_offset, + typename KTraits::DTypeKV **gptr, + const uint32_t stride_n, + const uint32_t kv_idx_base, + const uint32_t kv_len, + const dim3 tid = threadIdx) { // NOTE: for fp8, this function doesn't work for head_dim = 64 at the moment using DTypeKV = typename KTraits::DTypeKV; @@ -446,13 +449,15 @@ produce_kv(smem_t smem, lane_idx = tid.x; if constexpr (KTraits::SWIZZLE_MODE_KV == SwizzleMode::k128B) { - produce_kv_helper_( - warp_idx, lane_idx, smem, smem_offset, gptr); + produce_kv_helper_( + warp_idx, lane_idx, smem, smem_offset, gptr, stride_n, kv_idx_base, + kv_len); } #if defined(PLATFORM_HIP_DEVICE) else if constexpr (KTraits::SWIZZLE_MODE_KV == SwizzleMode::kLinear) { - produce_kv_helper_( - warp_idx, lane_idx, smem, smem_offset, gptr); + produce_kv_helper_( + warp_idx, lane_idx, smem, smem_offset, gptr, stride_n, kv_idx_base, + kv_len); } #endif else { @@ -475,15 +480,15 @@ produce_kv(smem_t smem, } template -__device__ __forceinline__ void -page_produce_kv(smem_t smem, - uint32_t *smem_offset, - const paged_kv_t &paged_kv, - const uint32_t kv_idx_base, - const size_t *thr_local_kv_offset, - const uint32_t kv_len, - const dim3 tid = threadIdx) +__device__ __forceinline__ void page_produce_kv( + smem_t smem, + uint32_t *smem_offset, + const paged_kv_t + &paged_kv, + const uint32_t kv_idx_base, + const size_t *thr_local_kv_offset, + const uint32_t kv_len, + const dim3 tid = threadIdx) { // NOTE: for fp8, this function doesn't work for head_dim = 64 at the moment using DType = typename KTraits::DTypeKV; @@ -615,20 +620,22 @@ __device__ __forceinline__ void init_states( } template -__device__ __forceinline__ void -load_q_global_smem(uint32_t packed_offset, - const uint32_t qo_upper_bound, - typename KTraits::DTypeQ *q_ptr_base, - const uint32_t q_stride_n, - const uint32_t q_stride_h, - const uint_fastdiv group_size, - smem_t *q_smem, - const dim3 tid = threadIdx) +__device__ __forceinline__ void load_q_global_smem( + uint32_t packed_offset, + const uint32_t qo_upper_bound, + typename KTraits::DTypeQ *q_ptr_base, + const uint32_t q_stride_n, + const uint32_t q_stride_h, + const uint_fastdiv group_size, + smem_t *q_smem, + const dim3 tid = threadIdx) { using DTypeQ = typename KTraits::DTypeQ; constexpr uint32_t WARP_THREAD_COLS = KTraits::WARP_THREAD_COLS; constexpr uint32_t WARP_THREAD_ROWS = KTraits::WARP_THREAD_ROWS; constexpr uint32_t HALF_ELEMS_PER_THREAD = KTraits::HALF_ELEMS_PER_THREAD; + constexpr uint32_t NUM_MMA_D_QK = KTraits::NUM_MMA_D_QK; + constexpr uint32_t UPCAST_STRIDE_Q = KTraits::UPCAST_STRIDE_Q; #if defined(PLATFORM_HIP_DEVICE) constexpr uint32_t COLUMN_RESET_OFFSET = @@ -690,7 +697,7 @@ __device__ __forceinline__ void q_smem_inplace_apply_rotary( const uint32_t qo_len, const uint32_t kv_len, const uint_fastdiv group_size, - smem_t *q_smem, + smem_t *q_smem, uint32_t *q_smem_offset_r, float (*rope_freq)[4], const dim3 tid = threadIdx) @@ -749,7 +756,7 @@ template __device__ __forceinline__ void q_smem_inplace_apply_rotary_with_pos( const uint32_t q_packed_idx_base, const typename KTraits::IdType *q_rope_offset, - smem_t *q_smem, + smem_t *q_smem, const uint_fastdiv group_size, uint32_t *q_smem_offset_r, float (*rope_freq)[4], @@ -797,12 +804,12 @@ __device__ __forceinline__ void q_smem_inplace_apply_rotary_with_pos( } template -__device__ __forceinline__ void -k_smem_inplace_apply_rotary(const uint32_t kv_idx_base, - smem_t *k_smem, - uint32_t *k_smem_offset_r, - float (*rope_freq)[4], - const dim3 tid = threadIdx) +__device__ __forceinline__ void k_smem_inplace_apply_rotary( + const uint32_t kv_idx_base, + smem_t *k_smem, + uint32_t *k_smem_offset_r, + float (*rope_freq)[4], + const dim3 tid = threadIdx) { using DTypeKV = typename KTraits::DTypeKV; static_assert(sizeof(DTypeKV) == 2); @@ -897,12 +904,13 @@ k_smem_inplace_apply_rotary(const uint32_t kv_idx_base, } template -__device__ __forceinline__ void -compute_qk(smem_t *q_smem, - uint32_t *q_smem_offset_r, - smem_t *k_smem, - uint32_t *k_smem_offset_r, - typename KTraits::DTypeQKAccum (*s_frag)[KTraits::NUM_MMA_KV][8]) +__device__ __forceinline__ void compute_qk( + smem_t *q_smem, + uint32_t *q_smem_offset_r, + smem_t *k_smem, + uint32_t *k_smem_offset_r, + typename KTraits::DTypeQKAccum ( + *s_frag)[KTraits::NUM_MMA_KV][KTraits::HALF_ELEMS_PER_THREAD]) { constexpr uint32_t UPCAST_STRIDE_Q = KTraits::UPCAST_STRIDE_Q; constexpr uint32_t UPCAST_STRIDE_K = KTraits::UPCAST_STRIDE_K; @@ -1013,18 +1021,18 @@ compute_qk(smem_t *q_smem, } template -__device__ __forceinline__ void -logits_transform(const Params ¶ms, - typename KTraits::AttentionVariant variant, - const uint32_t batch_idx, - const uint32_t qo_packed_idx_base, - const uint32_t kv_idx_base, - const uint32_t qo_len, - const uint32_t kv_len, - const uint_fastdiv group_size, - DTypeQKAccum (*s_frag)[KTraits::NUM_MMA_KV][8], - const dim3 tid = threadIdx, - const uint32_t kv_head_idx = blockIdx.z) +__device__ __forceinline__ void logits_transform( + const Params ¶ms, + typename KTraits::AttentionVariant variant, + const uint32_t batch_idx, + const uint32_t qo_packed_idx_base, + const uint32_t kv_idx_base, + const uint32_t qo_len, + const uint32_t kv_len, + const uint_fastdiv group_size, + DTypeQKAccum (*s_frag)[KTraits::NUM_MMA_KV][KTraits::HALF_ELEMS_PER_THREAD], + const dim3 tid = threadIdx, + const uint32_t kv_head_idx = blockIdx.z) { const uint32_t lane_idx = tid.x; uint32_t q[KTraits::NUM_MMA_Q][2], r[KTraits::NUM_MMA_Q][2]; @@ -1098,7 +1106,8 @@ logits_mask(const Params ¶ms, const uint32_t kv_len, const uint32_t chunk_end, const uint_fastdiv group_size, - typename KTraits::DTypeQKAccum (*s_frag)[KTraits::NUM_MMA_KV][8], + typename KTraits::DTypeQKAccum ( + *s_frag)[KTraits::NUM_MMA_KV][KTraits::HALF_ELEMS_PER_THREAD], const dim3 tid = threadIdx, const uint32_t kv_head_idx = blockIdx.z) { @@ -1148,8 +1157,9 @@ logits_mask(const Params ¶ms, template __device__ __forceinline__ void update_mdo_states( typename KTraits::AttentionVariant variant, - typename KTraits::DTypeQKAccum (*s_frag)[KTraits::NUM_MMA_KV][8], - float (*o_frag)[KTraits::NUM_MMA_D_VO][8], + typename KTraits::DTypeQKAccum ( + *s_frag)[KTraits::NUM_MMA_KV][KTraits::HALF_ELEMS_PER_THREAD], + float (*o_frag)[KTraits::NUM_MMA_D_VO][KTraits::HALF_ELEMS_PER_THREAD], typename KTraits::DTypeQKAccum (*m)[2], float (*d)[2]) { @@ -1284,12 +1294,13 @@ __device__ __forceinline__ void update_mdo_states( } template -__device__ __forceinline__ void -compute_sfm_v(smem_t *v_smem, - uint32_t *v_smem_offset_r, - typename KTraits::DTypeQKAccum (*s_frag)[KTraits::NUM_MMA_KV][8], - float (*o_frag)[KTraits::NUM_MMA_D_VO][8], - float (*d)[2]) +__device__ __forceinline__ void compute_sfm_v( + smem_t *v_smem, + uint32_t *v_smem_offset_r, + typename KTraits::DTypeQKAccum ( + *s_frag)[KTraits::NUM_MMA_KV][KTraits::HALF_ELEMS_PER_THREAD], + float (*o_frag)[KTraits::NUM_MMA_D_VO][KTraits::HALF_ELEMS_PER_THREAD], + float (*d)[2]) { constexpr uint32_t UPCAST_STRIDE_V = KTraits::UPCAST_STRIDE_V; @@ -1400,10 +1411,10 @@ compute_sfm_v(smem_t *v_smem, } template -__device__ __forceinline__ void -normalize_d(float (*o_frag)[KTraits::NUM_MMA_D_VO][8], - typename KTraits::DTypeQKAccum (*m)[2], - float (*d)[2]) +__device__ __forceinline__ void normalize_d( + float (*o_frag)[KTraits::NUM_MMA_D_VO][KTraits::HALF_ELEMS_PER_THREAD], + typename KTraits::DTypeQKAccum (*m)[2], + float (*d)[2]) { using AttentionVariant = typename KTraits::AttentionVariant; if constexpr (AttentionVariant::use_softmax) { @@ -1461,14 +1472,14 @@ finalize_m(typename KTraits::AttentionVariant variant, * threadIdx.z. */ template -__device__ __forceinline__ void -threadblock_sync_mdo_states(float (*o_frag)[KTraits::NUM_MMA_D_VO][8], - typename KTraits::SharedStorage *smem_storage, - typename KTraits::DTypeQKAccum (*m)[2], - float (*d)[2], - const uint32_t warp_idx, - const uint32_t lane_idx, - const dim3 tid = threadIdx) +__device__ __forceinline__ void threadblock_sync_mdo_states( + float (*o_frag)[KTraits::NUM_MMA_D_VO][KTraits::HALF_ELEMS_PER_THREAD], + typename KTraits::SharedStorage *smem_storage, + typename KTraits::DTypeQKAccum (*m)[2], + float (*d)[2], + const uint32_t warp_idx, + const uint32_t lane_idx, + const dim3 tid = threadIdx) { // only necessary when blockDim.z > 1 if constexpr (KTraits::NUM_WARPS_KV > 1) { @@ -1609,16 +1620,16 @@ threadblock_sync_mdo_states(float (*o_frag)[KTraits::NUM_MMA_D_VO][8], } template -__device__ __forceinline__ void -write_o_reg_gmem(float (*o_frag)[KTraits::NUM_MMA_D_VO][8], - smem_t *o_smem, - typename KTraits::DTypeO *o_ptr_base, - const uint32_t o_packed_idx_base, - const uint32_t qo_upper_bound, - const uint32_t o_stride_n, - const uint32_t o_stride_h, - const uint_fastdiv group_size, - const dim3 tid = threadIdx) +__device__ __forceinline__ void write_o_reg_gmem( + float (*o_frag)[KTraits::NUM_MMA_D_VO][KTraits::HALF_ELEMS_PER_THREAD], + smem_t *o_smem, + typename KTraits::DTypeO *o_ptr_base, + const uint32_t o_packed_idx_base, + const uint32_t qo_upper_bound, + const uint32_t o_stride_n, + const uint32_t o_stride_h, + const uint_fastdiv group_size, + const dim3 tid = threadIdx) { using DTypeO = typename KTraits::DTypeO; constexpr uint32_t UPCAST_STRIDE_O = KTraits::UPCAST_STRIDE_O; @@ -1805,13 +1816,12 @@ SinglePrefillWithKVCacheDevice(const Params params, KTraits::SWIZZLE_MODE_Q; [[maybe_unused]] constexpr SwizzleMode SWIZZLE_MODE_KV = KTraits::SWIZZLE_MODE_KV; - [[maybe_unused]] constexpr SmemBasePtrTy = KTraits::SmemBasePtrTy; [[maybe_unused]] constexpr uint32_t KV_THR_LAYOUT_ROW = KTraits::KV_THR_LAYOUT_ROW; [[maybe_unused]] constexpr uint32_t KV_THR_LAYOUT_COL = KTraits::KV_THR_LAYOUT_COL; [[maybe_unused]] constexpr MaskMode MASK_MODE = KTraits::MASK_MODE; - [[maybe_unused]] constexpr HALF_ELEMS_PER_THREAD = + [[maybe_unused]] constexpr uint32_t HALF_ELEMS_PER_THREAD = KTraits::HALF_ELEMS_PER_THREAD; DTypeQ *q = params.q; @@ -1868,7 +1878,8 @@ SinglePrefillWithKVCacheDevice(const Params params, const uint32_t qo_packed_idx_base = (bx * NUM_WARPS_Q + get_warp_idx_q(tid.y)) * NUM_MMA_Q * 16; - smem_t qo_smem(smem_storage.q_smem); + smem_t qo_smem( + smem_storage.q_smem); const uint32_t o_stride_n = num_qo_heads * HEAD_DIM_VO, o_stride_h = HEAD_DIM_VO; DTypeQ *q_ptr_base = q + (kv_head_idx * group_size) * q_stride_h; @@ -1897,7 +1908,8 @@ SinglePrefillWithKVCacheDevice(const Params params, block.sync(); } - smem_t k_smem(smem_storage.k_smem), + smem_t k_smem( + smem_storage.k_smem), v_smem(smem_storage.v_smem); const uint32_t num_iterations = ceil_div( diff --git a/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h b/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h index 0ead063ff6..097012dac7 100644 --- a/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h +++ b/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h @@ -149,7 +149,7 @@ template __device__ __forceinline__ void load_fragment_4x4_half_registers(uint32_t *R, const T *smem_ptr) { - static_assert(std::is_same_v(), "Only half type is supported"); + static_assert(std::is_same_v, "Only half type is supported"); // Each thread loads 4 __half values in two 32b registers. load_fragment(R, smem_ptr); // transposes the values in four adjacent threads. The function does the From fbe0a77a05c7f63553428f7ede2badac0e4f007b Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Fri, 8 Aug 2025 14:09:12 -0400 Subject: [PATCH 023/109] Various compiler error fixes. --- .../flashinfer/attention/generic/cascade.cuh | 2 - .../flashinfer/attention/generic/page.cuh | 1 - .../attention/generic/permuted_smem.cuh | 40 ++++++++++++++++++- .../flashinfer/attention/generic/prefill.cuh | 39 +++++++----------- .../include/gpu_iface/backend/hip/mma_hip.h | 2 +- libflashinfer/include/gpu_iface/mma_ops.hpp | 2 +- 6 files changed, 56 insertions(+), 30 deletions(-) diff --git a/libflashinfer/include/flashinfer/attention/generic/cascade.cuh b/libflashinfer/include/flashinfer/attention/generic/cascade.cuh index 899c77cdea..dc31ecd7db 100644 --- a/libflashinfer/include/flashinfer/attention/generic/cascade.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/cascade.cuh @@ -440,7 +440,6 @@ PersistentVariableLengthMergeStatesKernel(DTypeIn *__restrict__ V, uint32_t cta_id = blockIdx.x; uint32_t num_ctas = gridDim.x; const uint32_t seq_len = seq_len_ptr ? *seq_len_ptr : max_seq_len; - uint32_t num_iters = ceil_div(seq_len * num_heads, num_ctas); constexpr uint32_t vec_bits = sizeof(DTypeIn) * vec_size * 8; constexpr uint32_t head_dim = vec_size * bdx; extern __shared__ uint8_t smem[]; @@ -564,7 +563,6 @@ PersistentVariableLengthAttentionSumKernel(DTypeIn *__restrict__ V, uint32_t cta_id = blockIdx.x; uint32_t num_ctas = gridDim.x; const uint32_t seq_len = seq_len_ptr ? *seq_len_ptr : max_seq_len; - uint32_t num_iters = ceil_div(seq_len * num_heads, num_ctas); constexpr uint32_t vec_bits = sizeof(DTypeIn) * vec_size * 8; constexpr uint32_t head_dim = vec_size * bdx; extern __shared__ uint8_t smem[]; diff --git a/libflashinfer/include/flashinfer/attention/generic/page.cuh b/libflashinfer/include/flashinfer/attention/generic/page.cuh index 4973c01305..2871fa233a 100644 --- a/libflashinfer/include/flashinfer/attention/generic/page.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/page.cuh @@ -308,7 +308,6 @@ __global__ void AppendPagedKVCacheKernel(paged_kv_t paged_kv, size_t append_v_stride_h) { uint32_t tx = threadIdx.x, ty = threadIdx.y; - uint32_t num_heads = paged_kv.num_heads; uint32_t head_idx = ty; uint32_t cta_id = blockIdx.x; uint32_t num_ctas = gridDim.x; diff --git a/libflashinfer/include/flashinfer/attention/generic/permuted_smem.cuh b/libflashinfer/include/flashinfer/attention/generic/permuted_smem.cuh index 6f71514872..8b292149bf 100644 --- a/libflashinfer/include/flashinfer/attention/generic/permuted_smem.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/permuted_smem.cuh @@ -175,7 +175,7 @@ template struct smem_t flashinfer::gpu_iface::mma::load_fragment_transpose_4x4_half_registers( smem_t_ptr, frag); #else - ldmatrix_m8n8x4(offset, frag); + static_assert(false, "Not supported on current platform"); #endif } @@ -279,11 +279,49 @@ template struct smem_t smem_ptr, reinterpret_cast(gptr)); } + template + __device__ __forceinline__ void + load_vector_async(uint32_t offset, const T *gptr, bool predicate) + { +#if defined(PLATFORM_HIP_DEVICE) + load_64b_async(offset, gptr, predicate); +#else + load_128b_async(offset, gptr, predicate); +#endif + } + + template + __device__ __forceinline__ void load_vector_async(uint32_t offset, + const T *gptr) + { +#if defined(PLATFORM_HIP_DEVICE) + load_64b_async(offset, gptr); +#else + load_128b_async(offset, gptr); +#endif + } + template __device__ __forceinline__ void store_128b(uint32_t offset, T *gptr) { *reinterpret_cast(gptr) = *(base + offset); } + + template + __device__ __forceinline__ void store_64b(uint32_t offset, T *gptr) + { + *reinterpret_cast(gptr) = *(base + offset); + } + + template + __device__ __forceinline__ void store_vector(uint32_t offset, T *gptr) + { +#if defined(PLATFORM_HIP_DEVICE) + store_64b(offset, gptr); +#else + store_128b(offset, gptr); +#endif + } }; } // namespace flashinfer diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh index 3cd943a778..3d6c0854c2 100644 --- a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -393,13 +393,8 @@ __device__ __forceinline__ void produce_kv_helper_( for (uint32_t i = 0; i < NUM_MMA_KV * 4 / NUM_WARPS_Q; ++i) { #pragma unroll for (uint32_t j = 0; j < NUM_MMA_D / (8 / sizeof(DTypeKV)); ++j) { -#if defined(PLATFORM_HIP_DEVICE) - smem.template load_64b_async(*smem_offset, *gptr, - kv_idx < kv_len); -#else - smem.template load_128b_async(*smem_offset, *gptr, - kv_idx < kv_len); -#endif + smem.template load_vector_async(*smem_offset, *gptr, + kv_idx < kv_len); *smem_offset = smem.template advance_offset_by_column( *smem_offset, j); @@ -467,8 +462,8 @@ __device__ __forceinline__ void produce_kv( static_assert(NUM_MMA_KV * 2 % NUM_WARPS_Q == 0); #pragma unroll for (uint32_t i = 0; i < NUM_MMA_KV * 2 / NUM_WARPS_Q; ++i) { - smem.template load_128b_async(*smem_offset, *gptr, - kv_idx < kv_len); + smem.template load_vector_async(*smem_offset, *gptr, + kv_idx < kv_len); *smem_offset = smem.template advance_offset_by_row( *smem_offset); @@ -515,8 +510,8 @@ __device__ __forceinline__ void page_produce_kv( : paged_kv.k_data + thr_local_kv_offset[i]; #pragma unroll for (uint32_t j = 0; j < NUM_MMA_D / (8 / sizeof(DType)); ++j) { - smem.template load_128b_async(*smem_offset, gptr, - kv_idx < kv_len); + smem.template load_vector_async(*smem_offset, gptr, + kv_idx < kv_len); *smem_offset = smem.template advance_offset_by_column<8>(*smem_offset, j); gptr += 8 * upcast_size(); @@ -538,8 +533,8 @@ __device__ __forceinline__ void page_produce_kv( for (uint32_t i = 0; i < NUM_MMA_KV * 2 / NUM_WARPS_Q; ++i) { DType *gptr = produce_v ? paged_kv.v_data + thr_local_kv_offset[i] : paged_kv.k_data + thr_local_kv_offset[i]; - smem.template load_128b_async(*smem_offset, gptr, - kv_idx < kv_len); + smem.template load_vector_async(*smem_offset, gptr, + kv_idx < kv_len); kv_idx += NUM_WARPS * 8; *smem_offset = smem.template advance_offset_by_row( @@ -668,15 +663,10 @@ __device__ __forceinline__ void load_q_global_smem( for (uint32_t mma_do = 0; mma_do < KTraits::NUM_MMA_D_QK / 4; ++mma_do) { -#if defined(PLATFORM_HIP_DEVICE) // load q fragment from gmem to smem - q_smem - ->template load_128b_async( - q_smem_offset_w, q_ptr, q_idx < qo_upper_bound); -#else - q_smem->template load_64b_async( - q_smem_offset_w, q_ptr, q_idx < qo_upper_bound); -#endif + q_smem->template load_vector_async< + SharedMemFillMode::kNoFill>(q_smem_offset_w, q_ptr, + q_idx < qo_upper_bound); q_smem_offset_w = q_smem->template advance_offset_by_column< WARP_THREAD_COLS>(q_smem_offset_w, mma_do); q_ptr += HALF_ELEMS_PER_THREAD * upcast_size(); @@ -1727,7 +1717,7 @@ __device__ __forceinline__ void write_o_reg_gmem( mma_do < KTraits::NUM_MMA_D_VO / 4; ++mma_do) { if (o_idx < qo_upper_bound) { - o_smem->store_128b(o_smem_offset_w, o_ptr); + o_smem->store_vector(o_smem_offset_w, o_ptr); } o_ptr += 8 * upcast_size(); o_smem_offset_w = @@ -1909,8 +1899,9 @@ SinglePrefillWithKVCacheDevice(const Params params, } smem_t k_smem( - smem_storage.k_smem), - v_smem(smem_storage.v_smem); + smem_storage.k_smem); + smem_t v_smem( + smem_storage.v_smem); const uint32_t num_iterations = ceil_div( MASK_MODE == MaskMode::kCausal diff --git a/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h b/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h index 097012dac7..4330683609 100644 --- a/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h +++ b/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h @@ -147,7 +147,7 @@ amdgcn_mfma_fp32_16x16x16fp16(float *C, uint32_t *A, uint32_t *B) /// the registers for a group of four consecuitive threads. template __device__ __forceinline__ void -load_fragment_4x4_half_registers(uint32_t *R, const T *smem_ptr) +load_fragment_4x4_half_registers(const T *smem_ptr, uint32_t *R) { static_assert(std::is_same_v, "Only half type is supported"); // Each thread loads 4 __half values in two 32b registers. diff --git a/libflashinfer/include/gpu_iface/mma_ops.hpp b/libflashinfer/include/gpu_iface/mma_ops.hpp index 113c0aa7f6..bd5adb30b5 100644 --- a/libflashinfer/include/gpu_iface/mma_ops.hpp +++ b/libflashinfer/include/gpu_iface/mma_ops.hpp @@ -51,7 +51,7 @@ load_fragment_transpose_4x4_half_registers(const T *smem_ptr, uint32_t *R) { static_assert(std::is_same::value, "Only __half is supported for the 4x4 register transpose"); - mma_detail::load_fragment_4x4_half_registers(R, smem_ptr); + mma_detail::load_fragment_4x4_half_registers(smem_ptr, R); } #endif From cffa9dd423cb15921985f2a39904f7731950ad05 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Fri, 8 Aug 2025 14:18:35 -0400 Subject: [PATCH 024/109] Fix wrong header guard --- libflashinfer/include/gpu_iface/mma_ops.hpp | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/libflashinfer/include/gpu_iface/mma_ops.hpp b/libflashinfer/include/gpu_iface/mma_ops.hpp index bd5adb30b5..513bdc22e9 100644 --- a/libflashinfer/include/gpu_iface/mma_ops.hpp +++ b/libflashinfer/include/gpu_iface/mma_ops.hpp @@ -28,8 +28,6 @@ namespace mma * \param R pointer to the fragment * \param smem_ptr pointer to the shared memory */ -// Call this load fragment -// inside mma there is impl of load template __device__ __forceinline__ void load_fragment(uint32_t *R, const T *smem_ptr) @@ -44,13 +42,10 @@ load_fragment_transpose(uint32_t *R, const T *smem_ptr, uint32_t stride) mma_detail::load_fragment_transpose(R, smem_ptr, stride); } -#if defined(PLATFORM_HIP_DEVICE) && defined(__gfx942__) -template +#if defined(PLATFORM_HIP_DEVICE) __device__ __forceinline__ void -load_fragment_transpose_4x4_half_registers(const T *smem_ptr, uint32_t *R) +load_fragment_transpose_4x4_half_registers(const half *smem_ptr, uint32_t *R) { - static_assert(std::is_same::value, - "Only __half is supported for the 4x4 register transpose"); mma_detail::load_fragment_4x4_half_registers(smem_ptr, R); } #endif From 3c9a5de8997d9ff3557494f6b3326765323dfa77 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Sun, 10 Aug 2025 03:51:10 -0400 Subject: [PATCH 025/109] port llama rotary transforms to HIP. --- .../attention/generic/permuted_smem.cuh | 4 - .../flashinfer/attention/generic/prefill.cuh | 137 +++++++++--------- 2 files changed, 70 insertions(+), 71 deletions(-) diff --git a/libflashinfer/include/flashinfer/attention/generic/permuted_smem.cuh b/libflashinfer/include/flashinfer/attention/generic/permuted_smem.cuh index 8b292149bf..ca92b1f7fe 100644 --- a/libflashinfer/include/flashinfer/attention/generic/permuted_smem.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/permuted_smem.cuh @@ -159,8 +159,6 @@ template struct smem_t static_assert(sizeof(T) == 4, "Only 32-bit fragment loading supported"); reinterpret_cast(frag)[0] = *reinterpret_cast(base + offset); - reinterpret_cast(&frag[2])[0] = - *reinterpret_cast(base + (offset ^ 0x1)); #else ldmatrix_m8n8x4(offset, frag); #endif @@ -187,8 +185,6 @@ template struct smem_t static_assert(sizeof(T) == 4, "Only 32-bit fragment storing supported"); *reinterpret_cast(base + offset) = reinterpret_cast(frag)[0]; - *reinterpret_cast(base + (offset ^ 0x1)) = - reinterpret_cast(&frag[2])[0]; #else stmatrix_m8n8x4(offset, frag); #endif diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh index 3d6c0854c2..50d0d86479 100644 --- a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -173,6 +173,7 @@ struct KernelTraits static constexpr uint32_t KV_THR_LAYOUT_COL = SWIZZLE_MODE_KV == SwizzleMode::k128B ? 8 : 4; #endif + static constexpr uint32_t THREADS_PER_ROW_GROUP = WARP_THREAD_COLS / 2; static constexpr uint32_t UPCAST_STRIDE_Q = HEAD_DIM_QK / upcast_size(); static constexpr uint32_t UPCAST_STRIDE_K = @@ -276,7 +277,7 @@ get_warp_idx(const uint32_t tid_y = threadIdx.y, * \note The sin/cos computation is slow, especially for A100 GPUs which has low * non tensor-ops flops, will optimize in the future. */ -template +template __device__ __forceinline__ void k_frag_apply_llama_rope(T *x_first_half, T *x_second_half, @@ -285,12 +286,17 @@ k_frag_apply_llama_rope(T *x_first_half, { static_assert(sizeof(T) == 2); #pragma unroll - for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { + for (uint32_t reg_id = 0; reg_id < HALF_ELEMS_PER_THREAD; ++reg_id) { float cos, sin, tmp; // 0 1 | 2 3 // --------- // 4 5 | 6 7 + +#if defined(PLATFORM_HIP_DEVICE) + uint32_t i = reg_id / 2, j = reg_id % 2; +#else uint32_t i = reg_id / 4, j = (reg_id % 4) / 2; +#endif __sincosf(float(kv_offset + 8 * i) * rope_freq[2 * j + reg_id % 2], &sin, &cos); tmp = x_first_half[reg_id]; @@ -300,7 +306,7 @@ k_frag_apply_llama_rope(T *x_first_half, } } -template +template __device__ __forceinline__ void q_frag_apply_llama_rope(T *x_first_half, T *x_second_half, @@ -309,7 +315,7 @@ q_frag_apply_llama_rope(T *x_first_half, const uint_fastdiv group_size) { #pragma unroll - for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { + for (uint32_t reg_id = 0; reg_id < HALF_ELEMS_PER_THREAD; ++reg_id) { float cos, sin, tmp; // 0 1 | 4 5 // --------- @@ -329,7 +335,7 @@ q_frag_apply_llama_rope(T *x_first_half, } } -template +template __device__ __forceinline__ void q_frag_apply_llama_rope_with_pos(T *x_first_half, T *x_second_half, @@ -342,12 +348,18 @@ q_frag_apply_llama_rope_with_pos(T *x_first_half, static_cast(q_rope_offset[qo_packed_offset / group_size]), static_cast(q_rope_offset[(qo_packed_offset + 8) / group_size])}; #pragma unroll - for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { + for (uint32_t reg_id = 0; reg_id < HALF_ELEMS_PER_THREAD; ++reg_id) { float cos, sin, tmp; // 0 1 | 4 5 // --------- // 2 3 | 6 7 - uint32_t i = ((reg_id % 4) / 2), j = (reg_id / 4); +#if defined(PLATFORM_HIP_DEVICE) + const uint32_t i = reg_id / 2; + const uint32_t j = reg_id % 2; +#else + const uint32_t i = (reg_id % 4) / 2; + const uint32_t j = reg_id / 4; +#endif __sincosf(pos[i] * rope_freq[2 * j + reg_id % 2], &sin, &cos); tmp = x_first_half[reg_id]; x_first_half[reg_id] = (tmp * cos - (float)x_second_half[reg_id] * sin); @@ -554,15 +566,8 @@ init_rope_freq(float (*rope_freq)[4], constexpr uint32_t HEAD_DIM = KTraits::NUM_MMA_D_QK * 16; const uint32_t lane_idx = tid_x; -#if defined(PLATFORM_HIP_DEVICE) - // MI300: 8 threads handle 8 elements (1 element per thread) - constexpr uint32_t THREADS_PER_ROW = 8; - constexpr uint32_t ELEMS_PER_THREAD = 1; -#else - // NVIDIA: 4 threads handle 8 elements (2 elements per thread) - constexpr uint32_t THREADS_PER_ROW = 4; - constexpr uint32_t ELEMS_PER_THREAD = 2; -#endif + constexpr uint32_t THREADS_PER_ROW = KTraits::THREADS_PER_ROW_GROUP; + constexpr uint32_t ELEMS_PER_THREAD = 8 / THREADS_PER_ROW; #pragma unroll for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO / 2; ++mma_d) { @@ -699,14 +704,6 @@ __device__ __forceinline__ void q_smem_inplace_apply_rotary( static_assert(KTraits::NUM_MMA_D_QK % 4 == 0, "NUM_MMA_D_QK must be a multiple of 4"); -#if defined(PLATFORM_HIP_DEVICE) - // MI300: 8 threads handle a row of 8 elements - const uint32_t pos_group_idx = lane_idx / 8; -#else - // NVIDIA: 4 threads handle a row of 8 elements - const uint32_t pos_group_idx = lane_idx / 4; -#endif - #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { uint32_t q_smem_offset_r_first_half = *q_smem_offset_r; @@ -721,12 +718,13 @@ __device__ __forceinline__ void q_smem_inplace_apply_rotary( KTraits::NUM_MMA_D_QK>(q_smem_offset_r_first_half, 0); q_smem->template load_fragment(q_smem_offset_r_last_half, q_frag_local[1]); - q_frag_apply_llama_rope( + q_frag_apply_llama_rope( (typename KTraits::DTypeQ *)q_frag_local[0], (typename KTraits::DTypeQ *)q_frag_local[1], rope_freq[mma_di], q_packed_idx + kv_len * group_size - qo_len * group_size + - mma_q * 16 + pos_group_idx, + mma_q * 16 + lane_idx / KTraits::THREADS_PER_ROW_GROUP, group_size); q_smem->template store_fragment(q_smem_offset_r_last_half, q_frag_local[1]); @@ -755,7 +753,7 @@ __device__ __forceinline__ void q_smem_inplace_apply_rotary_with_pos( if (get_warp_idx_kv(tid.z) == 0) { constexpr uint32_t UPCAST_STRIDE_Q = KTraits::UPCAST_STRIDE_Q; const uint32_t lane_idx = tid.x; - uint32_t q_frag_local[2][4]; + uint32_t q_frag_local[2][KTraits::INT32_ELEMS_PER_THREAD]; static_assert(KTraits::NUM_MMA_D_QK % 4 == 0, "NUM_MMA_D_QK must be a multiple of 4"); #pragma unroll @@ -765,24 +763,26 @@ __device__ __forceinline__ void q_smem_inplace_apply_rotary_with_pos( for (uint32_t mma_di = 0; mma_di < KTraits::NUM_MMA_D_QK / 2; ++mma_di) { - q_smem->ldmatrix_m8n8x4(q_smem_offset_r_first_half, - q_frag_local[0]); + q_smem->load_fragment(q_smem_offset_r_first_half, + q_frag_local[0]); uint32_t q_smem_offset_r_last_half = q_smem->template advance_offset_by_column< KTraits::NUM_MMA_D_QK>(q_smem_offset_r_first_half, 0); - q_smem->ldmatrix_m8n8x4(q_smem_offset_r_last_half, - q_frag_local[1]); - q_frag_apply_llama_rope_with_pos( + q_smem->load_fragment(q_smem_offset_r_last_half, + q_frag_local[1]); + q_frag_apply_llama_rope_with_pos< + typename KTraits::DTypeQ, typename KTraits::IdType, + KTraits::HALF_ELEMS_PER_THREAD>( (typename KTraits::DTypeQ *)q_frag_local[0], (typename KTraits::DTypeQ *)q_frag_local[1], rope_freq[mma_di], - q_packed_idx_base + mma_q * 16 + lane_idx / 4, group_size, - q_rope_offset); - q_smem->stmatrix_m8n8x4(q_smem_offset_r_last_half, - q_frag_local[1]); - q_smem->stmatrix_m8n8x4(q_smem_offset_r_first_half, - q_frag_local[0]); + q_packed_idx_base + mma_q * 16 + + lane_idx / KTraits::THREADS_PER_ROW_GROUP, + group_size, q_rope_offset); + q_smem->store_fragment(q_smem_offset_r_last_half, + q_frag_local[1]); + q_smem->store_fragment(q_smem_offset_r_first_half, + q_frag_local[0]); q_smem_offset_r_first_half = q_smem->template advance_offset_by_column<2>( q_smem_offset_r_first_half, mma_di); @@ -804,6 +804,8 @@ __device__ __forceinline__ void k_smem_inplace_apply_rotary( using DTypeKV = typename KTraits::DTypeKV; static_assert(sizeof(DTypeKV) == 2); constexpr uint32_t UPCAST_STRIDE_K = KTraits::UPCAST_STRIDE_K; + constexpr uint32_t THREADS_PER_ROW_GROUP = KTraits::THREADS_PER_ROW_GROUP; + constexpr uint32_t HALF_ELEMS_PER_THREAD = KTraits::HALF_ELEMS_PER_THREAD; uint32_t k_frag_local[2][KTraits::INT32_ELEMS_PER_THREAD]; const uint32_t lane_idx = tid.x; if constexpr (KTraits::NUM_MMA_D_QK == 4 && KTraits::NUM_WARPS_Q == 4) { @@ -817,25 +819,24 @@ __device__ __forceinline__ void k_smem_inplace_apply_rotary( static_assert( KTraits::NUM_MMA_KV % 2 == 0, "when NUM_MMA_D_QK == 4, NUM_MMA_KV must be a multiple of 2"); - uint32_t kv_idx = kv_idx_base + (warp_idx / 2) * 16 + lane_idx / 4; + uint32_t kv_idx = kv_idx_base + (warp_idx / 2) * 16 + + lane_idx / THREADS_PER_ROW_GROUP; *k_smem_offset_r = (*k_smem_offset_r ^ (0x2 * (warp_idx % 2))) + (warp_idx / 2) * 16 * UPCAST_STRIDE_K; #pragma unroll for (uint32_t i = 0; i < KTraits::NUM_MMA_KV / 2; ++i) { uint32_t k_smem_offset_r_first_half = *k_smem_offset_r; uint32_t mma_di = (warp_idx % 2); - k_smem->ldmatrix_m8n8x4(k_smem_offset_r_first_half, - k_frag_local[0]); + k_smem->load_fragment(k_smem_offset_r_first_half, k_frag_local[0]); uint32_t k_smem_offset_r_last_half = k_smem->template advance_offset_by_column<4>( k_smem_offset_r_first_half, 0); - k_smem->ldmatrix_m8n8x4(k_smem_offset_r_last_half, k_frag_local[1]); - k_frag_apply_llama_rope((DTypeKV *)k_frag_local[0], - (DTypeKV *)k_frag_local[1], - rope_freq[mma_di], kv_idx); - k_smem->stmatrix_m8n8x4(k_smem_offset_r_last_half, k_frag_local[1]); - k_smem->stmatrix_m8n8x4(k_smem_offset_r_first_half, - k_frag_local[0]); + k_smem->load_fragment(k_smem_offset_r_last_half, k_frag_local[1]); + k_frag_apply_llama_rope( + (DTypeKV *)k_frag_local[0], (DTypeKV *)k_frag_local[1], + rope_freq[mma_di], kv_idx); + k_smem->store_fragment(k_smem_offset_r_last_half, k_frag_local[1]); + k_smem->store_fragment(k_smem_offset_r_first_half, k_frag_local[0]); *k_smem_offset_r += 32 * UPCAST_STRIDE_K; kv_idx += 32; } @@ -856,7 +857,7 @@ __device__ __forceinline__ void k_smem_inplace_apply_rotary( // ... uint32_t kv_idx = kv_idx_base + (warp_idx_z * KTraits::NUM_MMA_KV * 16) + - lane_idx / 4; + lane_idx / THREADS_PER_ROW_GROUP; *k_smem_offset_r = *k_smem_offset_r ^ (0x2 * warp_idx_x); #pragma unroll for (uint32_t i = 0; i < KTraits::NUM_MMA_KV; ++i) { @@ -866,20 +867,20 @@ __device__ __forceinline__ void k_smem_inplace_apply_rotary( j < KTraits::NUM_MMA_D_QK / (2 * KTraits::NUM_WARPS_Q); ++j) { uint32_t mma_di = warp_idx_x + j * KTraits::NUM_WARPS_Q; - k_smem->ldmatrix_m8n8x4(k_smem_offset_r_first_half, - k_frag_local[0]); + k_smem->load_fragment(k_smem_offset_r_first_half, + k_frag_local[0]); uint32_t k_smem_offset_r_last_half = k_smem->template advance_offset_by_column< KTraits::NUM_MMA_D_QK>(k_smem_offset_r_first_half, 0); - k_smem->ldmatrix_m8n8x4(k_smem_offset_r_last_half, - k_frag_local[1]); - k_frag_apply_llama_rope((DTypeKV *)k_frag_local[0], - (DTypeKV *)k_frag_local[1], - rope_freq[mma_di], kv_idx); - k_smem->stmatrix_m8n8x4(k_smem_offset_r_last_half, - k_frag_local[1]); - k_smem->stmatrix_m8n8x4(k_smem_offset_r_first_half, - k_frag_local[0]); + k_smem->load_fragment(k_smem_offset_r_last_half, + k_frag_local[1]); + k_frag_apply_llama_rope( + (DTypeKV *)k_frag_local[0], (DTypeKV *)k_frag_local[1], + rope_freq[mma_di], kv_idx); + k_smem->store_fragment(k_smem_offset_r_last_half, + k_frag_local[1]); + k_smem->store_fragment(k_smem_offset_r_first_half, + k_frag_local[0]); k_smem_offset_r_first_half = k_smem->template advance_offset_by_column< 2 * KTraits::NUM_WARPS_Q>(k_smem_offset_r_first_half, @@ -1027,12 +1028,13 @@ __device__ __forceinline__ void logits_transform( const uint32_t lane_idx = tid.x; uint32_t q[KTraits::NUM_MMA_Q][2], r[KTraits::NUM_MMA_Q][2]; float logits = 0., logitsTransformed = 0.; + constexpr uint32_t TPR = KTraits::THREADS_PER_ROW_GROUP; #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { #pragma unroll for (uint32_t j = 0; j < 2; ++j) { - group_size.divmod(qo_packed_idx_base + mma_q * 16 + lane_idx / 4 + + group_size.divmod(qo_packed_idx_base + mma_q * 16 + lane_idx / TPR + 8 * j, q[mma_q][j], r[mma_q][j]); } @@ -1046,8 +1048,8 @@ __device__ __forceinline__ void logits_transform( for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { const uint32_t q_idx = q[mma_q][(reg_id % 4) / 2], kv_idx = kv_idx_base + mma_kv * 16 + - 2 * (lane_idx % 4) + 8 * (reg_id / 4) + - reg_id % 2; + 2 * (lane_idx % TPR) + + 8 * (reg_id / 4) + reg_id % 2; const uint32_t qo_head_idx = kv_head_idx * group_size + r[mma_q][(reg_id % 4) / 2]; @@ -1106,12 +1108,13 @@ logits_mask(const Params ¶ms, constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; using DTypeQKAccum = typename KTraits::DTypeQKAccum; constexpr MaskMode MASK_MODE = KTraits::MASK_MODE; + constexpr uint32_t TPR = KTraits::THREADS_PER_ROW_GROUP; uint32_t q[NUM_MMA_Q][2], r[NUM_MMA_Q][2]; #pragma unroll for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { #pragma unroll for (uint32_t j = 0; j < 2; ++j) { - group_size.divmod(qo_packed_idx_base + mma_q * 16 + lane_idx / 4 + + group_size.divmod(qo_packed_idx_base + mma_q * 16 + lane_idx / TPR + 8 * j, q[mma_q][j], r[mma_q][j]); } @@ -1125,8 +1128,8 @@ logits_mask(const Params ¶ms, for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { const uint32_t q_idx = q[mma_q][(reg_id % 4) / 2], kv_idx = kv_idx_base + mma_kv * 16 + - 2 * (lane_idx % 4) + 8 * (reg_id / 4) + - reg_id % 2; + 2 * (lane_idx % TPR) + + 8 * (reg_id / 4) + reg_id % 2; const uint32_t qo_head_idx = kv_head_idx * group_size + r[mma_q][(reg_id % 4) / 2]; const bool mask = From 548faa281d44a75d45ddcc12cfcf53ef6faf5dce Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Sun, 10 Aug 2025 12:11:56 -0400 Subject: [PATCH 026/109] Port ancillary kernels to CDNA3 thread layout. --- .../flashinfer/attention/generic/prefill.cuh | 78 ++++++++++++------- 1 file changed, 52 insertions(+), 26 deletions(-) diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh index 50d0d86479..f3df5be8e6 100644 --- a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -306,7 +306,7 @@ k_frag_apply_llama_rope(T *x_first_half, } } -template +template __device__ __forceinline__ void q_frag_apply_llama_rope(T *x_first_half, T *x_second_half, @@ -1045,11 +1045,18 @@ __device__ __forceinline__ void logits_transform( #pragma unroll for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { #pragma unroll - for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { + for (uint32_t reg_id = 0; reg_id < KTraits::HALF_ELEMS_PER_THREAD; + ++reg_id) + { +#if defined(PLATFORM_HIP_DEVICE) + const uint32_t i = reg_id / 2; +#else + const uint32_t i = reg_id / 4; +#endif const uint32_t q_idx = q[mma_q][(reg_id % 4) / 2], kv_idx = kv_idx_base + mma_kv * 16 + - 2 * (lane_idx % TPR) + - 8 * (reg_id / 4) + reg_id % 2; + 2 * (lane_idx % TPR) + 8 * i + + reg_id % 2; const uint32_t qo_head_idx = kv_head_idx * group_size + r[mma_q][(reg_id % 4) / 2]; @@ -1125,11 +1132,18 @@ logits_mask(const Params ¶ms, #pragma unroll for (uint32_t mma_kv = 0; mma_kv < NUM_MMA_KV; ++mma_kv) { #pragma unroll - for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { + for (uint32_t reg_id = 0; reg_id < KTraits::HALF_ELEMS_PER_THREAD; + ++reg_id) + { +#if defined(PLATFORM_HIP_DEVICE) + const uint32_t i = reg_id / 2; +#else + const uint32_t i = reg_id / 4; +#endif const uint32_t q_idx = q[mma_q][(reg_id % 4) / 2], kv_idx = kv_idx_base + mma_kv * 16 + - 2 * (lane_idx % TPR) + - 8 * (reg_id / 4) + reg_id % 2; + 2 * (lane_idx % TPR) + 8 * i + + reg_id % 2; const uint32_t qo_head_idx = kv_head_idx * group_size + r[mma_q][(reg_id % 4) / 2]; const bool mask = @@ -1430,7 +1444,9 @@ __device__ __forceinline__ void normalize_d( #pragma unroll for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { #pragma unroll - for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { + for (uint32_t reg_id = 0; + reg_id < KTraits::HALF_ELEMS_PER_THREAD; ++reg_id) + { o_frag[mma_q][mma_d][reg_id] = o_frag[mma_q][mma_d][reg_id] * d_rcp[mma_q][(reg_id % 4) / 2]; @@ -1474,23 +1490,29 @@ __device__ __forceinline__ void threadblock_sync_mdo_states( const uint32_t lane_idx, const dim3 tid = threadIdx) { + constexpr uint32_t TPR = KTraits::THREADS_PER_ROW_GROUP; + static_assert(WARP_SIZE % TPR == 0, + "THREADS_PER_ROW_GROUP must divide WARP_SIZE"); + constexpr uint32_t GROUPS_PER_WARP = WARP_SIZE / TPR; + const uint32_t lane_group_idx = lane_idx / TPR; + // only necessary when blockDim.z > 1 if constexpr (KTraits::NUM_WARPS_KV > 1) { float *smem_o = smem_storage->cta_sync_o_smem; float2 *smem_md = smem_storage->cta_sync_md_smem; - // o: [num_warps, NUM_MMA_Q, NUM_MMA_D_VO, WARP_SIZE(32), 8] - // md: [num_warps, NUM_MMA_Q, 16, 2 (m/d)] + // o: [num_warps, NUM_MMA_Q, NUM_MMA_D_VO, WARP_SIZE, + // HALF_ELEMS_PER_THREAD] md: [num_warps, NUM_MMA_Q, 16, 2 (m/d)] #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { #pragma unroll for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { - vec_t::memcpy( + vec_t::memcpy( smem_o + (((warp_idx * KTraits::NUM_MMA_Q + mma_q) * KTraits::NUM_MMA_D_VO + mma_d) * WARP_SIZE + lane_idx) * - 8, + KTraits::HALF_ELEMS_PER_THREAD, o_frag[mma_q][mma_d]); } } @@ -1501,8 +1523,8 @@ __device__ __forceinline__ void threadblock_sync_mdo_states( #pragma unroll for (uint32_t j = 0; j < 2; ++j) { smem_md[((warp_idx * KTraits::NUM_MMA_Q + mma_q) * 2 + j) * - 8 + - lane_idx / 4] = + GROUPS_PER_WARP + + lane_group_idx] = make_float2(float(m[mma_q][j]), d[mma_q][j]); } } @@ -1523,8 +1545,8 @@ __device__ __forceinline__ void threadblock_sync_mdo_states( mma_q) * 2 + j) * - 8 + - lane_idx / 4]; + GROUPS_PER_WARP + + lane_group_idx]; float m_prev = m_new, d_prev = d_new; m_new = max(m_new, md.x); d_new = @@ -1540,8 +1562,8 @@ __device__ __forceinline__ void threadblock_sync_mdo_states( mma_q) * 2 + j) * - 8 + - lane_idx / 4]; + GROUPS_PER_WARP + + lane_group_idx]; float mi = md.x; o_scale[j][i] = gpu_iface::math::ptx_exp2(float(mi - m_new)); @@ -1553,11 +1575,11 @@ __device__ __forceinline__ void threadblock_sync_mdo_states( #pragma unroll for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { - vec_t o_new; + vec_t o_new; o_new.fill(0.f); #pragma unroll for (uint32_t i = 0; i < KTraits::NUM_WARPS_KV; ++i) { - vec_t oi; + vec_t oi; oi.load(smem_o + ((((i * KTraits::NUM_WARPS_Q + get_warp_idx_q(tid.y)) * KTraits::NUM_MMA_Q + @@ -1566,10 +1588,12 @@ __device__ __forceinline__ void threadblock_sync_mdo_states( mma_d) * WARP_SIZE + lane_idx) * - 8); + KTraits::HALF_ELEMS_PER_THREAD); #pragma unroll - for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { + for (uint32_t reg_id = 0; + reg_id < KTraits::HALF_ELEMS_PER_THREAD; ++reg_id) + { o_new[reg_id] += oi[reg_id] * o_scale[(reg_id % 4) / 2][i]; } @@ -1586,11 +1610,11 @@ __device__ __forceinline__ void threadblock_sync_mdo_states( #pragma unroll for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { - vec_t o_new; + vec_t o_new; o_new.fill(0.f); #pragma unroll for (uint32_t i = 0; i < KTraits::NUM_WARPS_KV; ++i) { - vec_t oi; + vec_t oi; oi.load(smem_o + ((((i * KTraits::NUM_WARPS_Q + get_warp_idx_q(tid.y)) * KTraits::NUM_MMA_Q + @@ -1599,9 +1623,11 @@ __device__ __forceinline__ void threadblock_sync_mdo_states( mma_d) * WARP_SIZE + lane_idx) * - 8); + KTraits::HALF_ELEMS_PER_THREAD); #pragma unroll - for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { + for (uint32_t reg_id = 0; + reg_id < KTraits::HALF_ELEMS_PER_THREAD; ++reg_id) + { o_new[reg_id] += oi[reg_id]; } } From 83fd9c9b0e6b68472e901c83f3eac3fa2f17e6a6 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Sun, 10 Aug 2025 12:44:46 -0400 Subject: [PATCH 027/109] wip --- .../flashinfer/attention/generic/prefill.cuh | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh index f3df5be8e6..05dd0e0f92 100644 --- a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -1652,6 +1652,8 @@ __device__ __forceinline__ void write_o_reg_gmem( { using DTypeO = typename KTraits::DTypeO; constexpr uint32_t UPCAST_STRIDE_O = KTraits::UPCAST_STRIDE_O; + constexpr uint32_t TPR = KTraits::THREADS_PER_ROW_GROUP; + constexpr uint32_t HALF_ELEMS_PER_THREAD = KTraits::HALF_ELEMS_PER_THREAD; const uint32_t warp_idx_x = get_warp_idx_q(tid.y); const uint32_t lane_idx = tid.x; @@ -1661,7 +1663,7 @@ __device__ __forceinline__ void write_o_reg_gmem( #pragma unroll for (uint32_t j = 0; j < 2; ++j) { uint32_t q, r; - group_size.divmod(o_packed_idx_base + lane_idx / 4 + + group_size.divmod(o_packed_idx_base + lane_idx / TPR + mma_q * 16 + j * 8, q, r); const uint32_t o_idx = q; @@ -1671,12 +1673,12 @@ __device__ __forceinline__ void write_o_reg_gmem( if (o_idx < qo_upper_bound) { *reinterpret_cast( o_ptr_base + q * o_stride_n + r * o_stride_h + - mma_d * 16 + (lane_idx % 4) * 2) = + mma_d * 16 + (lane_idx % TPR) * 2) = *reinterpret_cast( &o_frag[mma_q][mma_d][j * 2]); *reinterpret_cast( o_ptr_base + q * o_stride_n + r * o_stride_h + - mma_d * 16 + 8 + (lane_idx % 4) * 2) = + mma_d * 16 + 8 + (lane_idx % TPR) * 2) = *reinterpret_cast( &o_frag[mma_q][mma_d][4 + j * 2]); } @@ -1691,9 +1693,10 @@ __device__ __forceinline__ void write_o_reg_gmem( #pragma unroll for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { - uint32_t o_frag_f16[8 / 2]; - vec_cast::template cast<8>( - (DTypeO *)o_frag_f16, o_frag[mma_q][mma_d]); + uint32_t o_frag_f16[HALF_ELEMS_PER_THREAD / 2]; + vec_cast::template cast< + HALF_ELEMS_PER_THREAD>((DTypeO *)o_frag_f16, + o_frag[mma_q][mma_d]); #ifdef FLASHINFER_STMATRIX_M8N8X4_ENABLED uint32_t o_smem_offset_w = From 06b25c0894bb4c9ce1c5e7255cfed0a0d3cc9933 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Tue, 12 Aug 2025 02:04:31 -0400 Subject: [PATCH 028/109] Fix merge issue --- libflashinfer/include/gpu_iface/mma_ops.hpp | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/libflashinfer/include/gpu_iface/mma_ops.hpp b/libflashinfer/include/gpu_iface/mma_ops.hpp index 3d0800273c..e08a5f82e7 100644 --- a/libflashinfer/include/gpu_iface/mma_ops.hpp +++ b/libflashinfer/include/gpu_iface/mma_ops.hpp @@ -28,6 +28,8 @@ namespace mma * \param R pointer to the fragment * \param smem_ptr pointer to the shared memory */ +// Call this load fragment +// inside mma there is impl of load template __device__ __forceinline__ void load_fragment(uint32_t *R, const T *smem_ptr) @@ -42,9 +44,10 @@ load_fragment_transpose(uint32_t *R, const T *smem_ptr, uint32_t stride) mma_detail::load_fragment_transpose(R, smem_ptr, stride); } -#if defined(PLATFORM_HIP_DEVICE) +#if defined(PLATFORM_HIP_DEVICE) && defined(__gfx942__) +template __device__ __forceinline__ void -load_fragment_transpose_4x4_half_registers(const half *smem_ptr, uint32_t *R) +load_fragment_transpose_4x4_half_registers(uint32_t *R, const T *smem_ptr) { static_assert(std::is_same::value, "Only __half is supported for the 4x4 register transpose"); From a9a4df11b246f27221e1fcb1286cedc2b68e9469 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Tue, 12 Aug 2025 02:13:45 -0400 Subject: [PATCH 029/109] Upadet compute_qk to use mma ops --- .../include/flashinfer/attention/generic/prefill.cuh | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh index 05dd0e0f92..599bcbc375 100644 --- a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -960,8 +960,6 @@ __device__ __forceinline__ void compute_qk( if constexpr (std::is_same_v) { -#warning "TODO: mma_sync_m16n16k16_row_col_f16f16f32 ...." -#if 0 if (mma_d == 0) { mma::mma_sync_m16n16k16_row_col_f16f16f32< typename KTraits::DTypeQ, MMAMode::kInit>( @@ -972,11 +970,13 @@ __device__ __forceinline__ void compute_qk( typename KTraits::DTypeQ>(s_frag[mma_q][mma_kv], a_frag[mma_q], b_frag); } -#endif } else if (std::is_same_v) { -#warning "Not yet implemented" -#if 0 +#if defined(PLATFORM_HIP_DEVICE) + static_assert( + false, + "FP16 DTypeQKAccum not yet implemented for CDNA3"); +#endif if (mma_d == 0) { mma::mma_sync_m16n16k16_row_col_f16f16f16< MMAMode::kInit>((uint32_t *)s_frag[mma_q][mma_kv], @@ -987,7 +987,6 @@ __device__ __forceinline__ void compute_qk( (uint32_t *)s_frag[mma_q][mma_kv], a_frag[mma_q], b_frag); } -#endif } } } From 7ef584b9eecb1248feaaa6810af5929fa3563732 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Tue, 12 Aug 2025 02:19:28 -0400 Subject: [PATCH 030/109] Update all kernel launch to use WARP_SIZe for thread count. --- .../include/flashinfer/attention/generic/prefill.cuh | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh index 599bcbc375..b2560e08d5 100644 --- a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -2254,7 +2254,7 @@ gpuError_t SinglePrefillWithKVCacheDispatched(Params params, void *args[] = {(void *)¶ms}; dim3 nblks(ceil_div(qo_len * group_size, CTA_TILE_Q), 1, num_kv_heads); - dim3 nthrs(32, NUM_WARPS_Q, NUM_WARPS_KV); // FIXME + dim3 nthrs(WARP_SIZE, NUM_WARPS_Q, NUM_WARPS_KV); FI_GPU_CALL(gpuLaunchKernel((void *)kernel, nblks, nthrs, args, smem_size, stream)); @@ -2272,7 +2272,7 @@ gpuError_t SinglePrefillWithKVCacheDispatched(Params params, void *args[] = {(void *)¶ms}; dim3 nblks(ceil_div(qo_len * group_size, CTA_TILE_Q), num_chunks, num_kv_heads); - dim3 nthrs(32, NUM_WARPS_Q, NUM_WARPS_KV); + dim3 nthrs(WARP_SIZE, NUM_WARPS_Q, NUM_WARPS_KV); FI_GPU_CALL(gpuLaunchKernel((void *)kernel, nblks, nthrs, args, smem_size, stream)); @@ -3091,7 +3091,7 @@ BatchPrefillWithRaggedKVCacheDispatched(Params params, } dim3 nblks(padded_batch_size, 1, num_kv_heads); - dim3 nthrs(32, NUM_WARPS_Q, NUM_WARPS_KV); + dim3 nthrs(WARP_SIZE, NUM_WARPS_Q, NUM_WARPS_KV); constexpr uint32_t NUM_MMA_D_QK = HEAD_DIM_QK / 16; constexpr uint32_t NUM_MMA_D_VO = HEAD_DIM_VO / 16; using DTypeQKAccum = @@ -3219,7 +3219,7 @@ BatchPrefillWithPagedKVCacheDispatched(Params params, } dim3 nblks(padded_batch_size, 1, num_kv_heads); - dim3 nthrs(32, NUM_WARPS_Q, NUM_WARPS_KV); + dim3 nthrs(WARP_SIZE, NUM_WARPS_Q, NUM_WARPS_KV); constexpr uint32_t NUM_MMA_D_QK = HEAD_DIM_QK / 16; constexpr uint32_t NUM_MMA_D_VO = HEAD_DIM_VO / 16; From f51c30e228358fdedef7c74bf0cac1a72c22f441 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Tue, 12 Aug 2025 12:25:09 -0400 Subject: [PATCH 031/109] Update CUDA path in compute_qk. --- .../flashinfer/attention/generic/prefill.cuh | 39 ++++++++++++------- 1 file changed, 26 insertions(+), 13 deletions(-) diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh index b2560e08d5..6d6c8e664d 100644 --- a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -932,7 +932,7 @@ __device__ __forceinline__ void compute_qk( #if defined(PLATFORM_HIP_DEVICE) static_assert(false, "FP8 support not yet implemented for CDNA3"); -#endif +#else uint32_t b_frag_f8[2]; if (mma_d % 2 == 0) { k_smem->ldmatrix_m8n8x4_left_half(*k_smem_offset_r, @@ -947,9 +947,14 @@ __device__ __forceinline__ void compute_qk( vec_cast:: template cast<8>((typename KTraits::DTypeQ *)b_frag, (typename KTraits::DTypeKV *)b_frag_f8); +#endif } else { +#if defined(PLATFORM_HIP_DEVICE) k_smem->load_fragment_4x4_transposed(*k_smem_offset_r, b_frag); +#else + k_smem->load_fragment(*k_smem_offset_r, b_frag); +#endif } *k_smem_offset_r = k_smem->template advance_offset_by_row<16, UPCAST_STRIDE_K>( @@ -976,7 +981,7 @@ __device__ __forceinline__ void compute_qk( static_assert( false, "FP16 DTypeQKAccum not yet implemented for CDNA3"); -#endif +#else if (mma_d == 0) { mma::mma_sync_m16n16k16_row_col_f16f16f16< MMAMode::kInit>((uint32_t *)s_frag[mma_q][mma_kv], @@ -987,6 +992,7 @@ __device__ __forceinline__ void compute_qk( (uint32_t *)s_frag[mma_q][mma_kv], a_frag[mma_q], b_frag); } +#endif } } } @@ -1309,16 +1315,18 @@ __device__ __forceinline__ void compute_sfm_v( float (*d)[2]) { constexpr uint32_t UPCAST_STRIDE_V = KTraits::UPCAST_STRIDE_V; + constexpr uint32_t HALF_ELEMS_PER_THREAD = KTraits::HALF_ELEMS_PER_THREAD; typename KTraits::DTypeQ s_frag_f16[KTraits::NUM_MMA_Q][KTraits::NUM_MMA_KV] - [8]; + [HALF_ELEMS_PER_THREAD]; if constexpr (std::is_same_v) { #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { #pragma unroll for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { - vec_cast::template cast<8>( - s_frag_f16[mma_q][mma_kv], s_frag[mma_q][mma_kv]); + vec_cast::template cast< + HALF_ELEMS_PER_THREAD>(s_frag_f16[mma_q][mma_kv], + s_frag[mma_q][mma_kv]); } } } @@ -1328,8 +1336,6 @@ __device__ __forceinline__ void compute_sfm_v( for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { #pragma unroll for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { -#warning "TODO: m16k16_rowsum_f16f16f32 ..........." -#if 0 if constexpr (std::is_same_v) { @@ -1337,10 +1343,15 @@ __device__ __forceinline__ void compute_sfm_v( s_frag_f16[mma_q][mma_kv]); } else { +#if defined(PLATFORM_HIP_DEVICE) + static_assert( + !std::is_same_v::value, + "FP16 reduction path not implemented for CDNA3"); +#else mma::m16k16_rowsum_f16f16f32(d[mma_q], s_frag[mma_q][mma_kv]); - } #endif + } } } } @@ -1351,8 +1362,10 @@ __device__ __forceinline__ void compute_sfm_v( for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { uint32_t b_frag[4]; if constexpr (sizeof(typename KTraits::DTypeKV) == 1) { -#warning "Not yet implemented......" -#if 0 +#if defined(PLATFORM_HIP_DEVICE) + static_assert(false, + "FP8 V path not implemented for CDNA3 yet"); +#else uint32_t b_frag_f8[2]; if (mma_d % 2 == 0) { v_smem->ldmatrix_m8n8x4_trans_left_half(*v_smem_offset_r, @@ -1366,9 +1379,9 @@ __device__ __forceinline__ void compute_sfm_v( frag_layout_swizzle_16b_to_8b_trans(b_frag_f8[0]); b_frag_f8[1] = frag_layout_swizzle_16b_to_8b_trans(b_frag_f8[1]); - vec_cast::template - cast<8>((typename KTraits::DTypeQ *)b_frag, - (typename KTraits::DTypeKV *)b_frag_f8); + vec_cast:: + template cast<8>((typename KTraits::DTypeQ *)b_frag, + (typename KTraits::DTypeKV *)b_frag_f8); swap(b_frag[1], b_frag[2]); #endif } From d3377ce050c8c94fb3da4e0aa4edae86a7b55df3 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Tue, 12 Aug 2025 13:08:47 -0400 Subject: [PATCH 032/109] WIP... --- .../include/flashinfer/attention/generic/prefill.cuh | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh index 6d6c8e664d..44e2856f33 100644 --- a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -1242,8 +1242,11 @@ __device__ __forceinline__ void update_mdo_states( } } else if constexpr (std::is_same_v) { -#warning "Not implemented yet ...." -#if 0 +#if defined(PLATFORM_HIP_DEVICE) + static_assert( + false, + "Half precision accumulator not yet implemented for AMD"); +#else const half2 sm_scale = __float2half2_rn(variant.sm_scale_log2); #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { From 42b3965f02c4165ee3366bc5699c22f1c073b4f5 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Wed, 13 Aug 2025 11:09:48 -0400 Subject: [PATCH 033/109] Implementation of update_mdo_states. --- .../flashinfer/attention/generic/prefill.cuh | 81 +++++++++++++++++-- 1 file changed, 73 insertions(+), 8 deletions(-) diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh index 44e2856f33..4021daba09 100644 --- a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -154,6 +154,11 @@ struct KernelTraits // Presently we use 16x4 thread layout for all cases. static constexpr uint32_t KV_THR_LAYOUT_ROW = WARP_THREAD_ROWS; static constexpr uint32_t KV_THR_LAYOUT_COL = WARP_THREAD_COLS; + // The constant is defined based on the matrix layout of the "D/C" + // accumulator matrix in a D = A*B+C computation. On CDNA3 the D/C matrices + // are distributed as four 4x16 bands across the 64 threads. Each thread + // owns one element from four different rows. + static constexpr uint32_t NUM_MMA_ACCUM_CHUNKS_PER_THREAD = 4; #else using SmemBasePtrTy = uint4; static constexpr uint32_t NUM_THREADS = NUM_WARPS_Q * NUM_WARPS_KV * 32; @@ -172,6 +177,14 @@ struct KernelTraits : WARP_THREAD_COLS; static constexpr uint32_t KV_THR_LAYOUT_COL = SWIZZLE_MODE_KV == SwizzleMode::k128B ? 8 : 4; + + // The constant is defined based on the matrix layout of the "D/C" + // accumulator matrix in a D = A*B+C computation. On CUDA for + // m16n8k16 mma ops the D/C matrix is distributed as 4 8x8 block and each + // thread stores eight elements from two different rows. + // Refer: + // https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-fragment-mma-16816-i8-f8 + static constexpr uint32_t NUM_MMA_ACCUM_CHUNKS_PER_THREAD = 2; #endif static constexpr uint32_t THREADS_PER_ROW_GROUP = WARP_THREAD_COLS / 2; static constexpr uint32_t UPCAST_STRIDE_Q = @@ -590,9 +603,12 @@ template __device__ __forceinline__ void init_states( typename KTraits::AttentionVariant variant, float (*o_frag)[KTraits::NUM_MMA_D_VO][KTraits::HALF_ELEMS_PER_THREAD], - typename KTraits::DTypeQKAccum (*m)[2], - float (*d)[2]) + typename KTraits::DTypeQKAccum ( + *m)[KTraits::NUM_MMA_ACCUM_CHUNKS_PER_THREAD], + float (*d)[KTraits::NUM_MMA_ACCUM_CHUNKS_PER_THREAD]) { + constexpr uint32_t NUM_MMA_ACCUM_CHUNKS_PER_THREAD = + KTraits::NUM_MMA_ACCUM_CHUNKS_PER_THREAD; #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { #pragma unroll @@ -610,7 +626,7 @@ __device__ __forceinline__ void init_states( #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { #pragma unroll - for (uint32_t j = 0; j < 2; ++j) { + for (uint32_t j = 0; j < NUM_MMA_ACCUM_CHUNKS_PER_THREAD; ++j) { m[mma_q][j] = typename KTraits::DTypeQKAccum(-gpu_iface::math::inf); d[mma_q][j] = 1.f; @@ -1172,11 +1188,14 @@ __device__ __forceinline__ void update_mdo_states( typename KTraits::DTypeQKAccum ( *s_frag)[KTraits::NUM_MMA_KV][KTraits::HALF_ELEMS_PER_THREAD], float (*o_frag)[KTraits::NUM_MMA_D_VO][KTraits::HALF_ELEMS_PER_THREAD], - typename KTraits::DTypeQKAccum (*m)[2], - float (*d)[2]) + typename KTraits::DTypeQKAccum ( + *m)[KTraits::NUM_MMA_ACCUM_CHUNKS_PER_THREAD], + float (*d)[KTraits::NUM_MMA_ACCUM_CHUNKS_PER_THREAD]) { using DTypeQKAccum = typename KTraits::DTypeQKAccum; using AttentionVariant = typename KTraits::AttentionVariant; + constexpr uint32_t NUM_MMA_ACCUM_CHUNKS_PER_THREAD = + KTraits::NUM_MMA_ACCUM_CHUNKS_PER_THREAD; constexpr bool use_softmax = AttentionVariant::use_softmax; if constexpr (use_softmax) { @@ -1185,19 +1204,62 @@ __device__ __forceinline__ void update_mdo_states( #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { #pragma unroll - for (uint32_t j = 0; j < 2; ++j) { + for (uint32_t j = 0; j < NUM_MMA_ACCUM_CHUNKS_PER_THREAD; ++j) { float m_prev = m[mma_q][j]; #pragma unroll for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { +#if defined(PLATFORM_HIP_DEVICE) + m[mma_q][j] = + max(m[mma_q][j], s_frag[mma_q][mma_kv][j]); +#else float m_local = max(max(s_frag[mma_q][mma_kv][j * 2 + 0], s_frag[mma_q][mma_kv][j * 2 + 1]), max(s_frag[mma_q][mma_kv][j * 2 + 4], s_frag[mma_q][mma_kv][j * 2 + 5])); m[mma_q][j] = max(m[mma_q][j], m_local); +#endif + } +#if defined(PLATFORM_HIP_DEVICE) + // Butterfly reduction across all threads in the band (16 + // threads) for CDNA3's 64-thread wavefront + m[mma_q][j] = + max(m[mma_q][j], gpu_iface::math::shfl_xor_sync( + m[mma_q][j], 0x8)); // 16 apart + m[mma_q][j] = + max(m[mma_q][j], gpu_iface::math::shfl_xor_sync( + m[mma_q][j], 0x4)); // 8 apart + m[mma_q][j] = + max(m[mma_q][j], gpu_iface::math::shfl_xor_sync( + m[mma_q][j], 0x2)); // 4 apart + m[mma_q][j] = + max(m[mma_q][j], gpu_iface::math::shfl_xor_sync( + m[mma_q][j], 0x1)); // 2 apart + + float o_scale = gpu_iface::math::ptx_exp2( + m_prev * sm_scale - m[mma_q][j] * sm_scale); + d[mma_q][j] *= o_scale; + + // Scale output fragments for this specific row +#pragma unroll + for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; + ++mma_d) + { + o_frag[mma_q][mma_d][j] *= o_scale; // Direct indexing } + + // Convert logits to probabilities for this row +#pragma unroll + for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; + ++mma_kv) + { + s_frag[mma_q][mma_kv][j] = gpu_iface::math::ptx_exp2( + s_frag[mma_q][mma_kv][j] * sm_scale - + m[mma_q][j] * sm_scale); + } +#else m[mma_q][j] = max(m[mma_q][j], gpu_iface::math::shfl_xor_sync(m[mma_q][j], 0x2)); @@ -1238,6 +1300,7 @@ __device__ __forceinline__ void update_mdo_states( s_frag[mma_q][mma_kv][j * 2 + 5] * sm_scale - m[mma_q][j] * sm_scale); } +#endif } } } @@ -1860,6 +1923,8 @@ SinglePrefillWithKVCacheDevice(const Params params, [[maybe_unused]] constexpr MaskMode MASK_MODE = KTraits::MASK_MODE; [[maybe_unused]] constexpr uint32_t HALF_ELEMS_PER_THREAD = KTraits::HALF_ELEMS_PER_THREAD; + [[maybe_unused]] constexpr uint32_t NUM_MMA_ACCUM_CHUNKS_PER_THREAD = + KTraits::NUM_MMA_ACCUM_CHUNKS_PER_THREAD; DTypeQ *q = params.q; DTypeKV *k = params.k; @@ -1899,8 +1964,8 @@ SinglePrefillWithKVCacheDevice(const Params params, DTypeQKAccum s_frag[NUM_MMA_Q][NUM_MMA_KV][HALF_ELEMS_PER_THREAD]; alignas( 16) float o_frag[NUM_MMA_Q][NUM_MMA_D_VO][HALF_ELEMS_PER_THREAD]; - DTypeQKAccum m[NUM_MMA_Q][2]; - float d[NUM_MMA_Q][2]; + DTypeQKAccum m[NUM_MMA_Q][NUM_MMA_ACCUM_CHUNKS_PER_THREAD]; + float d[NUM_MMA_Q][NUM_MMA_ACCUM_CHUNKS_PER_THREAD]; float rope_freq[NUM_MMA_D_QK / 2][4]; if constexpr (KTraits::POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) { From 16a9d151e63a4e7d7a3217d047a23d4cea66ae10 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Wed, 13 Aug 2025 13:52:42 -0400 Subject: [PATCH 034/109] WIP compute_sfm_v --- .../include/flashinfer/attention/generic/prefill.cuh | 10 ++++++++-- libflashinfer/include/gpu_iface/backend/hip/mma_hip.h | 10 ++++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh index 4021daba09..3dc53fa31f 100644 --- a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -967,7 +967,11 @@ __device__ __forceinline__ void compute_qk( } else { #if defined(PLATFORM_HIP_DEVICE) - k_smem->load_fragment_4x4_transposed(*k_smem_offset_r, b_frag); + // TODO: We need to validate the layout of K. Whether a + // transposed load is needed or whether K is pre-transposed. + // k_smem->load_fragment_4x4_transposed(*k_smem_offset_r, + // b_frag); + k_smem->load_fragment(*k_smem_offset_r, b_frag); #else k_smem->load_fragment(*k_smem_offset_r, b_frag); #endif @@ -1378,10 +1382,12 @@ __device__ __forceinline__ void compute_sfm_v( typename KTraits::DTypeQKAccum ( *s_frag)[KTraits::NUM_MMA_KV][KTraits::HALF_ELEMS_PER_THREAD], float (*o_frag)[KTraits::NUM_MMA_D_VO][KTraits::HALF_ELEMS_PER_THREAD], - float (*d)[2]) + float (*d)[KTraits::NUM_MMA_ACCUM_CHUNKS_PER_THREAD]) { constexpr uint32_t UPCAST_STRIDE_V = KTraits::UPCAST_STRIDE_V; constexpr uint32_t HALF_ELEMS_PER_THREAD = KTraits::HALF_ELEMS_PER_THREAD; + constexpr uint32_t NUM_MMA_ACCUM_CHUNKS_PER_THREAD = + KTraits::NUM_MMA_ACCUM_CHUNKS_PER_THREAD; typename KTraits::DTypeQ s_frag_f16[KTraits::NUM_MMA_Q][KTraits::NUM_MMA_KV] [HALF_ELEMS_PER_THREAD]; diff --git a/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h b/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h index f00ba36a93..2238406b13 100644 --- a/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h +++ b/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h @@ -168,6 +168,16 @@ load_fragment_4x4_half_registers(const T *smem_ptr, uint32_t *R) transpose_4x4_half_registers(R); } +// TODO: Verify correct matrix multiplication order for rowsum on CDNA3 +// Current assumption: s_frag × ones_vector = row_sums +// Need to validate: +// 1. How compute_qk stores Q×K^T result in s_frag for CDNA3 +// 2. Whether K is pre-transposed or transposed during fragment loading +// 3. If we need s_frag × M1 or M1 × s_frag for correct row sums +// +// Test with known input matrices to verify: +// - s_frag layout matches expected Q×K^T result +// - rowsum produces correct per-row sums template __device__ __forceinline__ void m16k16_rowsum_f16f16f32(float *d, DType *s_frag) { From 96ead8413b0a6664e1c26d4831d2658fe8fa7141 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Wed, 13 Aug 2025 16:05:09 -0400 Subject: [PATCH 035/109] WIP broken... --- .../flashinfer/attention/generic/prefill.cuh | 95 ++++++++++--------- 1 file changed, 49 insertions(+), 46 deletions(-) diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh index 3dc53fa31f..8ddb06cc93 100644 --- a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -9,7 +9,7 @@ #include "gpu_iface/fastdiv.cuh" #include "gpu_iface/math_ops.hpp" #include "gpu_iface/memory_ops.hpp" -// #include "gpu_iface/mma_ops.hpp" +#include "gpu_iface/mma_ops.hpp" #include "gpu_iface/platform.hpp" #include "gpu_iface/utils.cuh" @@ -33,10 +33,10 @@ DEFINE_HAS_MEMBER(maybe_k_rope_offset) namespace cg = flashinfer::gpu_iface::cg; namespace memory = flashinfer::gpu_iface::memory; -// namespace mma = gpu_iface::mma; +namespace mma = gpu_iface::mma; using gpu_iface::vec_dtypes::vec_cast; -// using mma::MMAMode; +using mma::MMAMode; constexpr uint32_t WARP_SIZE = gpu_iface::kWarpSize; @@ -158,7 +158,7 @@ struct KernelTraits // accumulator matrix in a D = A*B+C computation. On CDNA3 the D/C matrices // are distributed as four 4x16 bands across the 64 threads. Each thread // owns one element from four different rows. - static constexpr uint32_t NUM_MMA_ACCUM_CHUNKS_PER_THREAD = 4; + static constexpr uint32_t NUM_ACCUM_ROWS_PER_THREAD = 4; #else using SmemBasePtrTy = uint4; static constexpr uint32_t NUM_THREADS = NUM_WARPS_Q * NUM_WARPS_KV * 32; @@ -184,7 +184,7 @@ struct KernelTraits // thread stores eight elements from two different rows. // Refer: // https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-fragment-mma-16816-i8-f8 - static constexpr uint32_t NUM_MMA_ACCUM_CHUNKS_PER_THREAD = 2; + static constexpr uint32_t NUM_ACCUM_ROWS_PER_THREAD = 2; #endif static constexpr uint32_t THREADS_PER_ROW_GROUP = WARP_THREAD_COLS / 2; static constexpr uint32_t UPCAST_STRIDE_Q = @@ -603,12 +603,11 @@ template __device__ __forceinline__ void init_states( typename KTraits::AttentionVariant variant, float (*o_frag)[KTraits::NUM_MMA_D_VO][KTraits::HALF_ELEMS_PER_THREAD], - typename KTraits::DTypeQKAccum ( - *m)[KTraits::NUM_MMA_ACCUM_CHUNKS_PER_THREAD], - float (*d)[KTraits::NUM_MMA_ACCUM_CHUNKS_PER_THREAD]) + typename KTraits::DTypeQKAccum (*m)[KTraits::NUM_ACCUM_ROWS_PER_THREAD], + float (*d)[KTraits::NUM_ACCUM_ROWS_PER_THREAD]) { - constexpr uint32_t NUM_MMA_ACCUM_CHUNKS_PER_THREAD = - KTraits::NUM_MMA_ACCUM_CHUNKS_PER_THREAD; + constexpr uint32_t NUM_ACCUM_ROWS_PER_THREAD = + KTraits::NUM_ACCUM_ROWS_PER_THREAD; #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { #pragma unroll @@ -626,7 +625,7 @@ __device__ __forceinline__ void init_states( #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { #pragma unroll - for (uint32_t j = 0; j < NUM_MMA_ACCUM_CHUNKS_PER_THREAD; ++j) { + for (uint32_t j = 0; j < NUM_ACCUM_ROWS_PER_THREAD; ++j) { m[mma_q][j] = typename KTraits::DTypeQKAccum(-gpu_iface::math::inf); d[mma_q][j] = 1.f; @@ -1192,14 +1191,13 @@ __device__ __forceinline__ void update_mdo_states( typename KTraits::DTypeQKAccum ( *s_frag)[KTraits::NUM_MMA_KV][KTraits::HALF_ELEMS_PER_THREAD], float (*o_frag)[KTraits::NUM_MMA_D_VO][KTraits::HALF_ELEMS_PER_THREAD], - typename KTraits::DTypeQKAccum ( - *m)[KTraits::NUM_MMA_ACCUM_CHUNKS_PER_THREAD], - float (*d)[KTraits::NUM_MMA_ACCUM_CHUNKS_PER_THREAD]) + typename KTraits::DTypeQKAccum (*m)[KTraits::NUM_ACCUM_ROWS_PER_THREAD], + float (*d)[KTraits::NUM_ACCUM_ROWS_PER_THREAD]) { using DTypeQKAccum = typename KTraits::DTypeQKAccum; using AttentionVariant = typename KTraits::AttentionVariant; - constexpr uint32_t NUM_MMA_ACCUM_CHUNKS_PER_THREAD = - KTraits::NUM_MMA_ACCUM_CHUNKS_PER_THREAD; + constexpr uint32_t NUM_ACCUM_ROWS_PER_THREAD = + KTraits::NUM_ACCUM_ROWS_PER_THREAD; constexpr bool use_softmax = AttentionVariant::use_softmax; if constexpr (use_softmax) { @@ -1208,7 +1206,7 @@ __device__ __forceinline__ void update_mdo_states( #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { #pragma unroll - for (uint32_t j = 0; j < NUM_MMA_ACCUM_CHUNKS_PER_THREAD; ++j) { + for (uint32_t j = 0; j < NUM_ACCUM_ROWS_PER_THREAD; ++j) { float m_prev = m[mma_q][j]; #pragma unroll for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; @@ -1382,12 +1380,15 @@ __device__ __forceinline__ void compute_sfm_v( typename KTraits::DTypeQKAccum ( *s_frag)[KTraits::NUM_MMA_KV][KTraits::HALF_ELEMS_PER_THREAD], float (*o_frag)[KTraits::NUM_MMA_D_VO][KTraits::HALF_ELEMS_PER_THREAD], - float (*d)[KTraits::NUM_MMA_ACCUM_CHUNKS_PER_THREAD]) + float (*d)[KTraits::NUM_ACCUM_ROWS_PER_THREAD]) { constexpr uint32_t UPCAST_STRIDE_V = KTraits::UPCAST_STRIDE_V; constexpr uint32_t HALF_ELEMS_PER_THREAD = KTraits::HALF_ELEMS_PER_THREAD; - constexpr uint32_t NUM_MMA_ACCUM_CHUNKS_PER_THREAD = - KTraits::NUM_MMA_ACCUM_CHUNKS_PER_THREAD; + constexpr uint32_t INT32_ELEMS_PER_THREAD = KTraits::INT32_ELEMS_PER_THREAD; + constexpr uint32_t NUM_ACCUM_ROWS_PER_THREAD = + KTraits::NUM_ACCUM_ROWS_PER_THREAD; + constexpr uint32_t V_SMEM_COLUMN_ADVANCE = + 16 / KTraits::HALF_ELEMS_PER_THREAD; typename KTraits::DTypeQ s_frag_f16[KTraits::NUM_MMA_Q][KTraits::NUM_MMA_KV] [HALF_ELEMS_PER_THREAD]; @@ -1417,7 +1418,7 @@ __device__ __forceinline__ void compute_sfm_v( else { #if defined(PLATFORM_HIP_DEVICE) static_assert( - !std::is_same_v::value, + !std::is_same_v, "FP16 reduction path not implemented for CDNA3"); #else mma::m16k16_rowsum_f16f16f32(d[mma_q], @@ -1432,7 +1433,7 @@ __device__ __forceinline__ void compute_sfm_v( for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { #pragma unroll for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { - uint32_t b_frag[4]; + uint32_t b_frag[INT32_ELEMS_PER_THREAD]; if constexpr (sizeof(typename KTraits::DTypeKV) == 1) { #if defined(PLATFORM_HIP_DEVICE) static_assert(false, @@ -1458,13 +1459,14 @@ __device__ __forceinline__ void compute_sfm_v( #endif } else { -#warning "TODO ldmatrix_m8n8x4_trans ............" +#if defined(PLATFORM_HIP_DEVICE) + v_smem->load_fragment_4x4_transposed(*v_smem_offset_r, b_frag); +#else v_smem->ldmatrix_m8n8x4_trans(*v_smem_offset_r, b_frag); +#endif } #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { -#warning "TODO mma_sync_m16n16k16_row_col_f16f16f32 ............" -#if 0 if constexpr (std::is_same_v) { @@ -1479,18 +1481,17 @@ __device__ __forceinline__ void compute_sfm_v( o_frag[mma_q][mma_d], (uint32_t *)s_frag[mma_q][mma_kv], b_frag); } -#endif } if constexpr (sizeof(typename KTraits::DTypeKV) == 1) { if (mma_d % 2 == 1) { *v_smem_offset_r = - v_smem->template advance_offset_by_column<2>( - *v_smem_offset_r, mma_d / 2); + v_smem->template advance_offset_by_column< + V_SMEM_COLUMN_ADVANCE>(*v_smem_offset_r, mma_d / 2); } } else { - *v_smem_offset_r = v_smem->template advance_offset_by_column<2>( - *v_smem_offset_r, mma_d); + *v_smem_offset_r = v_smem->template advance_offset_by_column< + V_SMEM_COLUMN_ADVANCE>(*v_smem_offset_r, mma_d); } } *v_smem_offset_r = @@ -1504,17 +1505,17 @@ __device__ __forceinline__ void compute_sfm_v( template __device__ __forceinline__ void normalize_d( float (*o_frag)[KTraits::NUM_MMA_D_VO][KTraits::HALF_ELEMS_PER_THREAD], - typename KTraits::DTypeQKAccum (*m)[2], - float (*d)[2]) + typename KTraits::DTypeQKAccum (*m)[KTraits::NUM_ACCUM_ROWS_PER_THREAD], + float (*d)[KTraits::NUM_ACCUM_ROWS_PER_THREAD]) { using AttentionVariant = typename KTraits::AttentionVariant; if constexpr (AttentionVariant::use_softmax) { - float d_rcp[KTraits::NUM_MMA_Q][2]; + float d_rcp[KTraits::NUM_MMA_Q][KTraits::NUM_ACCUM_ROWS_PER_THREAD]; // compute reciprocal of d #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { #pragma unroll - for (uint32_t j = 0; j < 2; ++j) { + for (uint32_t j = 0; j < KTraits::NUM_ACCUM_ROWS_PER_THREAD; ++j) { d_rcp[mma_q][j] = (m[mma_q][j] != typename KTraits::DTypeQKAccum(-gpu_iface::math::inf)) @@ -1541,15 +1542,15 @@ __device__ __forceinline__ void normalize_d( } template -__device__ __forceinline__ void -finalize_m(typename KTraits::AttentionVariant variant, - typename KTraits::DTypeQKAccum (*m)[2]) +__device__ __forceinline__ void finalize_m( + typename KTraits::AttentionVariant variant, + typename KTraits::DTypeQKAccum (*m)[KTraits::NUM_ACCUM_ROWS_PER_THREAD]) { if constexpr (variant.use_softmax) { #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { #pragma unroll - for (uint32_t j = 0; j < 2; ++j) { + for (uint32_t j = 0; j < KTraits::NUM_ACCUM_ROWS_PER_THREAD; ++j) { if (m[mma_q][j] != typename KTraits::DTypeQKAccum(-gpu_iface::math::inf)) { @@ -1568,13 +1569,15 @@ template __device__ __forceinline__ void threadblock_sync_mdo_states( float (*o_frag)[KTraits::NUM_MMA_D_VO][KTraits::HALF_ELEMS_PER_THREAD], typename KTraits::SharedStorage *smem_storage, - typename KTraits::DTypeQKAccum (*m)[2], - float (*d)[2], + typename KTraits::DTypeQKAccum (*m)[KTraits::NUM_ACCUM_ROWS_PER_THREAD], + float (*d)[KTraits::NUM_ACCUM_ROWS_PER_THREAD], const uint32_t warp_idx, const uint32_t lane_idx, const dim3 tid = threadIdx) { constexpr uint32_t TPR = KTraits::THREADS_PER_ROW_GROUP; + constexpr uint32_t NARPT = KTraits::NUM_ACCUM_ROWS_PER_THREAD; + static_assert(WARP_SIZE % TPR == 0, "THREADS_PER_ROW_GROUP must divide WARP_SIZE"); constexpr uint32_t GROUPS_PER_WARP = WARP_SIZE / TPR; @@ -1605,7 +1608,7 @@ __device__ __forceinline__ void threadblock_sync_mdo_states( #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { #pragma unroll - for (uint32_t j = 0; j < 2; ++j) { + for (uint32_t j = 0; j < NARPT; ++j) { smem_md[((warp_idx * KTraits::NUM_MMA_Q + mma_q) * 2 + j) * GROUPS_PER_WARP + lane_group_idx] = @@ -1619,7 +1622,7 @@ __device__ __forceinline__ void threadblock_sync_mdo_states( for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { float o_scale[2][KTraits::NUM_WARPS_KV]; #pragma unroll - for (uint32_t j = 0; j < 2; ++j) { + for (uint32_t j = 0; j < NARPT; ++j) { float m_new = -gpu_iface::math::inf, d_new = 1.f; #pragma unroll for (uint32_t i = 0; i < KTraits::NUM_WARPS_KV; ++i) { @@ -1929,8 +1932,8 @@ SinglePrefillWithKVCacheDevice(const Params params, [[maybe_unused]] constexpr MaskMode MASK_MODE = KTraits::MASK_MODE; [[maybe_unused]] constexpr uint32_t HALF_ELEMS_PER_THREAD = KTraits::HALF_ELEMS_PER_THREAD; - [[maybe_unused]] constexpr uint32_t NUM_MMA_ACCUM_CHUNKS_PER_THREAD = - KTraits::NUM_MMA_ACCUM_CHUNKS_PER_THREAD; + [[maybe_unused]] constexpr uint32_t NUM_ACCUM_ROWS_PER_THREAD = + KTraits::NUM_ACCUM_ROWS_PER_THREAD; DTypeQ *q = params.q; DTypeKV *k = params.k; @@ -1970,8 +1973,8 @@ SinglePrefillWithKVCacheDevice(const Params params, DTypeQKAccum s_frag[NUM_MMA_Q][NUM_MMA_KV][HALF_ELEMS_PER_THREAD]; alignas( 16) float o_frag[NUM_MMA_Q][NUM_MMA_D_VO][HALF_ELEMS_PER_THREAD]; - DTypeQKAccum m[NUM_MMA_Q][NUM_MMA_ACCUM_CHUNKS_PER_THREAD]; - float d[NUM_MMA_Q][NUM_MMA_ACCUM_CHUNKS_PER_THREAD]; + DTypeQKAccum m[NUM_MMA_Q][NUM_ACCUM_ROWS_PER_THREAD]; + float d[NUM_MMA_Q][NUM_ACCUM_ROWS_PER_THREAD]; float rope_freq[NUM_MMA_D_QK / 2][4]; if constexpr (KTraits::POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) { From 08f5ed104309888678af202fe8f830b6756c4e72 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Thu, 14 Aug 2025 00:46:55 -0400 Subject: [PATCH 036/109] Ported threadblock_sync_mdo_states --- .../flashinfer/attention/generic/prefill.cuh | 57 +++++++++++++------ 1 file changed, 39 insertions(+), 18 deletions(-) diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh index 8ddb06cc93..12a4f734e3 100644 --- a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -159,6 +159,12 @@ struct KernelTraits // are distributed as four 4x16 bands across the 64 threads. Each thread // owns one element from four different rows. static constexpr uint32_t NUM_ACCUM_ROWS_PER_THREAD = 4; + // Number of threads that collaboratively handle the same set of matrix rows + // in attention score computation and cross-warp synchronization. + // CUDA: 4 threads (each thread handles 2 elements from same row group) + // CDNA3: 16 threads (each thread handles 1 element from same row group) + static constexpr uint32_t THREADS_PER_MATRIX_ROW_SET = 16; + #else using SmemBasePtrTy = uint4; static constexpr uint32_t NUM_THREADS = NUM_WARPS_Q * NUM_WARPS_KV * 32; @@ -185,8 +191,8 @@ struct KernelTraits // Refer: // https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-fragment-mma-16816-i8-f8 static constexpr uint32_t NUM_ACCUM_ROWS_PER_THREAD = 2; + static constexpr uint32_t THREADS_PER_MATRIX_ROW_SET = 16; #endif - static constexpr uint32_t THREADS_PER_ROW_GROUP = WARP_THREAD_COLS / 2; static constexpr uint32_t UPCAST_STRIDE_Q = HEAD_DIM_QK / upcast_size(); static constexpr uint32_t UPCAST_STRIDE_K = @@ -579,7 +585,7 @@ init_rope_freq(float (*rope_freq)[4], constexpr uint32_t HEAD_DIM = KTraits::NUM_MMA_D_QK * 16; const uint32_t lane_idx = tid_x; - constexpr uint32_t THREADS_PER_ROW = KTraits::THREADS_PER_ROW_GROUP; + constexpr uint32_t THREADS_PER_ROW = KTraits::THREADS_PER_MATRIX_ROW_SET; constexpr uint32_t ELEMS_PER_THREAD = 8 / THREADS_PER_ROW; #pragma unroll @@ -739,7 +745,8 @@ __device__ __forceinline__ void q_smem_inplace_apply_rotary( (typename KTraits::DTypeQ *)q_frag_local[1], rope_freq[mma_di], q_packed_idx + kv_len * group_size - qo_len * group_size + - mma_q * 16 + lane_idx / KTraits::THREADS_PER_ROW_GROUP, + mma_q * 16 + + lane_idx / KTraits::THREADS_PER_MATRIX_ROW_SET, group_size); q_smem->template store_fragment(q_smem_offset_r_last_half, q_frag_local[1]); @@ -792,7 +799,7 @@ __device__ __forceinline__ void q_smem_inplace_apply_rotary_with_pos( (typename KTraits::DTypeQ *)q_frag_local[1], rope_freq[mma_di], q_packed_idx_base + mma_q * 16 + - lane_idx / KTraits::THREADS_PER_ROW_GROUP, + lane_idx / KTraits::THREADS_PER_MATRIX_ROW_SET, group_size, q_rope_offset); q_smem->store_fragment(q_smem_offset_r_last_half, q_frag_local[1]); @@ -819,7 +826,8 @@ __device__ __forceinline__ void k_smem_inplace_apply_rotary( using DTypeKV = typename KTraits::DTypeKV; static_assert(sizeof(DTypeKV) == 2); constexpr uint32_t UPCAST_STRIDE_K = KTraits::UPCAST_STRIDE_K; - constexpr uint32_t THREADS_PER_ROW_GROUP = KTraits::THREADS_PER_ROW_GROUP; + constexpr uint32_t THREADS_PER_MATRIX_ROW_SET = + KTraits::THREADS_PER_MATRIX_ROW_SET; constexpr uint32_t HALF_ELEMS_PER_THREAD = KTraits::HALF_ELEMS_PER_THREAD; uint32_t k_frag_local[2][KTraits::INT32_ELEMS_PER_THREAD]; const uint32_t lane_idx = tid.x; @@ -835,7 +843,7 @@ __device__ __forceinline__ void k_smem_inplace_apply_rotary( KTraits::NUM_MMA_KV % 2 == 0, "when NUM_MMA_D_QK == 4, NUM_MMA_KV must be a multiple of 2"); uint32_t kv_idx = kv_idx_base + (warp_idx / 2) * 16 + - lane_idx / THREADS_PER_ROW_GROUP; + lane_idx / THREADS_PER_MATRIX_ROW_SET; *k_smem_offset_r = (*k_smem_offset_r ^ (0x2 * (warp_idx % 2))) + (warp_idx / 2) * 16 * UPCAST_STRIDE_K; #pragma unroll @@ -872,7 +880,7 @@ __device__ __forceinline__ void k_smem_inplace_apply_rotary( // ... uint32_t kv_idx = kv_idx_base + (warp_idx_z * KTraits::NUM_MMA_KV * 16) + - lane_idx / THREADS_PER_ROW_GROUP; + lane_idx / THREADS_PER_MATRIX_ROW_SET; *k_smem_offset_r = *k_smem_offset_r ^ (0x2 * warp_idx_x); #pragma unroll for (uint32_t i = 0; i < KTraits::NUM_MMA_KV; ++i) { @@ -1052,7 +1060,7 @@ __device__ __forceinline__ void logits_transform( const uint32_t lane_idx = tid.x; uint32_t q[KTraits::NUM_MMA_Q][2], r[KTraits::NUM_MMA_Q][2]; float logits = 0., logitsTransformed = 0.; - constexpr uint32_t TPR = KTraits::THREADS_PER_ROW_GROUP; + constexpr uint32_t TPR = KTraits::THREADS_PER_MATRIX_ROW_SET; #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { @@ -1139,7 +1147,7 @@ logits_mask(const Params ¶ms, constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; using DTypeQKAccum = typename KTraits::DTypeQKAccum; constexpr MaskMode MASK_MODE = KTraits::MASK_MODE; - constexpr uint32_t TPR = KTraits::THREADS_PER_ROW_GROUP; + constexpr uint32_t TPR = KTraits::THREADS_PER_MATRIX_ROW_SET; uint32_t q[NUM_MMA_Q][2], r[NUM_MMA_Q][2]; #pragma unroll for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { @@ -1575,11 +1583,11 @@ __device__ __forceinline__ void threadblock_sync_mdo_states( const uint32_t lane_idx, const dim3 tid = threadIdx) { - constexpr uint32_t TPR = KTraits::THREADS_PER_ROW_GROUP; + constexpr uint32_t TPR = KTraits::THREADS_PER_MATRIX_ROW_SET; constexpr uint32_t NARPT = KTraits::NUM_ACCUM_ROWS_PER_THREAD; static_assert(WARP_SIZE % TPR == 0, - "THREADS_PER_ROW_GROUP must divide WARP_SIZE"); + "THREADS_PER_MATRIX_ROW_SET must divide WARP_SIZE"); constexpr uint32_t GROUPS_PER_WARP = WARP_SIZE / TPR; const uint32_t lane_group_idx = lane_idx / TPR; @@ -1587,8 +1595,12 @@ __device__ __forceinline__ void threadblock_sync_mdo_states( if constexpr (KTraits::NUM_WARPS_KV > 1) { float *smem_o = smem_storage->cta_sync_o_smem; float2 *smem_md = smem_storage->cta_sync_md_smem; - // o: [num_warps, NUM_MMA_Q, NUM_MMA_D_VO, WARP_SIZE, - // HALF_ELEMS_PER_THREAD] md: [num_warps, NUM_MMA_Q, 16, 2 (m/d)] + // o: [num_warps, + // NUM_MMA_Q, + // NUM_MMA_D_VO, + // WARP_SIZE, + // HALF_ELEMS_PER_THREAD] + // md: [num_warps, NUM_MMA_Q, 16, 2 (m/d)] #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { #pragma unroll @@ -1609,7 +1621,8 @@ __device__ __forceinline__ void threadblock_sync_mdo_states( for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { #pragma unroll for (uint32_t j = 0; j < NARPT; ++j) { - smem_md[((warp_idx * KTraits::NUM_MMA_Q + mma_q) * 2 + j) * + smem_md[((warp_idx * KTraits::NUM_MMA_Q + mma_q) * NARPT + + j) * GROUPS_PER_WARP + lane_group_idx] = make_float2(float(m[mma_q][j]), d[mma_q][j]); @@ -1620,7 +1633,7 @@ __device__ __forceinline__ void threadblock_sync_mdo_states( __syncthreads(); #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { - float o_scale[2][KTraits::NUM_WARPS_KV]; + float o_scale[NARPT][KTraits::NUM_WARPS_KV]; #pragma unroll for (uint32_t j = 0; j < NARPT; ++j) { float m_new = -gpu_iface::math::inf, d_new = 1.f; @@ -1630,7 +1643,7 @@ __device__ __forceinline__ void threadblock_sync_mdo_states( get_warp_idx_q(tid.y)) * KTraits::NUM_MMA_Q + mma_q) * - 2 + + NARPT + j) * GROUPS_PER_WARP + lane_group_idx]; @@ -1647,7 +1660,7 @@ __device__ __forceinline__ void threadblock_sync_mdo_states( get_warp_idx_q(tid.y)) * KTraits::NUM_MMA_Q + mma_q) * - 2 + + NARPT + j) * GROUPS_PER_WARP + lane_group_idx]; @@ -1681,8 +1694,16 @@ __device__ __forceinline__ void threadblock_sync_mdo_states( for (uint32_t reg_id = 0; reg_id < KTraits::HALF_ELEMS_PER_THREAD; ++reg_id) { +#if defined(PLATFORM_HIP_DEVICE) + // CDNA3: Direct mapping - each reg_id corresponds + // to one accumulator row + o_new[reg_id] += oi[reg_id] * o_scale[reg_id][i]; +#else + // CUDA: Grouped mapping - 2 elements per + // accumulator row o_new[reg_id] += oi[reg_id] * o_scale[(reg_id % 4) / 2][i]; +#endif } } o_new.store(o_frag[mma_q][mma_d]); @@ -1739,7 +1760,7 @@ __device__ __forceinline__ void write_o_reg_gmem( { using DTypeO = typename KTraits::DTypeO; constexpr uint32_t UPCAST_STRIDE_O = KTraits::UPCAST_STRIDE_O; - constexpr uint32_t TPR = KTraits::THREADS_PER_ROW_GROUP; + constexpr uint32_t TPR = KTraits::THREADS_PER_MATRIX_ROW_SET; constexpr uint32_t HALF_ELEMS_PER_THREAD = KTraits::HALF_ELEMS_PER_THREAD; const uint32_t warp_idx_x = get_warp_idx_q(tid.y); const uint32_t lane_idx = tid.x; From 7893ddd9cf01b5e8077acf0171035a8f36cdae7e Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Thu, 14 Aug 2025 14:12:32 -0400 Subject: [PATCH 037/109] write_o_reg_smem ported --- .../flashinfer/attention/generic/prefill.cuh | 64 +- .../gpu_iface/backend/cuda/vec_dtypes.cuh | 1781 +++++++++++++++++ 2 files changed, 1823 insertions(+), 22 deletions(-) create mode 100644 libflashinfer/include/gpu_iface/backend/cuda/vec_dtypes.cuh diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh index 12a4f734e3..802440b2eb 100644 --- a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -164,7 +164,6 @@ struct KernelTraits // CUDA: 4 threads (each thread handles 2 elements from same row group) // CDNA3: 16 threads (each thread handles 1 element from same row group) static constexpr uint32_t THREADS_PER_MATRIX_ROW_SET = 16; - #else using SmemBasePtrTy = uint4; static constexpr uint32_t NUM_THREADS = NUM_WARPS_Q * NUM_WARPS_KV * 32; @@ -191,7 +190,7 @@ struct KernelTraits // Refer: // https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-fragment-mma-16816-i8-f8 static constexpr uint32_t NUM_ACCUM_ROWS_PER_THREAD = 2; - static constexpr uint32_t THREADS_PER_MATRIX_ROW_SET = 16; + static constexpr uint32_t THREADS_PER_MATRIX_ROW_SET = 4; #endif static constexpr uint32_t UPCAST_STRIDE_Q = HEAD_DIM_QK / upcast_size(); @@ -1761,7 +1760,10 @@ __device__ __forceinline__ void write_o_reg_gmem( using DTypeO = typename KTraits::DTypeO; constexpr uint32_t UPCAST_STRIDE_O = KTraits::UPCAST_STRIDE_O; constexpr uint32_t TPR = KTraits::THREADS_PER_MATRIX_ROW_SET; + constexpr uint32_t NAPTR = KTraits::NUM_ACCUM_ROWS_PER_THREAD; constexpr uint32_t HALF_ELEMS_PER_THREAD = KTraits::HALF_ELEMS_PER_THREAD; + constexpr uint32_t WARP_THREAD_COLS = KTraits::WARP_THREAD_COLS; + const uint32_t warp_idx_x = get_warp_idx_q(tid.y); const uint32_t lane_idx = tid.x; @@ -1769,7 +1771,7 @@ __device__ __forceinline__ void write_o_reg_gmem( #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { #pragma unroll - for (uint32_t j = 0; j < 2; ++j) { + for (uint32_t j = 0; j < NAPTR; ++j) { uint32_t q, r; group_size.divmod(o_packed_idx_base + lane_idx / TPR + mma_q * 16 + j * 8, @@ -1779,16 +1781,22 @@ __device__ __forceinline__ void write_o_reg_gmem( for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { if (o_idx < qo_upper_bound) { - *reinterpret_cast( - o_ptr_base + q * o_stride_n + r * o_stride_h + - mma_d * 16 + (lane_idx % TPR) * 2) = + auto base_addr = o_ptr_base + q * o_stride_n + + r * o_stride_h + mma_d * 16; + auto col_offset = lane_idx % 16; +#if defined(PLATFORM_HIP_DEVICE) + *(base_addr + col_offset) = o_frag[mma_q][mma_d][j]; +#else + *reinterpret_cast(base_addr + + col_offset * 2) = *reinterpret_cast( &o_frag[mma_q][mma_d][j * 2]); - *reinterpret_cast( - o_ptr_base + q * o_stride_n + r * o_stride_h + - mma_d * 16 + 8 + (lane_idx % TPR) * 2) = + + *reinterpret_cast(base_addr + 8 + + col_offset * 2) = *reinterpret_cast( - &o_frag[mma_q][mma_d][4 + j * 2]); + &o_frag[mma_q][mma_d][$ + j * 2]); +#endif } } } @@ -1817,41 +1825,53 @@ __device__ __forceinline__ void write_o_reg_gmem( uint32_t o_smem_offset_w = o_smem->template get_permuted_offset( (warp_idx_x * KTraits::NUM_MMA_Q + mma_q) * 16 + - lane_idx / 4, + lane_idx / TPR, mma_d * 2); +#if defined(PLATFORM_HIP_DEVICE) ((uint32_t *)(o_smem->base + - o_smem_offset_w))[lane_idx % 4] = + o_smem_offset_w))[lane_idx % TPR] = + o_frag_f16[0]; + // Move 2 elements forward in the same row + uint32_t offset_2 = o_smem_offset_w + 2; + ((uint32_t *)(o_smem->base + offset_2))[lane_idx % 16] = + o_frag_f16[1]; +#else + ((uint32_t *)(o_smem->base + + o_smem_offset_w))[lane_idx % TPR] = o_frag_f16[0]; ((uint32_t *)(o_smem->base + o_smem_offset_w + 8 * UPCAST_STRIDE_O))[lane_idx % 4] = o_frag_f16[1]; ((uint32_t *)(o_smem->base + - (o_smem_offset_w ^ 0x1)))[lane_idx % 4] = + (o_smem_offset_w ^ 0x1)))[lane_idx % TPR] = o_frag_f16[2]; ((uint32_t *)(o_smem->base + (o_smem_offset_w ^ 0x1) + 8 * UPCAST_STRIDE_O))[lane_idx % 4] = o_frag_f16[3]; +#endif #endif } } uint32_t o_smem_offset_w = o_smem->template get_permuted_offset( - warp_idx_x * KTraits::NUM_MMA_Q * 16 + lane_idx / 8, - lane_idx % 8); + warp_idx_x * KTraits::NUM_MMA_Q * 16 + + lane_idx / WARP_THREAD_COLS, + lane_idx % WARP_THREAD_COLS); #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { #pragma unroll for (uint32_t j = 0; j < 2 * 2; ++j) { uint32_t q, r; - group_size.divmod(o_packed_idx_base + lane_idx / 8 + + group_size.divmod(o_packed_idx_base + + lane_idx / WARP_THREAD_COLS + mma_q * 16 + j * 4, q, r); const uint32_t o_idx = q; - DTypeO *o_ptr = o_ptr_base + q * o_stride_n + - r * o_stride_h + - (lane_idx % 8) * upcast_size(); + DTypeO *o_ptr = + o_ptr_base + q * o_stride_n + r * o_stride_h + + (lane_idx % WARP_THREAD_COLS) * upcast_size(); #pragma unroll for (uint32_t mma_do = 0; mma_do < KTraits::NUM_MMA_D_VO / 4; ++mma_do) @@ -1859,10 +1879,10 @@ __device__ __forceinline__ void write_o_reg_gmem( if (o_idx < qo_upper_bound) { o_smem->store_vector(o_smem_offset_w, o_ptr); } - o_ptr += 8 * upcast_size(); + o_ptr += WARP_THREAD_COLS * upcast_size(); o_smem_offset_w = - o_smem->template advance_offset_by_column<8>( - o_smem_offset_w, mma_do); + o_smem->template advance_offset_by_column< + WARP_THREAD_COLS>(o_smem_offset_w, mma_do); } o_smem_offset_w = o_smem->template advance_offset_by_row< 4, UPCAST_STRIDE_O>(o_smem_offset_w) - diff --git a/libflashinfer/include/gpu_iface/backend/cuda/vec_dtypes.cuh b/libflashinfer/include/gpu_iface/backend/cuda/vec_dtypes.cuh new file mode 100644 index 0000000000..e6d6396723 --- /dev/null +++ b/libflashinfer/include/gpu_iface/backend/cuda/vec_dtypes.cuh @@ -0,0 +1,1781 @@ +/* + * Copyright (c) 2023 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef VEC_DTYPES_CUH_ +#define VEC_DTYPES_CUH_ + +#include +#include +#include +#include + +#include + +namespace flashinfer +{ + +#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 900)) +#define FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED +#endif + +#define FLASHINFER_INLINE inline __attribute__((always_inline)) __device__ + +#if (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 < 120200) && \ + (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)) +// CUDA version < 12.2 and GPU architecture < 80 +FLASHINFER_INLINE __nv_bfloat162 make_bfloat162(const __nv_bfloat16 x, + const __nv_bfloat16 y) +{ + __nv_bfloat162 t; + t.x = x; + t.y = y; + return t; +} + +FLASHINFER_INLINE __nv_bfloat16 __hmul(const __nv_bfloat16 a, + const __nv_bfloat16 b) +{ + __nv_bfloat16 val; + const float fa = __bfloat162float(a); + const float fb = __bfloat162float(b); + // avoid ftz in device code + val = __float2bfloat16(__fmaf_ieee_rn(fa, fb, -0.0f)); + return val; +} + +FLASHINFER_INLINE __nv_bfloat162 __hmul2(const __nv_bfloat162 a, + const __nv_bfloat162 b) +{ + __nv_bfloat162 val; + val.x = __hmul(a.x, b.x); + val.y = __hmul(a.y, b.y); + return val; +} + +FLASHINFER_INLINE __nv_bfloat162 __floats2bfloat162_rn(const float a, + const float b) +{ + __nv_bfloat162 val; + val = __nv_bfloat162(__float2bfloat16_rn(a), __float2bfloat16_rn(b)); + return val; +} + +FLASHINFER_INLINE __nv_bfloat162 __float22bfloat162_rn(const float2 a) +{ + __nv_bfloat162 val = __floats2bfloat162_rn(a.x, a.y); + return val; +} +FLASHINFER_INLINE float2 __bfloat1622float2(const __nv_bfloat162 a) +{ + float hi_float; + float lo_float; + lo_float = __internal_bfloat162float(((__nv_bfloat162_raw)a).x); + hi_float = __internal_bfloat162float(((__nv_bfloat162_raw)a).y); + return make_float2(lo_float, hi_float); +} +#endif + +/******************* vec_t type cast *******************/ + +template struct vec_cast +{ + template + FLASHINFER_INLINE static void cast(dst_t *dst, const src_t *src) + { +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + dst[i] = (dst_t)src[i]; + } + } +}; + +template <> struct vec_cast +{ + template + FLASHINFER_INLINE static void cast(float *dst, const half *src) + { + if constexpr (vec_size == 1) { + dst[0] = (float)src[0]; + } + else { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((float2 *)dst)[i] = __half22float2(((half2 *)src)[i]); + } + } + } +}; + +template <> struct vec_cast +{ + template + FLASHINFER_INLINE static void cast(half *dst, const float *src) + { + if constexpr (vec_size == 1) { + dst[0] = __float2half(src[0]); + } + else { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((half2 *)dst)[i] = __float22half2_rn(((float2 *)src)[i]); + } + } + } +}; + +template constexpr FLASHINFER_INLINE int get_exponent_bits() +{ + if constexpr (std::is_same_v) { + return 4; + } + else if constexpr (std::is_same_v) { + return 5; + } + else if constexpr (std::is_same_v) { + return 5; + } + else if constexpr (std::is_same_v) { + return 8; + } +} + +template constexpr FLASHINFER_INLINE int get_mantissa_bits() +{ + if constexpr (std::is_same_v) { + return 3; + } + else if constexpr (std::is_same_v) { + return 2; + } + else if constexpr (std::is_same_v) { + return 11; + } + else if constexpr (std::is_same_v) { + return 7; + } +} + +/*! + * \brief Fallback to software fast dequant implementation if hardware + * dequantization is not available. + * \note Inspired by Marlin's fast dequantization, but here we don't have to + * permute weights order. + * \ref + * https://github.com/vllm-project/vllm/blob/6dffa4b0a6120159ef2fe44d695a46817aff65bc/csrc/quantization/fp8/fp8_marlin.cu#L120 + */ +template +__device__ void fast_dequant_f8f16x4(uint32_t *input, uint2 *output) +{ + uint32_t q = *input; + if constexpr (std::is_same_v && + std::is_same_v) + { + output->x = __byte_perm(0U, q, 0x5140); + output->y = __byte_perm(0U, q, 0x7362); + } + else { + constexpr int FP8_EXPONENT = get_exponent_bits(); + constexpr int FP8_MANTISSA = get_mantissa_bits(); + constexpr int FP16_EXPONENT = get_exponent_bits(); + + constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP8_EXPONENT; + // Calculate MASK for extracting mantissa and exponent + constexpr int MASK1 = 0x80000000; + constexpr int MASK2 = MASK1 >> (FP8_EXPONENT + FP8_MANTISSA); + constexpr int MASK3 = MASK2 & 0x7fffffff; + constexpr int MASK = MASK3 | (MASK3 >> 16); + q = __byte_perm(q, q, 0x1302); + + // Extract and shift FP8 values to FP16 format + uint32_t Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + uint32_t Out2 = + ((q << 8) & 0x80008000) | (((q << 8) & MASK) >> RIGHT_SHIFT); + + constexpr int BIAS_OFFSET = + (1 << (FP16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1)); + // Construct and apply exponent bias + if constexpr (std::is_same_v) { + const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET)); + + // Convert to half2 and apply bias + *(half2 *)&(output->x) = + __hmul2(*reinterpret_cast(&Out1), bias_reg); + *(half2 *)&(output->y) = + __hmul2(*reinterpret_cast(&Out2), bias_reg); + } + else { + constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23; + const nv_bfloat162 bias_reg = + __float2bfloat162_rn(*reinterpret_cast(&BIAS)); + // Convert to bfloat162 and apply bias + *(nv_bfloat162 *)&(output->x) = __hmul2( + *reinterpret_cast(&Out1), bias_reg); + *(nv_bfloat162 *)&(output->y) = __hmul2( + *reinterpret_cast(&Out2), bias_reg); + } + } +} + +template <> struct vec_cast +{ + template + FLASHINFER_INLINE static void cast(nv_bfloat16 *dst, + const __nv_fp8_e4m3 *src) + { + if constexpr (vec_size == 1) { + dst[0] = nv_bfloat16(src[0]); + } + else if constexpr (vec_size == 2) { + dst[0] = nv_bfloat16(src[0]); + dst[1] = nv_bfloat16(src[1]); + } + else { + static_assert(vec_size % 4 == 0, + "vec_size must be a multiple of 4"); +#pragma unroll + for (uint32_t i = 0; i < vec_size / 4; ++i) { + fast_dequant_f8f16x4<__nv_fp8_e4m3, nv_bfloat16>( + (uint32_t *)&src[i * 4], (uint2 *)&dst[i * 4]); + } + } + } +}; + +template <> struct vec_cast +{ + template + FLASHINFER_INLINE static void cast(nv_bfloat16 *dst, + const __nv_fp8_e5m2 *src) + { + if constexpr (vec_size == 1) { + dst[0] = nv_bfloat16(src[0]); + } + else if constexpr (vec_size == 2) { + dst[0] = nv_bfloat16(src[0]); + dst[1] = nv_bfloat16(src[1]); + } + else { + static_assert(vec_size % 4 == 0, + "vec_size must be a multiple of 4"); +#pragma unroll + for (uint32_t i = 0; i < vec_size / 4; ++i) { + fast_dequant_f8f16x4<__nv_fp8_e5m2, nv_bfloat16>( + (uint32_t *)&src[i * 4], (uint2 *)&dst[i * 4]); + } + } + } +}; + +template <> struct vec_cast<__nv_fp8_e4m3, half> +{ + template + FLASHINFER_INLINE static void cast(__nv_fp8_e4m3 *dst, const half *src) + { +#ifdef FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED + if constexpr (vec_size == 1) { + dst[0] = __nv_fp8_e4m3(src[0]); + } + else { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + uint16_t y; + uint32_t x = *(uint32_t *)&src[i * 2]; + asm volatile("cvt.rn.satfinite.e4m3x2.f16x2 %0, %1;" + : "=h"(y) + : "r"(x)); + *(uint16_t *)&dst[i * 2] = y; + } + } +#else +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + dst[i] = __nv_fp8_e4m3(src[i]); + } +#endif // FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED + } +}; + +template <> struct vec_cast<__nv_fp8_e5m2, half> +{ + template + FLASHINFER_INLINE static void cast(__nv_fp8_e5m2 *dst, const half *src) + { +#ifdef FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED + if constexpr (vec_size == 1) { + dst[0] = __nv_fp8_e5m2(src[0]); + } + else { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + uint16_t y; + uint32_t x = *(uint32_t *)&src[i * 2]; + asm volatile("cvt.rn.satfinite.e5m2x2.f16x2 %0, %1;" + : "=h"(y) + : "r"(x)); + *(uint16_t *)&dst[i * 2] = y; + } + } +#else +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + dst[i] = __nv_fp8_e5m2(src[i]); + } +#endif // FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED + } +}; + +template <> struct vec_cast +{ + template + FLASHINFER_INLINE static void cast(half *dst, const __nv_fp8_e4m3 *src) + { +#ifdef FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED + if constexpr (vec_size == 1) { + dst[0] = half(src[0]); + } + else { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + uint32_t y; + uint16_t x = *(uint16_t *)&src[i * 2]; + asm volatile("cvt.rn.f16x2.e4m3x2 %0, %1;" : "=r"(y) : "h"(x)); + *(uint32_t *)&dst[i * 2] = y; + } + } +#else + if constexpr (vec_size == 1) { + dst[0] = half(src[0]); + } + else if constexpr (vec_size == 2) { + dst[0] = half(src[0]); + dst[1] = half(src[1]); + } + else { + static_assert(vec_size % 4 == 0, + "vec_size must be a multiple of 4"); +#pragma unroll + for (uint32_t i = 0; i < vec_size / 4; ++i) { + fast_dequant_f8f16x4<__nv_fp8_e4m3, half>( + (uint32_t *)&src[i * 4], (uint2 *)&dst[i * 4]); + } + } +#endif // FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED + } +}; + +template <> struct vec_cast +{ + template + FLASHINFER_INLINE static void cast(half *dst, const __nv_fp8_e5m2 *src) + { +#ifdef FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED + if constexpr (vec_size == 1) { + dst[0] = half(src[0]); + } + else { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + uint32_t y; + uint16_t x = *(uint16_t *)&src[i * 2]; + asm volatile("cvt.rn.f16x2.e5m2x2 %0, %1;" : "=r"(y) : "h"(x)); + *(uint32_t *)&dst[i * 2] = y; + } + } +#else + if constexpr (vec_size == 1) { + dst[0] = half(src[0]); + } + else if constexpr (vec_size == 2) { + dst[0] = half(src[0]); + dst[1] = half(src[1]); + } + else { + static_assert(vec_size % 4 == 0, + "vec_size must be a multiple of 4"); +#pragma unroll + for (uint32_t i = 0; i < vec_size / 4; ++i) { + fast_dequant_f8f16x4<__nv_fp8_e5m2, half>( + (uint32_t *)&src[i * 4], (uint2 *)&dst[i * 4]); + } + } +#endif // FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED + } +}; + +template <> struct vec_cast +{ + template + FLASHINFER_INLINE static void cast(float *dst, const nv_bfloat16 *src) + { + if constexpr (vec_size == 1) { + dst[0] = (float)src[0]; + } + else { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((float2 *)dst)[i] = + __bfloat1622float2(((nv_bfloat162 *)src)[i]); + } + } + } +}; + +template <> struct vec_cast +{ + template + FLASHINFER_INLINE static void cast(nv_bfloat16 *dst, const float *src) + { + if constexpr (vec_size == 1) { + dst[0] = nv_bfloat16(src[0]); + } + else { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((nv_bfloat162 *)dst)[i] = + __float22bfloat162_rn(((float2 *)src)[i]); + } + } + } +}; + +template struct vec_t +{ + FLASHINFER_INLINE float_t &operator[](size_t i); + FLASHINFER_INLINE const float_t &operator[](size_t i) const; + FLASHINFER_INLINE void fill(float_t val); + FLASHINFER_INLINE void load(const float_t *ptr); + FLASHINFER_INLINE void store(float_t *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src); + template FLASHINFER_INLINE void cast_load(const T *ptr); + template FLASHINFER_INLINE void cast_store(T *ptr) const; + FLASHINFER_INLINE static void memcpy(float_t *dst, const float_t *src); + FLASHINFER_INLINE float_t *ptr(); +}; + +template +FLASHINFER_INLINE void cast_from_impl(vec_t &dst, + const vec_t &src) +{ + vec_cast::cast( + dst.ptr(), const_cast *>(&src)->ptr()); +} + +template +FLASHINFER_INLINE void cast_load_impl(vec_t &dst, + const src_float_t *src_ptr) +{ + if constexpr (std::is_same_v) { + dst.load(src_ptr); + } + else { + vec_t tmp; + tmp.load(src_ptr); + dst.cast_from(tmp); + } +} + +template +FLASHINFER_INLINE void cast_store_impl(tgt_float_t *dst_ptr, + const vec_t &src) +{ + if constexpr (std::is_same_v) { + src.store(dst_ptr); + } + else { + vec_t tmp; + tmp.cast_from(src); + tmp.store(dst_ptr); + } +} + +/******************* vec_t<__nv_fp8_e4m3> *******************/ + +// __nv_fp8_e4m3 x 1 +template <> struct vec_t<__nv_fp8_e4m3, 1> +{ + __nv_fp8_e4m3 data; + + FLASHINFER_INLINE __nv_fp8_e4m3 &operator[](size_t i) + { + return ((__nv_fp8_e4m3 *)(&data))[i]; + } + FLASHINFER_INLINE const __nv_fp8_e4m3 &operator[](size_t i) const + { + return ((const __nv_fp8_e4m3 *)(&data))[i]; + } + FLASHINFER_INLINE __nv_fp8_e4m3 *ptr() + { + return reinterpret_cast<__nv_fp8_e4m3 *>(&data); + } + FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val); + FLASHINFER_INLINE void load(const __nv_fp8_e4m3 *ptr); + FLASHINFER_INLINE void store(__nv_fp8_e4m3 *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) + { + cast_from_impl(*this, src); + } + template FLASHINFER_INLINE void cast_load(const T *ptr) + { + cast_load_impl(*this, ptr); + } + template FLASHINFER_INLINE void cast_store(T *ptr) const + { + cast_store_impl(ptr, *this); + } + + FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3 *dst, + const __nv_fp8_e4m3 *src); +}; + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 1>::fill(__nv_fp8_e4m3 val) +{ + data = val; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 1>::load(const __nv_fp8_e4m3 *ptr) +{ + data = *ptr; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 1>::store(__nv_fp8_e4m3 *ptr) const +{ + *ptr = data; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 1>::memcpy(__nv_fp8_e4m3 *dst, + const __nv_fp8_e4m3 *src) +{ + *dst = *src; +} + +// __nv_fp8_e4m3 x 2 +template <> struct vec_t<__nv_fp8_e4m3, 2> +{ + __nv_fp8x2_e4m3 data; + + FLASHINFER_INLINE __nv_fp8_e4m3 &operator[](size_t i) + { + return ((__nv_fp8_e4m3 *)(&data))[i]; + } + FLASHINFER_INLINE const __nv_fp8_e4m3 &operator[](size_t i) const + { + return ((const __nv_fp8_e4m3 *)(&data))[i]; + } + FLASHINFER_INLINE __nv_fp8_e4m3 *ptr() + { + return reinterpret_cast<__nv_fp8_e4m3 *>(&data); + } + FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val); + FLASHINFER_INLINE void load(const __nv_fp8_e4m3 *ptr); + FLASHINFER_INLINE void store(__nv_fp8_e4m3 *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) + { + cast_from_impl(*this, src); + } + template FLASHINFER_INLINE void cast_load(const T *ptr) + { + cast_load_impl(*this, ptr); + } + template FLASHINFER_INLINE void cast_store(T *ptr) const + { + cast_store_impl(ptr, *this); + } + FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3 *dst, + const __nv_fp8_e4m3 *src); +}; + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 2>::fill(__nv_fp8_e4m3 val) +{ + data.__x = + (__nv_fp8x2_storage_t(val.__x) << 8) | __nv_fp8x2_storage_t(val.__x); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 2>::load(const __nv_fp8_e4m3 *ptr) +{ + data = *((__nv_fp8x2_e4m3 *)ptr); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 2>::store(__nv_fp8_e4m3 *ptr) const +{ + *((__nv_fp8x2_e4m3 *)ptr) = data; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 2>::memcpy(__nv_fp8_e4m3 *dst, + const __nv_fp8_e4m3 *src) +{ + *((__nv_fp8x2_e4m3 *)dst) = *((__nv_fp8x2_e4m3 *)src); +} + +// __nv_fp8_e4m3 x 4 + +template <> struct vec_t<__nv_fp8_e4m3, 4> +{ + __nv_fp8x4_e4m3 data; + + FLASHINFER_INLINE __nv_fp8_e4m3 &operator[](size_t i) + { + return ((__nv_fp8_e4m3 *)(&data))[i]; + } + FLASHINFER_INLINE const __nv_fp8_e4m3 &operator[](size_t i) const + { + return ((const __nv_fp8_e4m3 *)(&data))[i]; + } + FLASHINFER_INLINE __nv_fp8_e4m3 *ptr() + { + return reinterpret_cast<__nv_fp8_e4m3 *>(&data); + } + FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val); + FLASHINFER_INLINE void load(const __nv_fp8_e4m3 *ptr); + FLASHINFER_INLINE void store(__nv_fp8_e4m3 *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) + { + cast_from_impl(*this, src); + } + template FLASHINFER_INLINE void cast_load(const T *ptr) + { + cast_load_impl(*this, ptr); + } + template FLASHINFER_INLINE void cast_store(T *ptr) const + { + cast_store_impl(ptr, *this); + } + + FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3 *dst, + const __nv_fp8_e4m3 *src); +}; + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::fill(__nv_fp8_e4m3 val) +{ + data.__x = (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | + __nv_fp8x4_storage_t(val.__x); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::load(const __nv_fp8_e4m3 *ptr) +{ + data = *((__nv_fp8x4_e4m3 *)ptr); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::store(__nv_fp8_e4m3 *ptr) const +{ + *((__nv_fp8x4_e4m3 *)ptr) = data; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::memcpy(__nv_fp8_e4m3 *dst, + const __nv_fp8_e4m3 *src) +{ + *((__nv_fp8x4_e4m3 *)dst) = *((__nv_fp8x4_e4m3 *)src); +} + +// __nv_fp8_e4m3 x 8 + +template <> struct vec_t<__nv_fp8_e4m3, 8> +{ + uint2 data; + + FLASHINFER_INLINE __nv_fp8_e4m3 &operator[](size_t i) + { + return ((__nv_fp8_e4m3 *)(&data))[i]; + } + FLASHINFER_INLINE const __nv_fp8_e4m3 &operator[](size_t i) const + { + return ((const __nv_fp8_e4m3 *)(&data))[i]; + } + FLASHINFER_INLINE __nv_fp8_e4m3 *ptr() + { + return reinterpret_cast<__nv_fp8_e4m3 *>(&data); + } + FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val); + FLASHINFER_INLINE void load(const __nv_fp8_e4m3 *ptr); + FLASHINFER_INLINE void store(__nv_fp8_e4m3 *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) + { + cast_from_impl(*this, src); + } + template FLASHINFER_INLINE void cast_load(const T *ptr) + { + cast_load_impl(*this, ptr); + } + template FLASHINFER_INLINE void cast_store(T *ptr) const + { + cast_store_impl(ptr, *this); + } + + FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3 *dst, + const __nv_fp8_e4m3 *src); +}; + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 8>::fill(__nv_fp8_e4m3 val) +{ + ((__nv_fp8x4_e4m3 *)(&data.x))->__x = + (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); + ((__nv_fp8x4_e4m3 *)(&data.y))->__x = + (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 8>::load(const __nv_fp8_e4m3 *ptr) +{ + data = *((uint2 *)ptr); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 8>::store(__nv_fp8_e4m3 *ptr) const +{ + *((uint2 *)ptr) = data; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 8>::memcpy(__nv_fp8_e4m3 *dst, + const __nv_fp8_e4m3 *src) +{ + *((uint2 *)dst) = *((uint2 *)src); +} + +// __nv_fp8_e4m3 x 16 or more +template struct vec_t<__nv_fp8_e4m3, vec_size> +{ + uint4 data[vec_size / 16]; + + FLASHINFER_INLINE __nv_fp8_e4m3 &operator[](size_t i) + { + return ((__nv_fp8_e4m3 *)data)[i]; + } + FLASHINFER_INLINE const __nv_fp8_e4m3 &operator[](size_t i) const + { + return ((const __nv_fp8_e4m3 *)data)[i]; + } + FLASHINFER_INLINE __nv_fp8_e4m3 *ptr() + { + return reinterpret_cast<__nv_fp8_e4m3 *>(&data); + } + FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val) + { +#pragma unroll + for (size_t i = 0; i < vec_size / 16; ++i) { + ((__nv_fp8x4_e4m3 *)(&(data[i].x)))->__x = + (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | + __nv_fp8x4_storage_t(val.__x); + ((__nv_fp8x4_e4m3 *)(&(data[i].y)))->__x = + (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | + __nv_fp8x4_storage_t(val.__x); + ((__nv_fp8x4_e4m3 *)(&(data[i].z)))->__x = + (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | + __nv_fp8x4_storage_t(val.__x); + ((__nv_fp8x4_e4m3 *)(&(data[i].w)))->__x = + (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | + __nv_fp8x4_storage_t(val.__x); + } + } + FLASHINFER_INLINE void load(const __nv_fp8_e4m3 *ptr) + { +#pragma unroll + for (size_t i = 0; i < vec_size / 16; ++i) { + data[i] = ((uint4 *)ptr)[i]; + } + } + FLASHINFER_INLINE void store(__nv_fp8_e4m3 *ptr) const + { +#pragma unroll + for (size_t i = 0; i < vec_size / 16; ++i) { + ((uint4 *)ptr)[i] = data[i]; + } + } + template + FLASHINFER_INLINE void cast_from(const vec_t &src) + { + cast_from_impl(*this, src); + } + template FLASHINFER_INLINE void cast_load(const T *ptr) + { + cast_load_impl(*this, ptr); + } + template FLASHINFER_INLINE void cast_store(T *ptr) const + { + cast_store_impl(ptr, *this); + } + + FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3 *dst, + const __nv_fp8_e4m3 *src) + { +#pragma unroll + for (size_t i = 0; i < vec_size / 16; ++i) { + ((uint4 *)dst)[i] = ((uint4 *)src)[i]; + } + } +}; + +/******************* vec_t<__nv_fp8_e5m2> *******************/ + +// __nv_fp8_e5m2 x 1 +template <> struct vec_t<__nv_fp8_e5m2, 1> +{ + __nv_fp8_e5m2 data; + + FLASHINFER_INLINE __nv_fp8_e5m2 &operator[](size_t i) + { + return ((__nv_fp8_e5m2 *)(&data))[i]; + } + FLASHINFER_INLINE const __nv_fp8_e5m2 &operator[](size_t i) const + { + return ((const __nv_fp8_e5m2 *)(&data))[i]; + } + FLASHINFER_INLINE __nv_fp8_e5m2 *ptr() + { + return reinterpret_cast<__nv_fp8_e5m2 *>(&data); + } + FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val); + FLASHINFER_INLINE void load(const __nv_fp8_e5m2 *ptr); + FLASHINFER_INLINE void store(__nv_fp8_e5m2 *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) + { + cast_from_impl(*this, src); + } + template FLASHINFER_INLINE void cast_load(const T *ptr) + { + cast_load_impl(*this, ptr); + } + template FLASHINFER_INLINE void cast_store(T *ptr) const + { + cast_store_impl(ptr, *this); + } + + FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2 *dst, + const __nv_fp8_e5m2 *src); +}; + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 1>::fill(__nv_fp8_e5m2 val) +{ + data = val; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 1>::load(const __nv_fp8_e5m2 *ptr) +{ + data = *ptr; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 1>::store(__nv_fp8_e5m2 *ptr) const +{ + *ptr = data; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 1>::memcpy(__nv_fp8_e5m2 *dst, + const __nv_fp8_e5m2 *src) +{ + *dst = *src; +} + +// __nv_fp8_e5m2 x 2 +template <> struct vec_t<__nv_fp8_e5m2, 2> +{ + __nv_fp8x2_e5m2 data; + + FLASHINFER_INLINE __nv_fp8_e5m2 &operator[](size_t i) + { + return ((__nv_fp8_e5m2 *)(&data))[i]; + } + FLASHINFER_INLINE const __nv_fp8_e5m2 &operator[](size_t i) const + { + return ((const __nv_fp8_e5m2 *)(&data))[i]; + } + FLASHINFER_INLINE __nv_fp8_e5m2 *ptr() + { + return reinterpret_cast<__nv_fp8_e5m2 *>(&data); + } + FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val); + FLASHINFER_INLINE void load(const __nv_fp8_e5m2 *ptr); + FLASHINFER_INLINE void store(__nv_fp8_e5m2 *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) + { + cast_from_impl(*this, src); + } + template FLASHINFER_INLINE void cast_load(const T *ptr) + { + cast_load_impl(*this, ptr); + } + template FLASHINFER_INLINE void cast_store(T *ptr) const + { + cast_store_impl(ptr, *this); + } + + FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2 *dst, + const __nv_fp8_e5m2 *src); +}; + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 2>::fill(__nv_fp8_e5m2 val) +{ + data.__x = + (__nv_fp8x2_storage_t(val.__x) << 8) | __nv_fp8x2_storage_t(val.__x); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 2>::load(const __nv_fp8_e5m2 *ptr) +{ + data = *((__nv_fp8x2_e5m2 *)ptr); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 2>::store(__nv_fp8_e5m2 *ptr) const +{ + *((__nv_fp8x2_e5m2 *)ptr) = data; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 2>::memcpy(__nv_fp8_e5m2 *dst, + const __nv_fp8_e5m2 *src) +{ + *((__nv_fp8x2_e5m2 *)dst) = *((__nv_fp8x2_e5m2 *)src); +} + +// __nv_fp8_e5m2 x 4 + +template <> struct vec_t<__nv_fp8_e5m2, 4> +{ + __nv_fp8x4_e5m2 data; + + FLASHINFER_INLINE __nv_fp8_e5m2 &operator[](size_t i) + { + return ((__nv_fp8_e5m2 *)(&data))[i]; + } + FLASHINFER_INLINE const __nv_fp8_e5m2 &operator[](size_t i) const + { + return ((const __nv_fp8_e5m2 *)(&data))[i]; + } + FLASHINFER_INLINE __nv_fp8_e5m2 *ptr() + { + return reinterpret_cast<__nv_fp8_e5m2 *>(&data); + } + FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val); + FLASHINFER_INLINE void load(const __nv_fp8_e5m2 *ptr); + FLASHINFER_INLINE void store(__nv_fp8_e5m2 *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) + { + cast_from_impl(*this, src); + } + template FLASHINFER_INLINE void cast_load(const T *ptr) + { + cast_load_impl(*this, ptr); + } + template FLASHINFER_INLINE void cast_store(T *ptr) const + { + cast_store_impl(ptr, *this); + } + + FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2 *dst, + const __nv_fp8_e5m2 *src); +}; + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::fill(__nv_fp8_e5m2 val) +{ + data.__x = (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | + __nv_fp8x4_storage_t(val.__x); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::load(const __nv_fp8_e5m2 *ptr) +{ + data = *((__nv_fp8x4_e5m2 *)ptr); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::store(__nv_fp8_e5m2 *ptr) const +{ + *((__nv_fp8x4_e5m2 *)ptr) = data; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::memcpy(__nv_fp8_e5m2 *dst, + const __nv_fp8_e5m2 *src) +{ + *((__nv_fp8x4_e5m2 *)dst) = *((__nv_fp8x4_e5m2 *)src); +} + +// __nv_fp8_e5m2 x 8 + +template <> struct vec_t<__nv_fp8_e5m2, 8> +{ + uint2 data; + + FLASHINFER_INLINE __nv_fp8_e5m2 &operator[](size_t i) + { + return ((__nv_fp8_e5m2 *)(&data))[i]; + } + FLASHINFER_INLINE const __nv_fp8_e5m2 &operator[](size_t i) const + { + return ((const __nv_fp8_e5m2 *)(&data))[i]; + } + FLASHINFER_INLINE __nv_fp8_e5m2 *ptr() + { + return reinterpret_cast<__nv_fp8_e5m2 *>(&data); + } + FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val); + FLASHINFER_INLINE void load(const __nv_fp8_e5m2 *ptr); + FLASHINFER_INLINE void store(__nv_fp8_e5m2 *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) + { + cast_from_impl(*this, src); + } + template FLASHINFER_INLINE void cast_load(const T *ptr) + { + cast_load_impl(*this, ptr); + } + template FLASHINFER_INLINE void cast_store(T *ptr) const + { + cast_store_impl(ptr, *this); + } + FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2 *dst, + const __nv_fp8_e5m2 *src); +}; + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 8>::fill(__nv_fp8_e5m2 val) +{ + ((__nv_fp8x4_e5m2 *)(&data.x))->__x = + (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); + ((__nv_fp8x4_e5m2 *)(&data.y))->__x = + (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 8>::load(const __nv_fp8_e5m2 *ptr) +{ + data = *((uint2 *)ptr); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 8>::store(__nv_fp8_e5m2 *ptr) const +{ + *((uint2 *)ptr) = data; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 8>::memcpy(__nv_fp8_e5m2 *dst, + const __nv_fp8_e5m2 *src) +{ + *((uint2 *)dst) = *((uint2 *)src); +} + +// __nv_fp8_e5m2 x 16 or more + +template struct vec_t<__nv_fp8_e5m2, vec_size> +{ + uint4 data[vec_size / 16]; + + FLASHINFER_INLINE __nv_fp8_e5m2 &operator[](size_t i) + { + return ((__nv_fp8_e5m2 *)data)[i]; + } + FLASHINFER_INLINE const __nv_fp8_e5m2 &operator[](size_t i) const + { + return ((const __nv_fp8_e5m2 *)data)[i]; + } + FLASHINFER_INLINE __nv_fp8_e5m2 *ptr() + { + return reinterpret_cast<__nv_fp8_e5m2 *>(&data); + } + FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val) + { +#pragma unroll + for (size_t i = 0; i < vec_size / 16; ++i) { + ((__nv_fp8x4_e5m2 *)(&(data[i].x)))->__x = + (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | + __nv_fp8x4_storage_t(val.__x); + ((__nv_fp8x4_e5m2 *)(&(data[i].y)))->__x = + (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | + __nv_fp8x4_storage_t(val.__x); + ((__nv_fp8x4_e5m2 *)(&(data[i].z)))->__x = + (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | + __nv_fp8x4_storage_t(val.__x); + ((__nv_fp8x4_e5m2 *)(&(data[i].w)))->__x = + (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | + __nv_fp8x4_storage_t(val.__x); + } + } + FLASHINFER_INLINE void load(const __nv_fp8_e5m2 *ptr) + { +#pragma unroll + for (size_t i = 0; i < vec_size / 16; ++i) { + data[i] = ((uint4 *)ptr)[i]; + } + } + FLASHINFER_INLINE void store(__nv_fp8_e5m2 *ptr) const + { +#pragma unroll + for (size_t i = 0; i < vec_size / 16; ++i) { + ((uint4 *)ptr)[i] = data[i]; + } + } + template + FLASHINFER_INLINE void cast_from(const vec_t &src) + { + cast_from_impl(*this, src); + } + template FLASHINFER_INLINE void cast_load(const T *ptr) + { + cast_load_impl(*this, ptr); + } + template FLASHINFER_INLINE void cast_store(T *ptr) const + { + cast_store_impl(ptr, *this); + } + FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2 *dst, + const __nv_fp8_e5m2 *src) + { +#pragma unroll + for (size_t i = 0; i < vec_size / 16; ++i) { + ((uint4 *)dst)[i] = ((uint4 *)src)[i]; + } + } +}; + +/******************* vec_t *******************/ + +// half x 1 +template <> struct vec_t +{ + half data; + + FLASHINFER_INLINE half &operator[](size_t i) + { + return ((half *)(&data))[i]; + } + FLASHINFER_INLINE const half &operator[](size_t i) const + { + return ((const half *)(&data))[i]; + } + FLASHINFER_INLINE half *ptr() { return reinterpret_cast(&data); } + FLASHINFER_INLINE void fill(half val); + FLASHINFER_INLINE void load(const half *ptr); + FLASHINFER_INLINE void store(half *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) + { + cast_from_impl(*this, src); + } + template FLASHINFER_INLINE void cast_load(const T *ptr) + { + cast_load_impl(*this, ptr); + } + template FLASHINFER_INLINE void cast_store(T *ptr) const + { + cast_store_impl(ptr, *this); + } + + FLASHINFER_INLINE static void memcpy(half *dst, const half *src); +}; + +FLASHINFER_INLINE void vec_t::fill(half val) { data = val; } + +FLASHINFER_INLINE void vec_t::load(const half *ptr) { data = *ptr; } + +FLASHINFER_INLINE void vec_t::store(half *ptr) const { *ptr = data; } + +FLASHINFER_INLINE void vec_t::memcpy(half *dst, const half *src) +{ + *dst = *src; +} + +// half x 2 +template <> struct vec_t +{ + half2 data; + + FLASHINFER_INLINE half &operator[](size_t i) + { + return ((half *)(&data))[i]; + } + FLASHINFER_INLINE const half &operator[](size_t i) const + { + return ((const half *)(&data))[i]; + } + FLASHINFER_INLINE half *ptr() { return reinterpret_cast(&data); } + FLASHINFER_INLINE void fill(half val); + FLASHINFER_INLINE void load(const half *ptr); + FLASHINFER_INLINE void store(half *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) + { + cast_from_impl(*this, src); + } + template FLASHINFER_INLINE void cast_load(const T *ptr) + { + cast_load_impl(*this, ptr); + } + template FLASHINFER_INLINE void cast_store(T *ptr) const + { + cast_store_impl(ptr, *this); + } + + FLASHINFER_INLINE static void memcpy(half *dst, const half *src); +}; + +FLASHINFER_INLINE void vec_t::fill(half val) +{ + data = make_half2(val, val); +} + +FLASHINFER_INLINE void vec_t::load(const half *ptr) +{ + data = *((half2 *)ptr); +} + +FLASHINFER_INLINE void vec_t::store(half *ptr) const +{ + *((half2 *)ptr) = data; +} + +FLASHINFER_INLINE void vec_t::memcpy(half *dst, const half *src) +{ + *((half2 *)dst) = *((half2 *)src); +} + +// half x 4 + +template <> struct vec_t +{ + uint2 data; + + FLASHINFER_INLINE half &operator[](size_t i) + { + return ((half *)(&data))[i]; + } + FLASHINFER_INLINE const half &operator[](size_t i) const + { + return ((const half *)(&data))[i]; + } + FLASHINFER_INLINE half *ptr() { return reinterpret_cast(&data); } + FLASHINFER_INLINE void fill(half val); + FLASHINFER_INLINE void load(const half *ptr); + FLASHINFER_INLINE void store(half *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) + { + cast_from_impl(*this, src); + } + template FLASHINFER_INLINE void cast_load(const T *ptr) + { + cast_load_impl(*this, ptr); + } + template FLASHINFER_INLINE void cast_store(T *ptr) const + { + cast_store_impl(ptr, *this); + } + FLASHINFER_INLINE static void memcpy(half *dst, const half *src); +}; + +FLASHINFER_INLINE void vec_t::fill(half val) +{ + *(half2 *)(&data.x) = make_half2(val, val); + *(half2 *)(&data.y) = make_half2(val, val); +} + +FLASHINFER_INLINE void vec_t::load(const half *ptr) +{ + data = *((uint2 *)ptr); +} + +FLASHINFER_INLINE void vec_t::store(half *ptr) const +{ + *((uint2 *)ptr) = data; +} + +FLASHINFER_INLINE void vec_t::memcpy(half *dst, const half *src) +{ + *((uint2 *)dst) = *((uint2 *)src); +} + +// half x 8 or more + +template struct vec_t +{ + uint4 data[vec_size / 8]; + FLASHINFER_INLINE half &operator[](size_t i) { return ((half *)data)[i]; } + FLASHINFER_INLINE const half &operator[](size_t i) const + { + return ((const half *)data)[i]; + } + FLASHINFER_INLINE half *ptr() { return reinterpret_cast(&data); } + FLASHINFER_INLINE void fill(half val) + { +#pragma unroll + for (size_t i = 0; i < vec_size / 8; ++i) { + *(half2 *)(&(data[i].x)) = make_half2(val, val); + *(half2 *)(&(data[i].y)) = make_half2(val, val); + *(half2 *)(&(data[i].z)) = make_half2(val, val); + *(half2 *)(&(data[i].w)) = make_half2(val, val); + } + } + FLASHINFER_INLINE void load(const half *ptr) + { +#pragma unroll + for (size_t i = 0; i < vec_size / 8; ++i) { + data[i] = ((uint4 *)ptr)[i]; + } + } + FLASHINFER_INLINE void store(half *ptr) const + { +#pragma unroll + for (size_t i = 0; i < vec_size / 8; ++i) { + ((uint4 *)ptr)[i] = data[i]; + } + } + template + FLASHINFER_INLINE void cast_from(const vec_t &src) + { + cast_from_impl(*this, src); + } + template FLASHINFER_INLINE void cast_load(const T *ptr) + { + cast_load_impl(*this, ptr); + } + template FLASHINFER_INLINE void cast_store(T *ptr) const + { + cast_store_impl(ptr, *this); + } + FLASHINFER_INLINE static void memcpy(half *dst, const half *src) + { +#pragma unroll + for (size_t i = 0; i < vec_size / 8; ++i) { + ((uint4 *)dst)[i] = ((uint4 *)src)[i]; + } + } +}; + +/******************* vec_t *******************/ + +// nv_bfloat16 x 1 +template <> struct vec_t +{ + nv_bfloat16 data; + FLASHINFER_INLINE nv_bfloat16 &operator[](size_t i) + { + return ((nv_bfloat16 *)(&data))[i]; + } + FLASHINFER_INLINE const nv_bfloat16 &operator[](size_t i) const + { + return ((const nv_bfloat16 *)(&data))[i]; + } + FLASHINFER_INLINE nv_bfloat16 *ptr() + { + return reinterpret_cast(&data); + } + FLASHINFER_INLINE void fill(nv_bfloat16 val); + FLASHINFER_INLINE void load(const nv_bfloat16 *ptr); + FLASHINFER_INLINE void store(nv_bfloat16 *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) + { + cast_from_impl(*this, src); + } + template FLASHINFER_INLINE void cast_load(const T *ptr) + { + cast_load_impl(*this, ptr); + } + template FLASHINFER_INLINE void cast_store(T *ptr) const + { + cast_store_impl(ptr, *this); + } + FLASHINFER_INLINE static void memcpy(nv_bfloat16 *dst, + const nv_bfloat16 *src); +}; + +FLASHINFER_INLINE void vec_t::fill(nv_bfloat16 val) +{ + data = val; +} + +FLASHINFER_INLINE void vec_t::load(const nv_bfloat16 *ptr) +{ + data = *ptr; +} + +FLASHINFER_INLINE void vec_t::store(nv_bfloat16 *ptr) const +{ + *ptr = data; +} + +FLASHINFER_INLINE void vec_t::memcpy(nv_bfloat16 *dst, + const nv_bfloat16 *src) +{ + *dst = *src; +} + +// nv_bfloat16 x 2 +template <> struct vec_t +{ + nv_bfloat162 data; + + FLASHINFER_INLINE nv_bfloat16 &operator[](size_t i) + { + return ((nv_bfloat16 *)(&data))[i]; + } + FLASHINFER_INLINE const nv_bfloat16 &operator[](size_t i) const + { + return ((const nv_bfloat16 *)(&data))[i]; + } + FLASHINFER_INLINE nv_bfloat16 *ptr() + { + return reinterpret_cast(&data); + } + FLASHINFER_INLINE void fill(nv_bfloat16 val); + FLASHINFER_INLINE void load(const nv_bfloat16 *ptr); + FLASHINFER_INLINE void store(nv_bfloat16 *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) + { + cast_from_impl(*this, src); + } + template FLASHINFER_INLINE void cast_load(const T *ptr) + { + cast_load_impl(*this, ptr); + } + template FLASHINFER_INLINE void cast_store(T *ptr) const + { + cast_store_impl(ptr, *this); + } + FLASHINFER_INLINE static void memcpy(nv_bfloat16 *dst, + const nv_bfloat16 *src); +}; + +FLASHINFER_INLINE void vec_t::fill(nv_bfloat16 val) +{ + data = make_bfloat162(val, val); +} + +FLASHINFER_INLINE void vec_t::load(const nv_bfloat16 *ptr) +{ + data = *((nv_bfloat162 *)ptr); +} + +FLASHINFER_INLINE void vec_t::store(nv_bfloat16 *ptr) const +{ + *((nv_bfloat162 *)ptr) = data; +} + +FLASHINFER_INLINE void vec_t::memcpy(nv_bfloat16 *dst, + const nv_bfloat16 *src) +{ + *((nv_bfloat162 *)dst) = *((nv_bfloat162 *)src); +} + +// nv_bfloat16 x 4 + +template <> struct vec_t +{ + uint2 data; + + FLASHINFER_INLINE nv_bfloat16 &operator[](size_t i) + { + return ((nv_bfloat16 *)(&data))[i]; + } + FLASHINFER_INLINE const nv_bfloat16 &operator[](size_t i) const + { + return ((const nv_bfloat16 *)(&data))[i]; + } + FLASHINFER_INLINE nv_bfloat16 *ptr() + { + return reinterpret_cast(&data); + } + FLASHINFER_INLINE void fill(nv_bfloat16 val); + FLASHINFER_INLINE void load(const nv_bfloat16 *ptr); + FLASHINFER_INLINE void store(nv_bfloat16 *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) + { + cast_from_impl(*this, src); + } + template FLASHINFER_INLINE void cast_load(const T *ptr) + { + cast_load_impl(*this, ptr); + } + template FLASHINFER_INLINE void cast_store(T *ptr) const + { + cast_store_impl(ptr, *this); + } + FLASHINFER_INLINE static void memcpy(nv_bfloat16 *dst, + const nv_bfloat16 *src); +}; + +FLASHINFER_INLINE void vec_t::fill(nv_bfloat16 val) +{ + *(nv_bfloat162 *)(&data.x) = make_bfloat162(val, val); + *(nv_bfloat162 *)(&data.y) = make_bfloat162(val, val); +} + +FLASHINFER_INLINE void vec_t::load(const nv_bfloat16 *ptr) +{ + data = *((uint2 *)ptr); +} + +FLASHINFER_INLINE void vec_t::store(nv_bfloat16 *ptr) const +{ + *((uint2 *)ptr) = data; +} + +FLASHINFER_INLINE void vec_t::memcpy(nv_bfloat16 *dst, + const nv_bfloat16 *src) +{ + *((uint2 *)dst) = *((uint2 *)src); +} + +// nv_bfloat16 x 8 or more + +template struct vec_t +{ + uint4 data[vec_size / 8]; + + FLASHINFER_INLINE nv_bfloat16 &operator[](size_t i) + { + return ((nv_bfloat16 *)data)[i]; + } + FLASHINFER_INLINE const nv_bfloat16 &operator[](size_t i) const + { + return ((const nv_bfloat16 *)data)[i]; + } + FLASHINFER_INLINE nv_bfloat16 *ptr() + { + return reinterpret_cast(&data); + } + FLASHINFER_INLINE void fill(nv_bfloat16 val) + { +#pragma unoll + for (size_t i = 0; i < vec_size / 8; ++i) { + *(nv_bfloat162 *)(&(data[i].x)) = make_bfloat162(val, val); + *(nv_bfloat162 *)(&(data[i].y)) = make_bfloat162(val, val); + *(nv_bfloat162 *)(&(data[i].z)) = make_bfloat162(val, val); + *(nv_bfloat162 *)(&(data[i].w)) = make_bfloat162(val, val); + } + } + FLASHINFER_INLINE void load(const nv_bfloat16 *ptr) + { +#pragma unoll + for (size_t i = 0; i < vec_size / 8; ++i) { + data[i] = ((uint4 *)ptr)[i]; + } + } + FLASHINFER_INLINE void store(nv_bfloat16 *ptr) const + { +#pragma unoll + for (size_t i = 0; i < vec_size / 8; ++i) { + ((uint4 *)ptr)[i] = data[i]; + } + } + template + FLASHINFER_INLINE void cast_from(const vec_t &src) + { + cast_from_impl(*this, src); + } + template FLASHINFER_INLINE void cast_load(const T *ptr) + { + cast_load_impl(*this, ptr); + } + template FLASHINFER_INLINE void cast_store(T *ptr) const + { + cast_store_impl(ptr, *this); + } + FLASHINFER_INLINE static void memcpy(nv_bfloat16 *dst, + const nv_bfloat16 *src) + { +#pragma unoll + for (size_t i = 0; i < vec_size / 8; ++i) { + ((uint4 *)dst)[i] = ((uint4 *)src)[i]; + } + } +}; + +/******************* vec_t *******************/ + +// float x 1 + +template <> struct vec_t +{ + float data; + + FLASHINFER_INLINE float &operator[](size_t i) + { + return ((float *)(&data))[i]; + } + FLASHINFER_INLINE const float &operator[](size_t i) const + { + return ((const float *)(&data))[i]; + } + FLASHINFER_INLINE float *ptr() { return reinterpret_cast(&data); } + FLASHINFER_INLINE void fill(float val); + FLASHINFER_INLINE void load(const float *ptr); + FLASHINFER_INLINE void store(float *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) + { + cast_from_impl(*this, src); + } + template FLASHINFER_INLINE void cast_load(const T *ptr) + { + cast_load_impl(*this, ptr); + } + template FLASHINFER_INLINE void cast_store(T *ptr) const + { + cast_store_impl(ptr, *this); + } + FLASHINFER_INLINE static void memcpy(float *dst, const float *src); +}; + +FLASHINFER_INLINE void vec_t::fill(float val) { data = val; } + +FLASHINFER_INLINE void vec_t::load(const float *ptr) { data = *ptr; } + +FLASHINFER_INLINE void vec_t::store(float *ptr) const { *ptr = data; } + +FLASHINFER_INLINE void vec_t::memcpy(float *dst, const float *src) +{ + *dst = *src; +} + +// float x 2 + +template <> struct vec_t +{ + float2 data; + + FLASHINFER_INLINE float &operator[](size_t i) + { + return ((float *)(&data))[i]; + } + FLASHINFER_INLINE const float &operator[](size_t i) const + { + return ((const float *)(&data))[i]; + } + FLASHINFER_INLINE float *ptr() { return reinterpret_cast(&data); } + FLASHINFER_INLINE void fill(float val); + FLASHINFER_INLINE void load(const float *ptr); + FLASHINFER_INLINE void store(float *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) + { + cast_from_impl(*this, src); + } + template FLASHINFER_INLINE void cast_load(const T *ptr) + { + cast_load_impl(*this, ptr); + } + template FLASHINFER_INLINE void cast_store(T *ptr) const + { + cast_store_impl(ptr, *this); + } + FLASHINFER_INLINE static void memcpy(float *dst, const float *src); +}; + +FLASHINFER_INLINE void vec_t::fill(float val) +{ + data = make_float2(val, val); +} + +FLASHINFER_INLINE void vec_t::load(const float *ptr) +{ + data = *((float2 *)ptr); +} + +FLASHINFER_INLINE void vec_t::store(float *ptr) const +{ + *((float2 *)ptr) = data; +} + +FLASHINFER_INLINE void vec_t::memcpy(float *dst, const float *src) +{ + *((float2 *)dst) = *((float2 *)src); +} + +// float x 4 or more +template struct vec_t +{ + float4 data[vec_size / 4]; + + FLASHINFER_INLINE float &operator[](size_t i) + { + return ((float *)(data))[i]; + } + FLASHINFER_INLINE const float &operator[](size_t i) const + { + return ((const float *)(data))[i]; + } + FLASHINFER_INLINE float *ptr() { return reinterpret_cast(&data); } + FLASHINFER_INLINE void fill(float val) + { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + data[i] = make_float4(val, val, val, val); + } + } + FLASHINFER_INLINE void load(const float *ptr) + { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + data[i] = ((float4 *)ptr)[i]; + } + } + FLASHINFER_INLINE void store(float *ptr) const + { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + ((float4 *)ptr)[i] = data[i]; + } + } + template + FLASHINFER_INLINE void cast_from(const vec_t &src) + { + cast_from_impl(*this, src); + } + template FLASHINFER_INLINE void cast_load(const T *ptr) + { + cast_load_impl(*this, ptr); + } + template FLASHINFER_INLINE void cast_store(T *ptr) const + { + cast_store_impl(ptr, *this); + } + FLASHINFER_INLINE static void memcpy(float *dst, const float *src) + { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + ((float4 *)dst)[i] = ((float4 *)src)[i]; + } + } +}; + +} // namespace flashinfer + +#endif // VEC_DTYPES_CUH_ From 8524d22a9d65138902c1d7302bc4f9ac4d5a7084 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Thu, 14 Aug 2025 14:41:18 -0400 Subject: [PATCH 038/109] Prefill compiles... --- .../include/flashinfer/attention/generic/permuted_smem.cuh | 2 +- .../include/flashinfer/attention/generic/prefill.cuh | 4 +--- libflashinfer/include/gpu_iface/backend/hip/mma_hip.h | 2 +- libflashinfer/include/gpu_iface/mma_ops.hpp | 4 ++-- 4 files changed, 5 insertions(+), 7 deletions(-) diff --git a/libflashinfer/include/flashinfer/attention/generic/permuted_smem.cuh b/libflashinfer/include/flashinfer/attention/generic/permuted_smem.cuh index ca92b1f7fe..708ba06c1d 100644 --- a/libflashinfer/include/flashinfer/attention/generic/permuted_smem.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/permuted_smem.cuh @@ -171,7 +171,7 @@ template struct smem_t #if defined(PLATFORM_HIP_DEVICE) auto smem_t_ptr = reinterpret_cast(base + offset); flashinfer::gpu_iface::mma::load_fragment_transpose_4x4_half_registers( - smem_t_ptr, frag); + frag, smem_t_ptr); #else static_assert(false, "Not supported on current platform"); #endif diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh index 802440b2eb..77f782609f 100644 --- a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -415,7 +415,6 @@ __device__ __forceinline__ void produce_kv_helper_( #endif uint32_t row = lane_idx / WARP_THREAD_COLS; - uint32_t col = lane_idx % WARP_THREAD_COLS; uint32_t kv_idx = kv_idx_base + warp_idx * WARP_THREAD_ROWS + row; // NOTE: NUM_MMA_KV*4/NUM_WARPS_Q = NUM_WARPS_KV*NUM_MMA_KV*4/num_warps static_assert(NUM_MMA_KV * 4 % NUM_WARPS_Q == 0); @@ -1392,8 +1391,7 @@ __device__ __forceinline__ void compute_sfm_v( constexpr uint32_t UPCAST_STRIDE_V = KTraits::UPCAST_STRIDE_V; constexpr uint32_t HALF_ELEMS_PER_THREAD = KTraits::HALF_ELEMS_PER_THREAD; constexpr uint32_t INT32_ELEMS_PER_THREAD = KTraits::INT32_ELEMS_PER_THREAD; - constexpr uint32_t NUM_ACCUM_ROWS_PER_THREAD = - KTraits::NUM_ACCUM_ROWS_PER_THREAD; + constexpr uint32_t V_SMEM_COLUMN_ADVANCE = 16 / KTraits::HALF_ELEMS_PER_THREAD; diff --git a/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h b/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h index 2238406b13..b7bb15a2e0 100644 --- a/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h +++ b/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h @@ -146,7 +146,7 @@ mma_sync_m16n16k16_row_col_f16f16f32(float *C, uint32_t *A, uint32_t *B) /// the registers for a group of four consecuitive threads. template __device__ __forceinline__ void -load_fragment_4x4_half_registers(const T *smem_ptr, uint32_t *R) +load_fragment_4x4_half_registers(uint32_t *R, const T *smem_ptr) { static_assert(std::is_same_v, "Only half type is supported"); // Each thread loads 4 __half values in two 32b registers. diff --git a/libflashinfer/include/gpu_iface/mma_ops.hpp b/libflashinfer/include/gpu_iface/mma_ops.hpp index e08a5f82e7..97ae7bf506 100644 --- a/libflashinfer/include/gpu_iface/mma_ops.hpp +++ b/libflashinfer/include/gpu_iface/mma_ops.hpp @@ -44,12 +44,12 @@ load_fragment_transpose(uint32_t *R, const T *smem_ptr, uint32_t stride) mma_detail::load_fragment_transpose(R, smem_ptr, stride); } -#if defined(PLATFORM_HIP_DEVICE) && defined(__gfx942__) +#if defined(PLATFORM_HIP_DEVICE) template __device__ __forceinline__ void load_fragment_transpose_4x4_half_registers(uint32_t *R, const T *smem_ptr) { - static_assert(std::is_same::value, + static_assert(std::is_same::value, "Only __half is supported for the 4x4 register transpose"); mma_detail::load_fragment_4x4_half_registers(R, smem_ptr); } From 931e43b9d8509737a3ec2f799ff686ec0c5bf65d Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Fri, 15 Aug 2025 14:55:23 -0400 Subject: [PATCH 039/109] Fixed logits functions and load_q_global_smem bug. --- .../attention/generic/permuted_smem.cuh | 2 +- .../flashinfer/attention/generic/prefill.cuh | 57 ++++++++++++------- 2 files changed, 37 insertions(+), 22 deletions(-) diff --git a/libflashinfer/include/flashinfer/attention/generic/permuted_smem.cuh b/libflashinfer/include/flashinfer/attention/generic/permuted_smem.cuh index 708ba06c1d..1c6dd0b0a6 100644 --- a/libflashinfer/include/flashinfer/attention/generic/permuted_smem.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/permuted_smem.cuh @@ -38,7 +38,7 @@ using b64_t = uint2; * \brief Compute the number of elements that can be stored in a b128_t. * \tparam T The data type of the elements. */ -template +template constexpr __host__ __device__ __forceinline__ uint32_t upcast_size() { static_assert(VectorWidthBits == 128 || VectorWidthBits == 64, diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh index 77f782609f..f020804584 100644 --- a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -164,6 +164,9 @@ struct KernelTraits // CUDA: 4 threads (each thread handles 2 elements from same row group) // CDNA3: 16 threads (each thread handles 1 element from same row group) static constexpr uint32_t THREADS_PER_MATRIX_ROW_SET = 16; + // controls the indexing stride used in logits-related functions + // (logits_transform, logits_mask, and LSE writing). + static constexpr uint32_t LOGITS_INDEX_STRIDE = 4; #else using SmemBasePtrTy = uint4; static constexpr uint32_t NUM_THREADS = NUM_WARPS_Q * NUM_WARPS_KV * 32; @@ -191,6 +194,7 @@ struct KernelTraits // https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-fragment-mma-16816-i8-f8 static constexpr uint32_t NUM_ACCUM_ROWS_PER_THREAD = 2; static constexpr uint32_t THREADS_PER_MATRIX_ROW_SET = 4; + static constexpr uint32_t LOGITS_INDEX_STRIDE = 8; #endif static constexpr uint32_t UPCAST_STRIDE_Q = HEAD_DIM_QK / upcast_size(); @@ -463,7 +467,6 @@ __device__ __forceinline__ void produce_kv( const dim3 tid = threadIdx) { // NOTE: for fp8, this function doesn't work for head_dim = 64 at the moment - using DTypeKV = typename KTraits::DTypeKV; constexpr uint32_t NUM_WARPS = KTraits::NUM_WARPS; constexpr uint32_t NUM_WARPS_Q = KTraits::NUM_WARPS_Q; constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; @@ -516,7 +519,6 @@ __device__ __forceinline__ void page_produce_kv( { // NOTE: for fp8, this function doesn't work for head_dim = 64 at the moment using DType = typename KTraits::DTypeKV; - using IdType = typename KTraits::IdType; constexpr SharedMemFillMode fill_mode = produce_v ? SharedMemFillMode::kFillZero : SharedMemFillMode::kNoFill; constexpr uint32_t NUM_WARPS = KTraits::NUM_WARPS; @@ -682,7 +684,7 @@ __device__ __forceinline__ void load_q_global_smem( r); const uint32_t q_idx = q; DTypeQ *q_ptr = q_ptr_base + q * q_stride_n + r * q_stride_h + - col * upcast_size(); + col * upcast_size(); #pragma unroll for (uint32_t mma_do = 0; mma_do < KTraits::NUM_MMA_D_QK / 4; ++mma_do) @@ -693,7 +695,7 @@ __device__ __forceinline__ void load_q_global_smem( q_idx < qo_upper_bound); q_smem_offset_w = q_smem->template advance_offset_by_column< WARP_THREAD_COLS>(q_smem_offset_w, mma_do); - q_ptr += HALF_ELEMS_PER_THREAD * upcast_size(); + q_ptr += HALF_ELEMS_PER_THREAD * upcast_size(); } q_smem_offset_w = q_smem->template advance_offset_by_row::value) { @@ -1143,16 +1152,18 @@ logits_mask(const Params ¶ms, const uint32_t lane_idx = tid.x; constexpr uint32_t NUM_MMA_Q = KTraits::NUM_MMA_Q; constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; - using DTypeQKAccum = typename KTraits::DTypeQKAccum; constexpr MaskMode MASK_MODE = KTraits::MASK_MODE; constexpr uint32_t TPR = KTraits::THREADS_PER_MATRIX_ROW_SET; - uint32_t q[NUM_MMA_Q][2], r[NUM_MMA_Q][2]; + constexpr uint32_t NAPTR = KTraits::NUM_ACCUM_ROWS_PER_THREAD; + constexpr uint32_t LIS = KTraits::LOGITS_INDEX_STRIDE; + + uint32_t q[NUM_MMA_Q][NAPTR], r[NUM_MMA_Q][NAPTR]; #pragma unroll for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { #pragma unroll - for (uint32_t j = 0; j < 2; ++j) { + for (uint32_t j = 0; j < NAPTR; ++j) { group_size.divmod(qo_packed_idx_base + mma_q * 16 + lane_idx / TPR + - 8 * j, + LIS * j, q[mma_q][j], r[mma_q][j]); } } @@ -1166,16 +1177,20 @@ logits_mask(const Params ¶ms, ++reg_id) { #if defined(PLATFORM_HIP_DEVICE) - const uint32_t i = reg_id / 2; + const uint32_t q_idx = q[mma_q][(reg_id % NAPTR)], + kv_idx = kv_idx_base + mma_kv * 16 + + 2 * (lane_idx % TPR) + + 8 * (reg_id / 2) + reg_id % 2; + const uint32_t qo_head_idx = + kv_head_idx * group_size + r[mma_q][(reg_id % NAPTR)]; #else - const uint32_t i = reg_id / 4; -#endif const uint32_t q_idx = q[mma_q][(reg_id % 4) / 2], kv_idx = kv_idx_base + mma_kv * 16 + - 2 * (lane_idx % TPR) + 8 * i + - reg_id % 2; + 2 * (lane_idx % TPR) + + 8 * (reg_id / 4) + reg_id % 2; const uint32_t qo_head_idx = kv_head_idx * group_size + r[mma_q][(reg_id % 4) / 2]; +#endif const bool mask = (!(MASK_MODE == MaskMode::kCausal ? (kv_idx + qo_len > kv_len + q_idx || From 05a3472d149169f6f409eb3d10fda020cead97c9 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Fri, 15 Aug 2025 23:55:08 -0400 Subject: [PATCH 040/109] Indexing fixes to normalize_d --- .../flashinfer/attention/generic/prefill.cuh | 24 +++++++++++++++---- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh index f020804584..796d2adef6 100644 --- a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -1529,6 +1529,8 @@ __device__ __forceinline__ void normalize_d( float (*d)[KTraits::NUM_ACCUM_ROWS_PER_THREAD]) { using AttentionVariant = typename KTraits::AttentionVariant; + constexpr uint32_t NAPTR = KTraits::NUM_ACCUM_ROWS_PER_THREAD; + if constexpr (AttentionVariant::use_softmax) { float d_rcp[KTraits::NUM_MMA_Q][KTraits::NUM_ACCUM_ROWS_PER_THREAD]; // compute reciprocal of d @@ -1552,9 +1554,15 @@ __device__ __forceinline__ void normalize_d( for (uint32_t reg_id = 0; reg_id < KTraits::HALF_ELEMS_PER_THREAD; ++reg_id) { +#if defined(PLATFORM_HIP_DEVICE) + o_frag[mma_q][mma_d][reg_id] = + o_frag[mma_q][mma_d][reg_id] * + d_rcp[mma_q][reg_id % NAPTR]; +#else o_frag[mma_q][mma_d][reg_id] = o_frag[mma_q][mma_d][reg_id] * d_rcp[mma_q][(reg_id % 4) / 2]; +#endif } } } @@ -1988,6 +1996,10 @@ SinglePrefillWithKVCacheDevice(const Params params, KTraits::HALF_ELEMS_PER_THREAD; [[maybe_unused]] constexpr uint32_t NUM_ACCUM_ROWS_PER_THREAD = KTraits::NUM_ACCUM_ROWS_PER_THREAD; + [[maybe_unused]] constexpr uint32_t LOGITS_INDEX_STRIDE = + KTraits::LOGITS_INDEX_STRIDE; + [[maybe_unused]] constexpr uint32_t THREADS_PER_MATRIX_ROW_SET = + KTraits::THREADS_PER_MATRIX_ROW_SET; DTypeQ *q = params.q; DTypeKV *k = params.k; @@ -2229,12 +2241,14 @@ SinglePrefillWithKVCacheDevice(const Params params, #pragma unroll for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { #pragma unroll - for (uint32_t j = 0; j < 2; ++j) { + for (uint32_t j = 0; j < NUM_ACCUM_ROWS_PER_THREAD; ++j) + { uint32_t q, r; - group_size.divmod(qo_packed_idx_base + - lane_idx / 4 + j * 8 + - mma_q * 16, - q, r); + group_size.divmod( + qo_packed_idx_base + + lane_idx / THREADS_PER_MATRIX_ROW_SET + + j * LOGITS_INDEX_STRIDE + mma_q * 16, + q, r); const uint32_t qo_head_idx = kv_head_idx * group_size + r; const uint32_t qo_idx = q; From f21615bd44912bdb44adeff8004bf8e6a4a22ad7 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Sat, 16 Aug 2025 00:49:23 -0400 Subject: [PATCH 041/109] Properly fix upcats_size and silence warnings --- .../flashinfer/attention/generic/pos_enc.cuh | 9 ++- .../flashinfer/attention/generic/prefill.cuh | 73 ++++++++++++------- 2 files changed, 51 insertions(+), 31 deletions(-) diff --git a/libflashinfer/include/flashinfer/attention/generic/pos_enc.cuh b/libflashinfer/include/flashinfer/attention/generic/pos_enc.cuh index 80d46b96d2..dbfb01bf11 100644 --- a/libflashinfer/include/flashinfer/attention/generic/pos_enc.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/pos_enc.cuh @@ -45,12 +45,13 @@ PosEncodingModeToString(const PosEncodingMode &pos_encoding_mode) __device__ __forceinline__ float get_alibi_slope(uint32_t head_idx, uint32_t num_heads) { - int n = - gpu_iface::math::ptx_exp2((int)gpu_iface::math::ptx_log2(num_heads)); + int n = (int)gpu_iface::math::ptx_exp2( + gpu_iface::math::ptx_log2(float(num_heads))); return head_idx < n - ? gpu_iface::math::ptx_exp2(-8. * float(head_idx + 1) / float(n)) + ? gpu_iface::math::ptx_exp2(-8.f * float(head_idx + 1) / + float(n)) : gpu_iface::math::ptx_exp2( - -4. * float((head_idx + 1 - n) * 2 - 1) / float(n)); + -4.f * float((head_idx + 1 - n) * 2 - 1) / float(n)); } /*! diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh index 796d2adef6..6e213bc20f 100644 --- a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -411,6 +411,7 @@ __device__ __forceinline__ void produce_kv_helper_( produce_v ? KTraits::NUM_MMA_D_VO : KTraits::NUM_MMA_D_QK; constexpr uint32_t UPCAST_STRIDE = produce_v ? KTraits::UPCAST_STRIDE_V : KTraits::UPCAST_STRIDE_K; + constexpr uint32_t VECTOR_BIT_WIDTH = KTraits::VECTOR_BIT_WIDTH; #if defined(PLATFORM_HIP_DEVICE) constexpr uint32_t COLUMN_RESET_OFFSET = (NUM_MMA_D / 4) * WARP_THREAD_COLS; @@ -431,7 +432,7 @@ __device__ __forceinline__ void produce_kv_helper_( *smem_offset = smem.template advance_offset_by_column( *smem_offset, j); - *gptr += 8 * upcast_size(); + *gptr += 8 * upcast_size(); } kv_idx += NUM_WARPS * WARP_THREAD_ROWS; *smem_offset = @@ -439,7 +440,8 @@ __device__ __forceinline__ void produce_kv_helper_( UPCAST_STRIDE>(*smem_offset) - COLUMN_RESET_OFFSET; *gptr += NUM_WARPS * WARP_THREAD_ROWS * stride_n - - sizeof(DTypeKV) * NUM_MMA_D * upcast_size(); + sizeof(DTypeKV) * NUM_MMA_D * + upcast_size(); } *smem_offset -= KTraits::CTA_TILE_KV * UPCAST_STRIDE; } @@ -528,6 +530,8 @@ __device__ __forceinline__ void page_produce_kv( produce_v ? KTraits::NUM_MMA_D_VO : KTraits::NUM_MMA_D_QK; constexpr uint32_t UPCAST_STRIDE = produce_v ? KTraits::UPCAST_STRIDE_V : KTraits::UPCAST_STRIDE_K; + constexpr uint32_t VECTOR_BIT_WIDTH = KTraits::VECTOR_BIT_WIDTH; + const uint32_t warp_idx = get_warp_idx(tid.y, tid.z), lane_idx = tid.x; if constexpr (KTraits::SWIZZLE_MODE_KV == SwizzleMode::k128B) { @@ -545,7 +549,7 @@ __device__ __forceinline__ void page_produce_kv( kv_idx < kv_len); *smem_offset = smem.template advance_offset_by_column<8>(*smem_offset, j); - gptr += 8 * upcast_size(); + gptr += 8 * upcast_size(); } kv_idx += NUM_WARPS * 4; *smem_offset = smem.template advance_offset_by_row(); + col * upcast_size(); #pragma unroll for (uint32_t mma_do = 0; mma_do < KTraits::NUM_MMA_D_QK / 4; ++mma_do) @@ -695,7 +700,8 @@ __device__ __forceinline__ void load_q_global_smem( q_idx < qo_upper_bound); q_smem_offset_w = q_smem->template advance_offset_by_column< WARP_THREAD_COLS>(q_smem_offset_w, mma_do); - q_ptr += HALF_ELEMS_PER_THREAD * upcast_size(); + q_ptr += HALF_ELEMS_PER_THREAD * + upcast_size(); } q_smem_offset_w = q_smem->template advance_offset_by_row(tid.y); const uint32_t lane_idx = tid.x; @@ -1890,9 +1897,10 @@ __device__ __forceinline__ void write_o_reg_gmem( mma_q * 16 + j * 4, q, r); const uint32_t o_idx = q; - DTypeO *o_ptr = - o_ptr_base + q * o_stride_n + r * o_stride_h + - (lane_idx % WARP_THREAD_COLS) * upcast_size(); + DTypeO *o_ptr = o_ptr_base + q * o_stride_n + + r * o_stride_h + + (lane_idx % WARP_THREAD_COLS) * + upcast_size(); #pragma unroll for (uint32_t mma_do = 0; mma_do < KTraits::NUM_MMA_D_VO / 4; ++mma_do) @@ -1900,7 +1908,8 @@ __device__ __forceinline__ void write_o_reg_gmem( if (o_idx < qo_upper_bound) { o_smem->store_vector(o_smem_offset_w, o_ptr); } - o_ptr += WARP_THREAD_COLS * upcast_size(); + o_ptr += WARP_THREAD_COLS * + upcast_size(); o_smem_offset_w = o_smem->template advance_offset_by_column< WARP_THREAD_COLS>(o_smem_offset_w, mma_do); @@ -2000,6 +2009,8 @@ SinglePrefillWithKVCacheDevice(const Params params, KTraits::LOGITS_INDEX_STRIDE; [[maybe_unused]] constexpr uint32_t THREADS_PER_MATRIX_ROW_SET = KTraits::THREADS_PER_MATRIX_ROW_SET; + [[maybe_unused]] constexpr uint32_t VECTOR_BIT_WIDTH = + KTraits::VECTOR_BIT_WIDTH; DTypeQ *q = params.q; DTypeKV *k = params.k; @@ -2113,20 +2124,20 @@ SinglePrefillWithKVCacheDevice(const Params params, : chunk_size) / CTA_TILE_KV; - DTypeKV *k_ptr = - k + - (chunk_start + warp_idx * KV_THR_LAYOUT_ROW + - lane_idx / KV_THR_LAYOUT_COL) * - k_stride_n + - kv_head_idx * k_stride_h + - (lane_idx % KV_THR_LAYOUT_COL) * upcast_size(); - DTypeKV *v_ptr = - v + - (chunk_start + warp_idx * KV_THR_LAYOUT_ROW + - lane_idx / KV_THR_LAYOUT_COL) * - v_stride_n + - kv_head_idx * v_stride_h + - (lane_idx % KV_THR_LAYOUT_COL) * upcast_size(); + DTypeKV *k_ptr = k + + (chunk_start + warp_idx * KV_THR_LAYOUT_ROW + + lane_idx / KV_THR_LAYOUT_COL) * + k_stride_n + + kv_head_idx * k_stride_h + + (lane_idx % KV_THR_LAYOUT_COL) * + upcast_size(); + DTypeKV *v_ptr = v + + (chunk_start + warp_idx * KV_THR_LAYOUT_ROW + + lane_idx / KV_THR_LAYOUT_COL) * + v_stride_n + + kv_head_idx * v_stride_h + + (lane_idx % KV_THR_LAYOUT_COL) * + upcast_size(); uint32_t k_smem_offset_r = k_smem.template get_permuted_offset( @@ -2499,6 +2510,8 @@ __launch_bounds__(KTraits::NUM_THREADS) void BatchPrefillWithRaggedKVCacheKernel [[maybe_unused]] constexpr uint32_t KV_THR_LAYOUT_COL = KTraits::KV_THR_LAYOUT_COL; [[maybe_unused]] constexpr MaskMode MASK_MODE = KTraits::MASK_MODE; + [[maybe_unused]] constexpr uint32_t VECTOR_BIT_WIDTH = + KTraits::VECTOR_BIT_WIDTH; DTypeQ *q = params.q; IdType *request_indices = params.request_indices; @@ -2675,14 +2688,16 @@ __launch_bounds__(KTraits::NUM_THREADS) void BatchPrefillWithRaggedKVCacheKernel warp_idx * KV_THR_LAYOUT_ROW + lane_idx / KV_THR_LAYOUT_COL) * k_stride_n + kv_head_idx * k_stride_h + - (lane_idx % KV_THR_LAYOUT_COL) * upcast_size(); + (lane_idx % KV_THR_LAYOUT_COL) * + upcast_size(); DTypeKV *v_ptr = v + (kv_indptr[request_idx] + chunk_start + warp_idx * KV_THR_LAYOUT_ROW + lane_idx / KV_THR_LAYOUT_COL) * v_stride_n + kv_head_idx * v_stride_h + - (lane_idx % KV_THR_LAYOUT_COL) * upcast_size(); + (lane_idx % KV_THR_LAYOUT_COL) * + upcast_size(); produce_kv( k_smem, &k_smem_offset_w, &k_ptr, k_stride_n, 0, chunk_size, tid); @@ -2874,6 +2889,8 @@ __device__ __forceinline__ void BatchPrefillWithPagedKVCacheDevice( [[maybe_unused]] constexpr uint32_t KV_THR_LAYOUT_COL = KTraits::KV_THR_LAYOUT_COL; [[maybe_unused]] constexpr MaskMode MASK_MODE = KTraits::MASK_MODE; + [[maybe_unused]] constexpr uint32_t VECTOR_BIT_WIDTH = + KTraits::VECTOR_BIT_WIDTH; IdType *request_indices = params.request_indices; IdType *qo_tile_indices = params.qo_tile_indices; @@ -3023,7 +3040,8 @@ __device__ __forceinline__ void BatchPrefillWithPagedKVCacheDevice( page_iter, entry_idx); thr_local_kv_offset[i] = paged_kv.protective_get_kv_offset( page_iter, kv_head_idx, entry_idx, - (lane_idx % KV_THR_LAYOUT_COL) * upcast_size(), + (lane_idx % KV_THR_LAYOUT_COL) * + upcast_size(), last_indptr); } page_produce_kv(k_smem, &k_smem_offset_w, paged_kv, 0, @@ -3077,7 +3095,8 @@ __device__ __forceinline__ void BatchPrefillWithPagedKVCacheDevice( page_iter, entry_idx); thr_local_kv_offset[i] = paged_kv.protective_get_kv_offset( page_iter, kv_head_idx, entry_idx, - (lane_idx % KV_THR_LAYOUT_COL) * upcast_size(), + (lane_idx % KV_THR_LAYOUT_COL) * + upcast_size(), last_indptr); } memory::wait_group<1>(); From db59b654b79556aa1adcb18d2f523bcdee898c48 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Sun, 17 Aug 2025 03:38:20 -0400 Subject: [PATCH 042/109] Add default_prefill_params.cuh to generic --- .../generic/default_prefill_params.cuh | 352 ++++++++++++++++++ 1 file changed, 352 insertions(+) create mode 100644 libflashinfer/include/flashinfer/attention/generic/default_prefill_params.cuh diff --git a/libflashinfer/include/flashinfer/attention/generic/default_prefill_params.cuh b/libflashinfer/include/flashinfer/attention/generic/default_prefill_params.cuh new file mode 100644 index 0000000000..9d4468267a --- /dev/null +++ b/libflashinfer/include/flashinfer/attention/generic/default_prefill_params.cuh @@ -0,0 +1,352 @@ +/* + * Copyright (c) 2024 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef FLASHINFER_PREFILL_PARAMS_CUH_ +#define FLASHINFER_PREFILL_PARAMS_CUH_ + +#include "gpu_iface/gpu_runtime_compat.hpp" + +#include +#include + +#include "page.cuh" + +namespace flashinfer +{ + +template +struct SinglePrefillParams +{ + using DTypeQ = DTypeQ_; + using DTypeKV = DTypeKV_; + using DTypeO = DTypeO_; + using IdType = int32_t; + DTypeQ *q; + DTypeKV *k; + DTypeKV *v; + uint8_t *maybe_custom_mask; + DTypeO *o; + float *lse; + float *maybe_alibi_slopes; + uint_fastdiv group_size; + uint32_t qo_len; + uint32_t kv_len; + uint32_t num_qo_heads; + uint32_t num_kv_heads; + uint32_t q_stride_n; + uint32_t q_stride_h; + uint32_t k_stride_n; + uint32_t k_stride_h; + uint32_t v_stride_n; + uint32_t v_stride_h; + uint32_t head_dim; + int32_t window_left; + float logits_soft_cap; + float sm_scale; + float rope_rcp_scale; + float rope_rcp_theta; + + uint32_t partition_kv; + + __host__ SinglePrefillParams() + : q(nullptr), k(nullptr), v(nullptr), maybe_custom_mask(nullptr), + o(nullptr), lse(nullptr), maybe_alibi_slopes(nullptr), group_size(), + qo_len(0), kv_len(0), num_qo_heads(0), num_kv_heads(0), q_stride_n(0), + q_stride_h(0), k_stride_n(0), k_stride_h(0), v_stride_n(0), + v_stride_h(0), head_dim(0), window_left(0), logits_soft_cap(0.0f), + sm_scale(0.0f), rope_rcp_scale(0.0f), rope_rcp_theta(0.0f), + partition_kv(false) + { + } + + __host__ SinglePrefillParams(DTypeQ *q, + DTypeKV *k, + DTypeKV *v, + uint8_t *maybe_custom_mask, + DTypeO *o, + float *lse, + float *maybe_alibi_slopes, + uint32_t num_qo_heads, + uint32_t num_kv_heads, + uint32_t qo_len, + uint32_t kv_len, + uint32_t q_stride_n, + uint32_t q_stride_h, + uint32_t kv_stride_n, + uint32_t kv_stride_h, + uint32_t head_dim, + int32_t window_left, + float logits_soft_cap, + float sm_scale, + float rope_scale, + float rope_theta) + : q(q), k(k), v(v), maybe_custom_mask(maybe_custom_mask), o(o), + lse(lse), maybe_alibi_slopes(maybe_alibi_slopes), + group_size(num_qo_heads / num_kv_heads), num_qo_heads(num_qo_heads), + num_kv_heads(num_kv_heads), qo_len(qo_len), kv_len(kv_len), + q_stride_n(q_stride_n), q_stride_h(q_stride_h), + k_stride_n(kv_stride_n), k_stride_h(kv_stride_h), + v_stride_n(kv_stride_n), v_stride_h(kv_stride_h), head_dim(head_dim), + window_left(window_left), logits_soft_cap(logits_soft_cap), + sm_scale(sm_scale), rope_rcp_scale(1. / rope_scale), + rope_rcp_theta(1. / rope_theta), partition_kv(false) + { + } + + __host__ __device__ __forceinline__ uint32_t + get_qo_len(uint32_t batch_idx) const + { + return qo_len; + } + + __host__ __device__ __forceinline__ uint32_t + get_kv_len(uint32_t batch_idx) const + { + return kv_len; + } +}; + +template +struct BatchPrefillRaggedParams +{ + using DTypeQ = DTypeQ_; + using DTypeKV = DTypeKV_; + using DTypeO = DTypeO_; + using IdType = IdType_; + + DTypeQ *q; + DTypeKV *k; + DTypeKV *v; + uint8_t *maybe_custom_mask; + IdType *q_indptr; + IdType *kv_indptr; + IdType *maybe_mask_indptr; + IdType *maybe_q_rope_offset; // maybe_q_rope_offset is only used for + // fused-rope attention + IdType *maybe_k_rope_offset; // maybe_k_rope_offset is only used for + // fused-rope attention + DTypeO *o; + float *lse; + float *maybe_alibi_slopes; + uint_fastdiv group_size; + uint32_t num_qo_heads; + uint32_t num_kv_heads; + uint32_t q_stride_n; + uint32_t q_stride_h; + uint32_t k_stride_n; + uint32_t k_stride_h; + uint32_t v_stride_n; + uint32_t v_stride_h; + int32_t window_left; + float logits_soft_cap; + float sm_scale; + float rope_rcp_scale; + float rope_rcp_theta; + + IdType *request_indices; + IdType *qo_tile_indices; + IdType *kv_tile_indices; + IdType *merge_indptr; + IdType *o_indptr; + IdType *kv_chunk_size_ptr; + bool *block_valid_mask; + uint32_t max_total_num_rows; + uint32_t *total_num_rows; + uint32_t padded_batch_size; + bool partition_kv; + + __host__ BatchPrefillRaggedParams() + : q(nullptr), k(nullptr), v(nullptr), maybe_custom_mask(nullptr), + q_indptr(nullptr), kv_indptr(nullptr), maybe_mask_indptr(nullptr), + maybe_q_rope_offset(nullptr), maybe_k_rope_offset(nullptr), + o(nullptr), lse(nullptr), maybe_alibi_slopes(nullptr), group_size(), + num_qo_heads(0), num_kv_heads(0), q_stride_n(0), q_stride_h(0), + k_stride_n(0), k_stride_h(0), v_stride_n(0), v_stride_h(0), + window_left(0), logits_soft_cap(0.0f), sm_scale(0.0f), + rope_rcp_scale(0.0f), rope_rcp_theta(0.0f), request_indices(nullptr), + qo_tile_indices(nullptr), kv_tile_indices(nullptr), + merge_indptr(nullptr), o_indptr(nullptr), kv_chunk_size_ptr(nullptr), + block_valid_mask(nullptr), max_total_num_rows(0), + total_num_rows(nullptr), padded_batch_size(0), partition_kv(false) + { + } + + __host__ BatchPrefillRaggedParams(DTypeQ *q, + DTypeKV *k, + DTypeKV *v, + uint8_t *maybe_custom_mask, + IdType *q_indptr, + IdType *kv_indptr, + IdType *maybe_mask_indptr, + IdType *maybe_q_rope_offset, + IdType *maybe_k_rope_offset, + DTypeO *o, + float *lse, + float *maybe_alibi_slopes, + uint32_t num_qo_heads, + uint32_t num_kv_heads, + uint32_t q_stride_n, + uint32_t q_stride_h, + uint32_t kv_stride_n, + uint32_t kv_stride_h, + int32_t window_left, + float logits_soft_cap, + float sm_scale, + float rope_scale, + float rope_theta) + : q(q), k(k), v(v), maybe_custom_mask(maybe_custom_mask), + q_indptr(q_indptr), kv_indptr(kv_indptr), + maybe_mask_indptr(maybe_mask_indptr), + maybe_q_rope_offset(maybe_q_rope_offset), + maybe_k_rope_offset(maybe_k_rope_offset), o(o), lse(lse), + maybe_alibi_slopes(maybe_alibi_slopes), + group_size(num_qo_heads / num_kv_heads), num_qo_heads(num_qo_heads), + num_kv_heads(num_kv_heads), q_stride_n(q_stride_n), + q_stride_h(q_stride_h), k_stride_n(kv_stride_n), + k_stride_h(kv_stride_h), v_stride_n(kv_stride_n), + v_stride_h(kv_stride_h), window_left(window_left), + logits_soft_cap(logits_soft_cap), sm_scale(sm_scale), + rope_rcp_scale(1.f / rope_scale), rope_rcp_theta(1.f / rope_theta), + request_indices(nullptr), qo_tile_indices(nullptr), + kv_tile_indices(nullptr), merge_indptr(nullptr), o_indptr(nullptr), + kv_chunk_size_ptr(nullptr), block_valid_mask(nullptr), + max_total_num_rows(0), total_num_rows(nullptr), padded_batch_size(0), + partition_kv(false) + { + } + + __host__ __device__ __forceinline__ uint32_t + get_qo_len(uint32_t batch_idx) const + { + return q_indptr[batch_idx + 1] - q_indptr[batch_idx]; + } + + __host__ __device__ __forceinline__ uint32_t + get_kv_len(uint32_t batch_idx) const + { + return kv_indptr[batch_idx + 1] - kv_indptr[batch_idx]; + } +}; + +template +struct BatchPrefillPagedParams +{ + using DTypeQ = DTypeQ_; + using DTypeKV = DTypeKV_; + using DTypeO = DTypeO_; + using IdType = IdType_; + + DTypeQ *q; + paged_kv_t paged_kv; + uint8_t *maybe_custom_mask; + IdType *q_indptr; + IdType *maybe_mask_indptr; + IdType *maybe_q_rope_offset; // maybe_q_rope_offset is only used for + // fused-rope attention + DTypeO *o; + float *lse; + float *maybe_alibi_slopes; + uint_fastdiv group_size; + uint32_t num_qo_heads; + IdType q_stride_n; + IdType q_stride_h; + int32_t window_left; + float logits_soft_cap; + float sm_scale; + float rope_rcp_scale; + float rope_rcp_theta; + + IdType *request_indices; + IdType *qo_tile_indices; + IdType *kv_tile_indices; + IdType *merge_indptr; + IdType *o_indptr; + bool *block_valid_mask; + IdType *kv_chunk_size_ptr; + uint32_t max_total_num_rows; + uint32_t *total_num_rows; + uint32_t padded_batch_size; + bool partition_kv; + + __host__ BatchPrefillPagedParams() + : q(nullptr), paged_kv(), maybe_custom_mask(nullptr), q_indptr(nullptr), + maybe_mask_indptr(nullptr), maybe_q_rope_offset(nullptr), o(nullptr), + lse(nullptr), maybe_alibi_slopes(nullptr), group_size(), + num_qo_heads(0), q_stride_n(0), q_stride_h(0), window_left(0), + logits_soft_cap(0.0f), sm_scale(0.0f), rope_rcp_scale(0.0f), + rope_rcp_theta(0.0f), request_indices(nullptr), + qo_tile_indices(nullptr), kv_tile_indices(nullptr), + merge_indptr(nullptr), o_indptr(nullptr), block_valid_mask(nullptr), + kv_chunk_size_ptr(nullptr), max_total_num_rows(0), + total_num_rows(nullptr), padded_batch_size(0), partition_kv(false) + { + } + + __host__ BatchPrefillPagedParams(DTypeQ *q, + paged_kv_t paged_kv, + uint8_t *maybe_custom_mask, + IdType *q_indptr, + IdType *maybe_mask_indptr, + IdType *maybe_q_rope_offset, + DTypeO *o, + float *lse, + float *maybe_alibi_slopes, + uint32_t num_qo_heads, + IdType q_stride_n, + IdType q_stride_h, + int32_t window_left, + float logits_soft_cap, + float sm_scale, + float rope_scale, + float rope_theta) + : q(q), paged_kv(paged_kv), maybe_custom_mask(maybe_custom_mask), + q_indptr(q_indptr), maybe_mask_indptr(maybe_mask_indptr), + maybe_q_rope_offset(maybe_q_rope_offset), o(o), lse(lse), + maybe_alibi_slopes(maybe_alibi_slopes), + group_size(num_qo_heads / paged_kv.num_heads), + num_qo_heads(num_qo_heads), q_stride_n(q_stride_n), + q_stride_h(q_stride_h), window_left(window_left), + logits_soft_cap(logits_soft_cap), sm_scale(sm_scale), + rope_rcp_scale(1.f / rope_scale), rope_rcp_theta(1.f / rope_theta), + request_indices(nullptr), qo_tile_indices(nullptr), + kv_tile_indices(nullptr), merge_indptr(nullptr), o_indptr(nullptr), + block_valid_mask(nullptr), kv_chunk_size_ptr(nullptr), + max_total_num_rows(0), total_num_rows(nullptr), padded_batch_size(0), + partition_kv(false) + { + } + + __host__ __device__ __forceinline__ uint32_t + get_qo_len(uint32_t batch_idx) const + { + return q_indptr[batch_idx + 1] - q_indptr[batch_idx]; + } + + __host__ __device__ __forceinline__ uint32_t + get_kv_len(uint32_t batch_idx) const + { + return paged_kv.get_length(batch_idx); + } +}; + +} // namespace flashinfer + +#endif // FLASHINFER_DECODE_PARAMS_CUH_ From 1c161e5022f54a529e2b19c1e1d5b977682cca83 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Sun, 17 Aug 2025 03:38:49 -0400 Subject: [PATCH 043/109] Initial unit test harness --- .../hip/test_load_q_global_smem_kernel.cpp | 284 ++++++++++++ .../tests/hip/test_single_prefill.cpp | 405 ++++++++++++++++++ libflashinfer/utils/cpu_reference_hip.h | 2 +- .../utils/flashinfer_prefill_ops.hip.h | 166 +++++++ 4 files changed, 856 insertions(+), 1 deletion(-) create mode 100644 libflashinfer/tests/hip/test_load_q_global_smem_kernel.cpp create mode 100644 libflashinfer/tests/hip/test_single_prefill.cpp create mode 100644 libflashinfer/utils/flashinfer_prefill_ops.hip.h diff --git a/libflashinfer/tests/hip/test_load_q_global_smem_kernel.cpp b/libflashinfer/tests/hip/test_load_q_global_smem_kernel.cpp new file mode 100644 index 0000000000..9df0feb9ed --- /dev/null +++ b/libflashinfer/tests/hip/test_load_q_global_smem_kernel.cpp @@ -0,0 +1,284 @@ +// test_load_q_global_smem.cpp +#include +#include +#include +#include +#include +#include + +// Include necessary headers +#include "flashinfer/attention/generic/default_prefill_params.cuh" +#include "flashinfer/attention/generic/prefill.cuh" +#include "flashinfer/attention/generic/variants.cuh" +#include "utils/cpu_reference_hip.h" +#include "utils/utils_hip.h" + +using namespace flashinfer; + +// CPU Reference Implementation for Q Loading +template +std::vector cpu_reference_q_smem_layout( + const std::vector &q_global, + size_t qo_len, + size_t num_qo_heads, + size_t head_dim, + size_t q_stride_n, + size_t q_stride_h, + size_t qo_packed_idx_base, + uint32_t group_size, + size_t + smem_height, // Expected shared memory height (16 for single MMA block) + size_t smem_width) // Expected shared memory width (head_dim) +{ + std::vector q_smem_expected(smem_height * smem_width, DTypeQ(0)); + + // Simulate the loading pattern that load_q_global_smem should follow + for (size_t smem_row = 0; smem_row < smem_height; ++smem_row) { + uint32_t q_packed_idx = qo_packed_idx_base + smem_row; + uint32_t q_idx = q_packed_idx / group_size; // Sequence position + uint32_t r = q_packed_idx % group_size; // Head offset within group + + if (q_idx < qo_len) { + for (size_t feat_idx = 0; feat_idx < head_dim; ++feat_idx) { + // Calculate global memory offset + size_t global_offset = + q_idx * q_stride_n + r * q_stride_h + feat_idx; + + // Place in shared memory layout (assuming linear layout for + // test) + size_t smem_offset = smem_row * smem_width + feat_idx; + if (global_offset < q_global.size()) { + q_smem_expected[smem_offset] = q_global[global_offset]; + } + } + } + } + + return q_smem_expected; +} + +uint_fastdiv create_group_size_div(uint32_t group_size) +{ + return uint_fastdiv(group_size); +} + +// Test kernel for Q loading +template +__global__ void test_q_loading_kernel( + typename KTraits::DTypeQ *q_global, + typename KTraits::DTypeQ *q_smem_output, // Output: dump of shared memory + uint32_t qo_packed_idx_base, + uint32_t qo_len, + uint32_t q_stride_n, + uint32_t q_stride_h, + uint_fastdiv group_size_div) +{ + // Set up shared memory + extern __shared__ uint8_t smem[]; + typename KTraits::SharedStorage &smem_storage = + reinterpret_cast(smem); + + smem_t q_smem( + smem_storage.q_smem); + + // // Fast integer division for group_size + // uint_fastdiv group_size_div(group_size); + + // Call the function we're testing + load_q_global_smem(qo_packed_idx_base, qo_len, q_global, + q_stride_n, q_stride_h, group_size_div, &q_smem, + threadIdx); + + // Synchronize to ensure loading is complete + __syncthreads(); + + // Copy shared memory contents to global memory for verification + // Only first warp does this to avoid conflicts + if (threadIdx.y == 0 && threadIdx.z == 0) { + const uint32_t lane_idx = threadIdx.x; + const uint32_t total_elements = + KTraits::CTA_TILE_Q * KTraits::HEAD_DIM_QK; + + // Each thread copies a portion of shared memory + for (uint32_t i = lane_idx; i < total_elements; + i += KTraits::NUM_THREADS) + { + // For linear swizzle mode, direct copy + if (i < total_elements) { + q_smem_output[i] = reinterpret_cast( + smem_storage.q_smem)[i]; + } + } + } +} + +// Main test function +template bool test_q_loading_correctness() +{ + std::cout << "Testing Q loading correctness with " << sizeof(DTypeQ) * 8 + << "-bit precision..." << std::endl; + + // Test parameters - small sizes for initial validation + constexpr size_t qo_len = 8; + constexpr size_t num_qo_heads = 8; + constexpr size_t num_kv_heads = 2; + constexpr size_t head_dim = 64; + constexpr uint32_t group_size = num_qo_heads / num_kv_heads; + + // Create test data with known pattern for easier debugging + const size_t q_size = qo_len * num_qo_heads * head_dim; + std::vector q_host(q_size); + + // Fill with simple pattern: row*1000 + col for easier validation + for (size_t i = 0; i < q_size; ++i) { + float val = float(i % 100) / 10.0f; // Values 0.0, 0.1, 0.2, ... 9.9 + q_host[i] = fi::con::explicit_casting(val); + } + + // GPU memory allocation + DTypeQ *q_device, *q_smem_output; + const size_t smem_elements = 16 * head_dim; // Single MMA block + FI_GPU_CALL(hipMalloc(&q_device, q_size * sizeof(DTypeQ))); + FI_GPU_CALL(hipMalloc(&q_smem_output, smem_elements * sizeof(DTypeQ))); + + FI_GPU_CALL(hipMemcpy(q_device, q_host.data(), q_size * sizeof(DTypeQ), + hipMemcpyHostToDevice)); + + // Define kernel traits for CDNA3 + using KTraits = + KernelTraits>; + + // Launch parameters + dim3 block_size(64, 1, 1); // CDNA3: 64 threads per wavefront + dim3 grid_size(1, 1, 1); + size_t shared_mem_size = sizeof(typename KTraits::SharedStorage); + + // Test parameters + const uint32_t qo_packed_idx_base = 0; // Start from beginning + const uint32_t q_stride_n = num_qo_heads * head_dim; + const uint32_t q_stride_h = head_dim; + + std::cout << "Launching kernel with:" << std::endl; + std::cout << " Block size: " << block_size.x << "x" << block_size.y << "x" + << block_size.z << std::endl; + std::cout << " Shared memory: " << shared_mem_size << " bytes" + << std::endl; + std::cout << " Q size: " << q_size << " elements" << std::endl; + + uint_fastdiv group_size_div = create_group_size_div(group_size); + + // Launch test kernel + test_q_loading_kernel<<>>( + q_device, q_smem_output, qo_packed_idx_base, qo_len, q_stride_n, + q_stride_h, group_size_div); + + FI_GPU_CALL(hipDeviceSynchronize()); + + // Get results back + std::vector q_smem_actual(smem_elements); + FI_GPU_CALL(hipMemcpy(q_smem_actual.data(), q_smem_output, + smem_elements * sizeof(DTypeQ), + hipMemcpyDeviceToHost)); + + // Generate CPU reference + std::vector q_smem_expected = cpu_reference_q_smem_layout( + q_host, qo_len, num_qo_heads, head_dim, q_stride_n, q_stride_h, + qo_packed_idx_base, group_size, 16, head_dim); + + // Compare results + bool passed = true; + float max_diff = 0.0f; + size_t mismatch_count = 0; + + std::cout << "\nValidation results:" << std::endl; + std::cout << "Comparing " << q_smem_actual.size() << " elements..." + << std::endl; + + for (size_t i = 0; + i < std::min(q_smem_actual.size(), q_smem_expected.size()); ++i) + { + float actual = + fi::con::explicit_casting(q_smem_actual[i]); + float expected = + fi::con::explicit_casting(q_smem_expected[i]); + float diff = std::abs(actual - expected); + max_diff = std::max(max_diff, diff); + + if (!utils::isclose(q_smem_actual[i], q_smem_expected[i], 1e-3f, 1e-4f)) + { + if (mismatch_count < 10) { // Show first 10 mismatches + size_t row = i / head_dim; + size_t col = i % head_dim; + std::cout << "Mismatch at [" << row << "][" << col + << "] (index " << i << "): " + << "expected " << expected << ", got " << actual + << ", diff " << diff << std::endl; + } + mismatch_count++; + passed = false; + } + } + + std::cout << "Max difference: " << max_diff << std::endl; + std::cout << "Total mismatches: " << mismatch_count << " / " + << q_smem_actual.size() << std::endl; + std::cout << "Q loading test: " << (passed ? "PASSED" : "FAILED") + << std::endl; + + // Show some sample values for debugging + if (!passed) { + std::cout << "\nFirst 10 expected vs actual values:" << std::endl; + for (size_t i = 0; i < std::min(size_t(10), q_smem_actual.size()); ++i) + { + float actual = + fi::con::explicit_casting(q_smem_actual[i]); + float expected = + fi::con::explicit_casting(q_smem_expected[i]); + std::cout << "[" << i << "] expected: " << expected + << ", actual: " << actual << std::endl; + } + } + + // Cleanup + FI_GPU_CALL(hipFree(q_device)); + FI_GPU_CALL(hipFree(q_smem_output)); + + return passed; +} + +// Main function +int main() +{ + std::cout << "=== FlashInfer Q Loading Component Test ===" << std::endl; + std::cout << "Testing load_q_global_smem function for CDNA3 architecture" + << std::endl; + + // Initialize HIP + hipError_t err = hipSetDevice(0); + if (err != hipSuccess) { + std::cout << "Failed to set HIP device: " << hipGetErrorString(err) + << std::endl; + return 1; + } + + hipDeviceProp_t prop; + FI_GPU_CALL(hipGetDeviceProperties(&prop, 0)); + std::cout << "Running on: " << prop.name << std::endl; + + bool all_passed = true; + + // Test with half precision + std::cout << "\n--- Testing with FP16 ---" << std::endl; + all_passed &= test_q_loading_correctness<__half>(); + + if (all_passed) { + std::cout << "\n✅ All Q loading tests PASSED!" << std::endl; + return 0; + } + else { + std::cout << "\n❌ Some Q loading tests FAILED!" << std::endl; + return 1; + } +} diff --git a/libflashinfer/tests/hip/test_single_prefill.cpp b/libflashinfer/tests/hip/test_single_prefill.cpp new file mode 100644 index 0000000000..e335bf4df2 --- /dev/null +++ b/libflashinfer/tests/hip/test_single_prefill.cpp @@ -0,0 +1,405 @@ +// SPDX - FileCopyrightText : 2023 - 2025 Flashinfer team +// SPDX - FileCopyrightText : 2025 Advanced Micro Devices, Inc. +// +// SPDX - License - Identifier : Apache 2.0 + +#include "flashinfer/attention/generic/prefill.cuh" + +#include "../../utils/cpu_reference_hip.h" +#include "../../utils/flashinfer_prefill_ops.hip.h" +#include "../../utils/utils_hip.h" + +#include + +#include + +#define HIP_ENABLE_WARP_SYNC_BUILTINS 1 + +using namespace flashinfer; + +template +void _TestSinglePrefillKernelCorrectness(size_t qo_len, + size_t kv_len, + size_t num_qo_heads, + size_t num_kv_heads, + size_t head_dim, + bool causal, + QKVLayout kv_layout, + PosEncodingMode pos_encoding_mode, + bool use_fp16_qk_reduction, + float rtol = 1e-3, + float atol = 1e-3) +{ + std::vector q(qo_len * num_qo_heads * head_dim); + std::vector k(kv_len * num_kv_heads * head_dim); + std::vector v(kv_len * num_kv_heads * head_dim); + std::vector o(qo_len * num_qo_heads * head_dim); + + utils::vec_normal_(q); + utils::vec_normal_(k); + utils::vec_normal_(v); + utils::vec_zero_(o); + + DTypeQ *q_d; + hipMalloc(&q_d, q.size() * sizeof(DTypeQ)); + hipMemcpy(q_d, q.data(), q.size() * sizeof(DTypeQ), hipMemcpyHostToDevice); + + DTypeKV *k_d; + hipMalloc(&k_d, k.size() * sizeof(DTypeKV)); + hipMemcpy(k_d, k.data(), k.size() * sizeof(DTypeKV), hipMemcpyHostToDevice); + + DTypeKV *v_d; + hipMalloc(&v_d, v.size() * sizeof(DTypeKV)); + hipMemcpy(v_d, v.data(), v.size() * sizeof(DTypeKV), hipMemcpyHostToDevice); + + DTypeO *o_d; + hipMalloc(&o_d, o.size() * sizeof(DTypeO)); + hipMemcpy(o_d, o.data(), o.size() * sizeof(DTypeO), hipMemcpyHostToDevice); + + DTypeO *tmp_d; + hipMalloc(&tmp_d, 16 * 1024 * 1024 * sizeof(DTypeO)); + + hipError_t status = + flashinfer::SinglePrefillWithKVCache( + q_d, k_d, v_d, o_d, tmp_d, + /*lse=*/nullptr, num_qo_heads, num_kv_heads, qo_len, kv_len, + head_dim, causal, kv_layout, pos_encoding_mode, + use_fp16_qk_reduction); + + EXPECT_EQ(status, hipSuccess) + << "SinglePrefillWithKVCache kernel launch failed, error message: " + << hipGetErrorString(status); + + std::vector o_h(o.size()); + hipMemcpy(o_h.data(), o_d, o_h.size() * sizeof(DTypeO), + hipMemcpyDeviceToHost); + + // Print the first 10 elements of the output vector for debugging + // std::cout << "Output vector (first 10 elements):"; + // std::cout << "[" << std::endl; + // for (int i = 0; i < 10; ++i) { + // std::cout << fi::con::explicit_casting(o_h[i]) << " "; + // } + // std::cout << "]" << std::endl; + + bool isEmpty = o_h.empty(); + EXPECT_EQ(isEmpty, false) << "Output vector is empty"; + + std::vector att_out; + std::vector o_ref = + cpu_reference::single_mha( + q, k, v, att_out, qo_len, kv_len, num_qo_heads, num_kv_heads, + head_dim, causal, kv_layout, pos_encoding_mode); + size_t num_results_error_atol = 0; + bool nan_detected = false; + + for (size_t i = 0; i < o_ref.size(); ++i) { + float o_h_val = fi::con::explicit_casting(o_h[i]); + float o_ref_val = fi::con::explicit_casting(o_ref[i]); + + if (isnan(o_h_val)) { + nan_detected = true; + } + + num_results_error_atol += + (!utils::isclose(o_ref_val, o_h_val, rtol, atol)); + if (!utils::isclose(o_ref_val, o_h_val, rtol, atol)) { + std::cout << "i=" << i << ", o_ref[i]=" << o_ref_val + << ", o_h[i]=" << o_h_val << std::endl; + } + } + // std::cout<<"Printing att_out vector:\n"; + // for(auto i: att_out) { + // std::cout << i << "\n"; + // } + float result_accuracy = + 1. - float(num_results_error_atol) / float(o_ref.size()); + std::cout << "num_qo_heads=" << num_qo_heads + << ", num_kv_heads=" << num_kv_heads << ", qo_len=" << qo_len + << ", kv_len=" << kv_len << ", head_dim=" << head_dim + << ", causal=" << causal + << ", kv_layout=" << QKVLayoutToString(kv_layout) + << ", pos_encoding_mode=" + << PosEncodingModeToString(pos_encoding_mode) + << ", result_accuracy=" << result_accuracy << std::endl; + + EXPECT_GT(result_accuracy, 0.90) << "Result correctness test failed."; + EXPECT_FALSE(nan_detected) << "Nan detected in the result."; + + hipFree(q_d); + hipFree(k_d); + hipFree(v_d); + hipFree(o_d); + hipFree(tmp_d); +} + +// template +// void TestSinglePrefillKernelLongContextCorrectness(bool +// use_fp16_qk_reduction) +// { +// for (size_t qo_len : {1, 31, 63, 127}) { +// for (size_t kv_len : {31717}) { +// for (size_t num_heads : {1}) { +// for (size_t head_dim : {64, 128, 256}) { +// for (bool causal : {false, true}) { +// for (size_t pos_encoding_mode : {0, 1}) { +// for (size_t kv_layout : {0, 1}) { +// _TestSinglePrefillKernelCorrectness< +// DTypeIn, DTypeIn, DTypeO>( +// qo_len, kv_len, num_heads, num_heads, +// head_dim, causal, QKVLayout(kv_layout), +// PosEncodingMode(pos_encoding_mode), +// use_fp16_qk_reduction); +// } +// } +// } +// } +// } +// } +// } +// } +//*********************************************************************** +// The following tests are disabled because we dont support fp8 <-> float +// conversions + +// template +// void TestSinglePrefillFP8KernelLongContextCorrectness(bool +// use_fp16_qk_reduction) { +// for (size_t qo_len : {1, 31, 63, 127}) { +// for (size_t kv_len : {31717}) { +// for (size_t num_heads : {1}) { +// for (size_t head_dim : {64, 128, 256}) { +// for (bool causal : {false, true}) { +// for (size_t pos_encoding_mode : {0}) { +// for (size_t kv_layout : {0, 1}) { +// _TestSinglePrefillKernelCorrectness<__half, DTypeKV, __half>( +// qo_len, kv_len, num_heads, num_heads, head_dim, causal, +// QKVLayout(kv_layout), PosEncodingMode(pos_encoding_mode), +// use_fp16_qk_reduction); +// } +// } +// } +// } +// } +// } +// } +// } + +// template +// void TestSinglePrefillKernelShortContextCorrectness(bool +// use_fp16_qk_reduction) +// { +// float rtol = std::is_same::value ? 1e-2 : 1e-3; +// float atol = std::is_same::value ? 1e-2 : 1e-3; +// for (size_t qkv_len : {2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37}) { +// for (size_t num_qo_heads : {32}) { +// for (size_t num_kv_heads : {4, 8, 32}) { +// for (size_t head_dim : {64, 128, 256}) { +// for (bool causal : {false, true}) { +// for (size_t pos_encoding_mode : {0, 1}) { +// for (size_t kv_layout : {0, 1}) { +// _TestSinglePrefillKernelCorrectness< +// DTypeIn, DTypeIn, DTypeO>( +// qkv_len, qkv_len, num_qo_heads, +// num_kv_heads, head_dim, causal, +// QKVLayout(kv_layout), +// PosEncodingMode(pos_encoding_mode), +// use_fp16_qk_reduction, rtol, atol); +// } +// } +// } +// } +// } +// } +// } +// } + +//*********************************************************************** +// The following tests are disabled because we dont support fp8 <-> float +// conversions + +// template +// void TestSinglePrefillFP8KernelShortContextCorrectness(bool +// use_fp16_qk_reduction) { +// float rtol = 1e-3; +// float atol = 1e-3; +// for (size_t qkv_len : {2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37}) { +// for (size_t num_qo_heads : {32}) { +// for (size_t num_kv_heads : {4, 8, 32}) { +// for (size_t head_dim : {64, 128, 256}) { +// for (bool causal : {false, true}) { +// for (size_t pos_encoding_mode : {0}) { +// for (size_t kv_layout : {0, 1}) { +// _TestSinglePrefillKernelCorrectness<__half, DTypeKV, __half>( +// qkv_len, qkv_len, num_qo_heads, num_kv_heads, head_dim, +// causal, QKVLayout(kv_layout), +// PosEncodingMode(pos_encoding_mode), +// use_fp16_qk_reduction, rtol, atol); +// } +// } +// } +// } +// } +// } +// } +// } + +// template +// void TestSinglePrefillKernelCorrectness(bool use_fp16_qk_reduction) +// { +// for (size_t qo_len : {399, 400, 401}) { +// for (size_t kv_len : {533, 534, 535}) { +// for (size_t num_heads : {12}) { +// for (size_t head_dim : {64, 128, 256}) { +// for (bool causal : {false, true}) { +// for (size_t pos_encoding_mode : {0, 1}) { +// for (size_t kv_layout : {0, 1}) { +// _TestSinglePrefillKernelCorrectness< +// DTypeIn, DTypeIn, DTypeO>( +// qo_len, kv_len, num_heads, num_heads, +// head_dim, causal, QKVLayout(kv_layout), +// PosEncodingMode(pos_encoding_mode), +// use_fp16_qk_reduction); +// } +// } +// } +// } +// } +// } +// } +// } + +// template +// void TestSinglePrefillFP8KernelCorrectness(bool use_fp16_qk_reduction) +// { +// for (size_t qo_len : {399, 400, 401}) { +// for (size_t kv_len : {533, 534, 535}) { +// for (size_t num_heads : {12}) { +// for (size_t head_dim : {64, 128, 256}) { +// for (bool causal : {false, true}) { +// for (size_t pos_encoding_mode : {0}) { +// for (size_t kv_layout : {0, 1}) { +// _TestSinglePrefillKernelCorrectness< +// __half, DTypeKV, __half>( +// qo_len, kv_len, num_heads, num_heads, +// head_dim, causal, QKVLayout(kv_layout), +// PosEncodingMode(pos_encoding_mode), +// use_fp16_qk_reduction); +// } +// } +// } +// } +// } +// } +// } +// } + +// TEST(FlashInferCorrectnessTest, +// TestSinglePrefillKernelLongContextCorrectnessFP16) +// { +// TestSinglePrefillKernelLongContextCorrectness<__half, __half>(false); +// } + +// TEST(FlashInferCorrectnessTest, +// TestSinglePrefillKernelLongContextCorrectnessFP16QKHalfAccum) +// { +// TestSinglePrefillKernelLongContextCorrectness<__half, __half>(true); +// } + +// TEST(FlashInferCorrectnessTest, +// TestSinglePrefillKernelShortContextCorrectnessFP16) +// { +// TestSinglePrefillKernelShortContextCorrectness<__half, __half>(false); +// } + +// TEST(FlashInferCorrectnessTest, +// TestSinglePrefillKernelShortContextCorrectnessFP16QKHalfAccum) +// { +// TestSinglePrefillKernelShortContextCorrectness<__half, __half>(true); +// } + +// TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelCorrectnessTestFP16) +// { +// TestSinglePrefillKernelCorrectness<__half, __half>(false); +// } + +// TEST(FlashInferCorrectnessTest, +// TestSinglePrefillKernelCorrectnessTestFP16QKHalfAccum) +// { +// TestSinglePrefillKernelCorrectness<__half, __half>(true); +// } + +// #ifdef FLASHINFER_ENABLE_BF16 +// TEST(FlashInferCorrectnessTest, +// TestSinglePrefillKernelLongContextCorrectnessBF16) +// { +// TestSinglePrefillKernelLongContextCorrectness<__hip_bfloat16, +// __hip_bfloat16>( +// false); +// } +// TEST(FlashInferCorrectnessTest, +// TestSinglePrefillKernelShortContextCorrectnessBF16) +// { +// TestSinglePrefillKernelShortContextCorrectness<__hip_bfloat16, +// __hip_bfloat16>( +// false); +// } +// TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelCorrectnessTestBF16) +// { +// TestSinglePrefillKernelCorrectness<__hip_bfloat16, +// __hip_bfloat16>(false); +// } +// #endif + +//*********************************************************************** +// The following tests are disabled because we dont support fp8 <-> float +// conversions + +// #ifdef FLASHINFER_ENABLE_FP8_E4M3 +// TEST(FlashInferCorrectnessTest, +// TestSinglePrefillKernelShortContextCorrectnessE4M3) { +// TestSinglePrefillFP8KernelShortContextCorrectness<__nv_fp8_e4m3>(false); +// } +// TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelCorrectnessTestE4M3) { +// TestSinglePrefillFP8KernelCorrectness<__nv_fp8_e4m3>(false); +// } +// TEST(FlashInferCorrectnessTest, +// TestSinglePrefillKernelLongContextCorrectnessE4M3) { +// TestSinglePrefillFP8KernelLongContextCorrectness<__nv_fp8_e4m3>(false); +// } +// #endif + +// #ifdef FLASHINFER_ENABLE_FP8_E5M2 +// TEST(FlashInferCorrectnessTest, +// TestSinglePrefillKernelShortContextCorrectnessE5M2) { +// TestSinglePrefillFP8KernelShortContextCorrectness<__nv_fp8_e5m2>(false); +// } +// TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelCorrectnessTestE5M2) { +// TestSinglePrefillFP8KernelCorrectness<__nv_fp8_e5m2>(false); +// } +// TEST(FlashInferCorrectnessTest, +// TestSinglePrefillKernelLongContextCorrectnessE5M2) { +// TestSinglePrefillFP8KernelLongContextCorrectness<__nv_fp8_e5m2>(false); +// } +// #endif + +int main(int argc, char **argv) +{ + // ::testing::InitGoogleTest(&argc, argv); + // return RUN_ALL_TESTS(); + using DTypeIn = __half; + using DTypeO = __half; + bool use_fp16_qk_reduction = false; + size_t qo_len = 399; + size_t kv_len = 533; + size_t num_heads = 1; + size_t head_dim = 64; + bool causal = false; + size_t pos_encoding_mode = 0; + size_t kv_layout = 0; + + _TestSinglePrefillKernelCorrectness( + qo_len, kv_len, num_heads, num_heads, head_dim, causal, + QKVLayout(kv_layout), PosEncodingMode(pos_encoding_mode), + use_fp16_qk_reduction); +} diff --git a/libflashinfer/utils/cpu_reference_hip.h b/libflashinfer/utils/cpu_reference_hip.h index 059c0fc2d1..bcbe3a8c5a 100644 --- a/libflashinfer/utils/cpu_reference_hip.h +++ b/libflashinfer/utils/cpu_reference_hip.h @@ -83,7 +83,7 @@ inline std::vector apply_llama_rope(const T *input, rst[k] = cos * fi::con::explicit_casting(input[k]) + sin * permuted_input[k]; } - return std::move(rst); + return rst; } template diff --git a/libflashinfer/utils/flashinfer_prefill_ops.hip.h b/libflashinfer/utils/flashinfer_prefill_ops.hip.h new file mode 100644 index 0000000000..dfc83a8c5c --- /dev/null +++ b/libflashinfer/utils/flashinfer_prefill_ops.hip.h @@ -0,0 +1,166 @@ +// SPDX - FileCopyrightText : 2023 - 2025 Flashinfer team +// SPDX - FileCopyrightText : 2025 Advanced Micro Devices, Inc. +// +// SPDX - License - Identifier : Apache 2.0 + +#pragma once + +#include "utils_hip.h" + +#include "flashinfer/attention/generic/allocator.h" +#include "flashinfer/attention/generic/default_prefill_params.cuh" +#include "flashinfer/attention/generic/exception.h" +#include "flashinfer/attention/generic/prefill.cuh" +#include "flashinfer/attention/generic/scheduler.cuh" +#include "flashinfer/attention/generic/variants.cuh" + +#include "gpu_iface/enums.hpp" +#include "gpu_iface/layout.cuh" +#include + +namespace flashinfer +{ + +template +hipError_t SinglePrefillWithKVCacheDispatched(Params params, + typename Params::DTypeO *tmp, + hipStream_t stream); + +template +hipError_t SinglePrefillWithKVCacheCustomMask( + DTypeIn *q, + DTypeIn *k, + DTypeIn *v, + uint8_t *custom_mask, + DTypeO *o, + DTypeO *tmp, + float *lse, + uint32_t num_qo_heads, + uint32_t num_kv_heads, + uint32_t qo_len, + uint32_t kv_len, + uint32_t head_dim, + QKVLayout kv_layout = QKVLayout::kNHD, + PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, + bool use_fp16_qk_reduction = false, + std::optional maybe_sm_scale = std::nullopt, + float rope_scale = 1.f, + float rope_theta = 1e4, + hipStream_t stream = nullptr) +{ + const float sm_scale = + maybe_sm_scale.value_or(1.f / std::sqrt(float(head_dim))); + auto [qo_stride_n, qo_stride_h, kv_stride_n, kv_stride_h] = get_qkv_strides( + kv_layout, kv_len, num_qo_heads, num_kv_heads, head_dim); + DISPATCH_use_fp16_qk_reduction( + use_fp16_qk_reduction, USE_FP16_QK_REDUCTION, + {DISPATCH_head_dim( + head_dim, HEAD_DIM, + {DISPATCH_pos_encoding_mode(pos_encoding_mode, POS_ENCODING_MODE, { + using Params = SinglePrefillParams; + using AttentionVariant = DefaultAttention< + /*use_custom_mask=*/true, /*use_sliding_window=*/false, + /*use_logits_soft_cap=*/false, /*use_alibi=*/false>; + Params params(q, k, v, custom_mask, o, lse, + /*alibi_slopes=*/nullptr, num_qo_heads, + num_kv_heads, qo_len, kv_len, qo_stride_n, + qo_stride_h, kv_stride_n, kv_stride_h, head_dim, + /*window_left=*/-1, + /*logits_soft_cap=*/0.f, sm_scale, rope_scale, + rope_theta); + return SinglePrefillWithKVCacheDispatched< + HEAD_DIM, HEAD_DIM, POS_ENCODING_MODE, + USE_FP16_QK_REDUCTION, MaskMode::kCustom, AttentionVariant>( + params, tmp, stream); + })})}); + return hipSuccess; +} + +/*! + * \brief FlashAttention prefill hip function for a single request. + * \tparam DTypeIn The data type of input + * \tparam DTypeO The data type of output + * \param q The query tensor. + * \param k The key tensor. + * \param v The value tensor. + * \param o The output tensor. + * \param tmp The temporary storage (only used for cooperative kernel). + * \param lse The logsumexp values. + * \param num_qo_heads The number of query and output heads. + * \param num_kv_heads The number of key and value heads. + * \param qo_len The length of query and output. + * \param kv_len The length of key and value. + * \param head_dim The dimension of each head. + * \param causal Whether to use causal attention. + * \param kv_layout The layout of input and output. + * \param pos_encoding_mode The positional encoding mode. + * \param use_fp16_qk_reduction Whether to allow accumulating q*k^T with fp16. + * \param rope_scale The scaling factor used in RoPE interpolation. + * \param rope_theta The theta used in RoPE. + * \param stream The hip stream to execute the kernel on. + * \return status Indicates whether hip calls are successful + */ +template +hipError_t SinglePrefillWithKVCache( + DTypeQ *q, + DTypeKV *k, + DTypeKV *v, + DTypeO *o, + DTypeO *tmp, + float *lse, + uint32_t num_qo_heads, + uint32_t num_kv_heads, + uint32_t qo_len, + uint32_t kv_len, + uint32_t head_dim, + bool causal = true, + QKVLayout kv_layout = QKVLayout::kNHD, + PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, + bool use_fp16_qk_reduction = false, + std::optional maybe_sm_scale = std::nullopt, + float rope_scale = 1.f, + float rope_theta = 1e4, + hipStream_t stream = nullptr) +{ + const float sm_scale = + maybe_sm_scale.value_or(1.f / std::sqrt(float(head_dim))); + const MaskMode mask_mode = causal ? MaskMode::kCausal : MaskMode::kNone; + auto [qo_stride_n, qo_stride_h, kv_stride_n, kv_stride_h] = get_qkv_strides( + kv_layout, kv_len, num_qo_heads, num_kv_heads, head_dim); + DISPATCH_use_fp16_qk_reduction( + use_fp16_qk_reduction, USE_FP16_QK_REDUCTION, + {DISPATCH_mask_mode( + mask_mode, MASK_MODE, + {DISPATCH_head_dim( + head_dim, HEAD_DIM, + {DISPATCH_pos_encoding_mode( + pos_encoding_mode, POS_ENCODING_MODE, { + using Params = + SinglePrefillParams; + using AttentionVariant = DefaultAttention< + /*use_custom_mask=*/(MASK_MODE == + MaskMode::kCustom), + /*use_sliding_window=*/false, + /*use_logits_soft_cap=*/false, /*use_alibi=*/false>; + Params params(q, k, v, /*custom_mask=*/nullptr, o, lse, + /*alibi_slopes=*/nullptr, num_qo_heads, + num_kv_heads, qo_len, kv_len, qo_stride_n, + qo_stride_h, kv_stride_n, kv_stride_h, + head_dim, + /*window_left=*/-1, + /*logits_soft_cap=*/0.f, sm_scale, + rope_scale, rope_theta); + return SinglePrefillWithKVCacheDispatched< + HEAD_DIM, HEAD_DIM, POS_ENCODING_MODE, + USE_FP16_QK_REDUCTION, MASK_MODE, AttentionVariant, + Params>(params, tmp, stream); + })})})}); + return hipSuccess; +} +} // namespace flashinfer From 9e2e3f57ed510bc8d4cec5d0b297891d00ebebb5 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Sun, 17 Aug 2025 03:53:10 -0400 Subject: [PATCH 044/109] Updated load_q_global_smem_kernel test --- .../hip/test_load_q_global_smem_kernel.cpp | 32 ++++++++++++------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/libflashinfer/tests/hip/test_load_q_global_smem_kernel.cpp b/libflashinfer/tests/hip/test_load_q_global_smem_kernel.cpp index 9df0feb9ed..44ba713802 100644 --- a/libflashinfer/tests/hip/test_load_q_global_smem_kernel.cpp +++ b/libflashinfer/tests/hip/test_load_q_global_smem_kernel.cpp @@ -92,21 +92,31 @@ __global__ void test_q_loading_kernel( // Synchronize to ensure loading is complete __syncthreads(); - // Copy shared memory contents to global memory for verification - // Only first warp does this to avoid conflicts if (threadIdx.y == 0 && threadIdx.z == 0) { const uint32_t lane_idx = threadIdx.x; - const uint32_t total_elements = - KTraits::CTA_TILE_Q * KTraits::HEAD_DIM_QK; + constexpr uint32_t smem_height = KTraits::CTA_TILE_Q; // 16 + constexpr uint32_t smem_width = KTraits::HEAD_DIM_QK; // 64 + constexpr uint32_t total_elements = smem_height * smem_width; - // Each thread copies a portion of shared memory - for (uint32_t i = lane_idx; i < total_elements; - i += KTraits::NUM_THREADS) + // Each thread copies using proper swizzled access + for (uint32_t linear_idx = lane_idx; linear_idx < total_elements; + linear_idx += KTraits::NUM_THREADS) { - // For linear swizzle mode, direct copy - if (i < total_elements) { - q_smem_output[i] = reinterpret_cast( - smem_storage.q_smem)[i]; + if (linear_idx < total_elements) { + uint32_t row = linear_idx / smem_width; + uint32_t col = linear_idx % smem_width; + uint32_t swizzled_offset = q_smem.template get_permuted_offset< + smem_width / upcast_size()>( + row, col / upcast_size()); + uint32_t element_idx = + col % upcast_size(); + typename KTraits::DTypeQ *smem_ptr = + reinterpret_cast( + q_smem.base + swizzled_offset); + q_smem_output[linear_idx] = smem_ptr[element_idx]; } } } From 78b559f11732f4049babc4b0e75cde43a0567821 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Mon, 18 Aug 2025 05:51:45 -0400 Subject: [PATCH 045/109] Fixes to test_single_prefill.cpp --- .../generic/default_prefill_params.cuh | 4 +- .../flashinfer/attention/generic/prefill.cuh | 8 ++-- .../tests/hip/test_single_prefill.cpp | 44 ++++++++++--------- .../utils/flashinfer_prefill_ops.hip.h | 4 +- 4 files changed, 32 insertions(+), 28 deletions(-) diff --git a/libflashinfer/include/flashinfer/attention/generic/default_prefill_params.cuh b/libflashinfer/include/flashinfer/attention/generic/default_prefill_params.cuh index 9d4468267a..2b558d03cf 100644 --- a/libflashinfer/include/flashinfer/attention/generic/default_prefill_params.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/default_prefill_params.cuh @@ -41,10 +41,10 @@ struct SinglePrefillParams float *lse; float *maybe_alibi_slopes; uint_fastdiv group_size; - uint32_t qo_len; - uint32_t kv_len; uint32_t num_qo_heads; uint32_t num_kv_heads; + uint32_t qo_len; + uint32_t kv_len; uint32_t q_stride_n; uint32_t q_stride_h; uint32_t k_stride_n; diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh index 6e213bc20f..00c98fb48f 100644 --- a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -152,8 +152,8 @@ struct KernelTraits static constexpr SwizzleMode SWIZZLE_MODE_KV = SwizzleMode::kLinear; // Presently we use 16x4 thread layout for all cases. - static constexpr uint32_t KV_THR_LAYOUT_ROW = WARP_THREAD_ROWS; - static constexpr uint32_t KV_THR_LAYOUT_COL = WARP_THREAD_COLS; + static constexpr uint32_t KV_THR_LAYOUT_ROW = 16; + static constexpr uint32_t KV_THR_LAYOUT_COL = 4; // The constant is defined based on the matrix layout of the "D/C" // accumulator matrix in a D = A*B+C computation. On CDNA3 the D/C matrices // are distributed as four 4x16 bands across the 64 threads. Each thread @@ -402,8 +402,8 @@ __device__ __forceinline__ void produce_kv_helper_( const uint32_t kv_len) { using DTypeKV = typename KTraits::DTypeKV; - constexpr uint32_t WARP_THREAD_COLS = KTraits::WARP_THREAD_COLS; - constexpr uint32_t WARP_THREAD_ROWS = KTraits::WARP_THREAD_ROWS; + constexpr uint32_t WARP_THREAD_COLS = KTraits::KV_THR_LAYOUT_COL; + constexpr uint32_t WARP_THREAD_ROWS = KTraits::KV_THR_LAYOUT_ROW; constexpr uint32_t NUM_WARPS = KTraits::NUM_WARPS; constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; constexpr uint32_t NUM_WARPS_Q = KTraits::NUM_WARPS_Q; diff --git a/libflashinfer/tests/hip/test_single_prefill.cpp b/libflashinfer/tests/hip/test_single_prefill.cpp index e335bf4df2..20a50ba911 100644 --- a/libflashinfer/tests/hip/test_single_prefill.cpp +++ b/libflashinfer/tests/hip/test_single_prefill.cpp @@ -3,11 +3,11 @@ // // SPDX - License - Identifier : Apache 2.0 -#include "flashinfer/attention/generic/prefill.cuh" - #include "../../utils/cpu_reference_hip.h" #include "../../utils/flashinfer_prefill_ops.hip.h" #include "../../utils/utils_hip.h" +#include "flashinfer/attention/generic/prefill.cuh" +#include "gpu_iface/gpu_runtime_compat.hpp" #include @@ -41,23 +41,27 @@ void _TestSinglePrefillKernelCorrectness(size_t qo_len, utils::vec_zero_(o); DTypeQ *q_d; - hipMalloc(&q_d, q.size() * sizeof(DTypeQ)); - hipMemcpy(q_d, q.data(), q.size() * sizeof(DTypeQ), hipMemcpyHostToDevice); + FI_GPU_CALL(hipMalloc(&q_d, q.size() * sizeof(DTypeQ))); + FI_GPU_CALL(hipMemcpy(q_d, q.data(), q.size() * sizeof(DTypeQ), + hipMemcpyHostToDevice)); DTypeKV *k_d; - hipMalloc(&k_d, k.size() * sizeof(DTypeKV)); - hipMemcpy(k_d, k.data(), k.size() * sizeof(DTypeKV), hipMemcpyHostToDevice); + FI_GPU_CALL(hipMalloc(&k_d, k.size() * sizeof(DTypeKV))); + FI_GPU_CALL(hipMemcpy(k_d, k.data(), k.size() * sizeof(DTypeKV), + hipMemcpyHostToDevice)); DTypeKV *v_d; - hipMalloc(&v_d, v.size() * sizeof(DTypeKV)); - hipMemcpy(v_d, v.data(), v.size() * sizeof(DTypeKV), hipMemcpyHostToDevice); + FI_GPU_CALL(hipMalloc(&v_d, v.size() * sizeof(DTypeKV))); + FI_GPU_CALL(hipMemcpy(v_d, v.data(), v.size() * sizeof(DTypeKV), + hipMemcpyHostToDevice)); DTypeO *o_d; - hipMalloc(&o_d, o.size() * sizeof(DTypeO)); - hipMemcpy(o_d, o.data(), o.size() * sizeof(DTypeO), hipMemcpyHostToDevice); + FI_GPU_CALL(hipMalloc(&o_d, o.size() * sizeof(DTypeO))); + FI_GPU_CALL(hipMemcpy(o_d, o.data(), o.size() * sizeof(DTypeO), + hipMemcpyHostToDevice)); DTypeO *tmp_d; - hipMalloc(&tmp_d, 16 * 1024 * 1024 * sizeof(DTypeO)); + FI_GPU_CALL(hipMalloc(&tmp_d, 16 * 1024 * 1024 * sizeof(DTypeO))); hipError_t status = flashinfer::SinglePrefillWithKVCache( @@ -71,8 +75,8 @@ void _TestSinglePrefillKernelCorrectness(size_t qo_len, << hipGetErrorString(status); std::vector o_h(o.size()); - hipMemcpy(o_h.data(), o_d, o_h.size() * sizeof(DTypeO), - hipMemcpyDeviceToHost); + FI_GPU_CALL(hipMemcpy(o_h.data(), o_d, o_h.size() * sizeof(DTypeO), + hipMemcpyDeviceToHost)); // Print the first 10 elements of the output vector for debugging // std::cout << "Output vector (first 10 elements):"; @@ -88,8 +92,8 @@ void _TestSinglePrefillKernelCorrectness(size_t qo_len, std::vector att_out; std::vector o_ref = cpu_reference::single_mha( - q, k, v, att_out, qo_len, kv_len, num_qo_heads, num_kv_heads, - head_dim, causal, kv_layout, pos_encoding_mode); + q, k, v, qo_len, kv_len, num_qo_heads, num_kv_heads, head_dim, + causal, kv_layout, pos_encoding_mode); size_t num_results_error_atol = 0; bool nan_detected = false; @@ -126,11 +130,11 @@ void _TestSinglePrefillKernelCorrectness(size_t qo_len, EXPECT_GT(result_accuracy, 0.90) << "Result correctness test failed."; EXPECT_FALSE(nan_detected) << "Nan detected in the result."; - hipFree(q_d); - hipFree(k_d); - hipFree(v_d); - hipFree(o_d); - hipFree(tmp_d); + FI_GPU_CALL(hipFree(q_d)); + FI_GPU_CALL(hipFree(k_d)); + FI_GPU_CALL(hipFree(v_d)); + FI_GPU_CALL(hipFree(o_d)); + FI_GPU_CALL(hipFree(tmp_d)); } // template diff --git a/libflashinfer/utils/flashinfer_prefill_ops.hip.h b/libflashinfer/utils/flashinfer_prefill_ops.hip.h index dfc83a8c5c..db4a2694e5 100644 --- a/libflashinfer/utils/flashinfer_prefill_ops.hip.h +++ b/libflashinfer/utils/flashinfer_prefill_ops.hip.h @@ -59,7 +59,7 @@ hipError_t SinglePrefillWithKVCacheCustomMask( auto [qo_stride_n, qo_stride_h, kv_stride_n, kv_stride_h] = get_qkv_strides( kv_layout, kv_len, num_qo_heads, num_kv_heads, head_dim); DISPATCH_use_fp16_qk_reduction( - use_fp16_qk_reduction, USE_FP16_QK_REDUCTION, + static_cast(use_fp16_qk_reduction), USE_FP16_QK_REDUCTION, {DISPATCH_head_dim( head_dim, HEAD_DIM, {DISPATCH_pos_encoding_mode(pos_encoding_mode, POS_ENCODING_MODE, { @@ -134,7 +134,7 @@ hipError_t SinglePrefillWithKVCache( auto [qo_stride_n, qo_stride_h, kv_stride_n, kv_stride_h] = get_qkv_strides( kv_layout, kv_len, num_qo_heads, num_kv_heads, head_dim); DISPATCH_use_fp16_qk_reduction( - use_fp16_qk_reduction, USE_FP16_QK_REDUCTION, + static_cast(use_fp16_qk_reduction), USE_FP16_QK_REDUCTION, {DISPATCH_mask_mode( mask_mode, MASK_MODE, {DISPATCH_head_dim( From 86a9780063d4f8a3bfb6ebe85faf66ddb659dcd3 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Mon, 18 Aug 2025 06:46:25 -0400 Subject: [PATCH 046/109] Initial stubs for compute_qk --- .../tests/hip/test_single_prefill.cpp | 8 + libflashinfer/utils/compute_qk_stub.cu | 187 ++++++++++++++++++ libflashinfer/utils/cpu_reference_hip.h | 51 +++++ 3 files changed, 246 insertions(+) create mode 100644 libflashinfer/utils/compute_qk_stub.cu diff --git a/libflashinfer/tests/hip/test_single_prefill.cpp b/libflashinfer/tests/hip/test_single_prefill.cpp index 20a50ba911..c8ada17889 100644 --- a/libflashinfer/tests/hip/test_single_prefill.cpp +++ b/libflashinfer/tests/hip/test_single_prefill.cpp @@ -17,6 +17,14 @@ using namespace flashinfer; +template +void _TestComputeQKCorrectness() +{ + std::vector q(qo_len * num_qo_heads * head_dim); + std::vector k(kv_len * num_kv_heads * head_dim); + std::vector o(qo_len * num_qo_heads * head_dim); +} + template void _TestSinglePrefillKernelCorrectness(size_t qo_len, size_t kv_len, diff --git a/libflashinfer/utils/compute_qk_stub.cu b/libflashinfer/utils/compute_qk_stub.cu new file mode 100644 index 0000000000..061340eda6 --- /dev/null +++ b/libflashinfer/utils/compute_qk_stub.cu @@ -0,0 +1,187 @@ +// SPDX-FileCopyrightText: 2025 Advanced Micro Devices, Inc. +// SPDX-License-Identifier: Apache-2.0 + +#include "flashinfer/attention/generic/prefill.cuh" +#include "gpu_iface/gpu_runtime_compat.hpp" + +using namespace flashinfer; + +template +__global__ void ComputeQKStubKernel(typename KTraits::DTypeQ *q, + typename KTraits::DTypeKV *k, + float *qk_scores_output, + uint32_t qo_len, + uint32_t kv_len, + uint32_t num_qo_heads, + uint32_t num_kv_heads, + uint32_t q_stride_n, + uint32_t q_stride_h, + uint32_t k_stride_n, + uint32_t k_stride_h, + uint_fastdiv group_size) +{ + using DTypeQ = typename KTraits::DTypeQ; + using DTypeKV = typename KTraits::DTypeKV; + using DTypeQKAccum = typename KTraits::DTypeQKAccum; + + extern __shared__ uint8_t smem[]; + typename KTraits::SharedStorage &smem_storage = + reinterpret_cast(smem); + + // Initialize shared memory objects + smem_t q_smem( + smem_storage.q_smem); + smem_t k_smem( + smem_storage.k_smem); + + const uint32_t lane_idx = threadIdx.x; + const uint32_t warp_idx = get_warp_idx(threadIdx.y, threadIdx.z); + const uint32_t kv_head_idx = blockIdx.z; + + // 1. Load Q into shared memory (same as SinglePrefillWithKVCacheDevice) + const uint32_t qo_packed_idx_base = (blockIdx.x * KTraits::NUM_WARPS_Q + + get_warp_idx_q(threadIdx.y)) * + KTraits::NUM_MMA_Q * 16; + + DTypeQ *q_ptr_base = q + (kv_head_idx * group_size) * q_stride_h; + + load_q_global_smem(qo_packed_idx_base, qo_len, q_ptr_base, + q_stride_n, q_stride_h, group_size, &q_smem, + threadIdx); + + // 2. Load K into shared memory (same as SinglePrefillWithKVCacheDevice) + DTypeKV *k_ptr = k + + (warp_idx * KTraits::KV_THR_LAYOUT_ROW + + lane_idx / KTraits::KV_THR_LAYOUT_COL) * + k_stride_n + + kv_head_idx * k_stride_h + + (lane_idx % KTraits::KV_THR_LAYOUT_COL) * + upcast_size(); + + uint32_t k_smem_offset_w = + k_smem.template get_permuted_offset( + warp_idx * KTraits::KV_THR_LAYOUT_ROW + + lane_idx / KTraits::KV_THR_LAYOUT_COL, + lane_idx % KTraits::KV_THR_LAYOUT_COL); + + produce_kv( + k_smem, &k_smem_offset_w, &k_ptr, k_stride_n, 0, kv_len, threadIdx); + + __syncthreads(); + + // 3. Set up fragment offsets for compute_qk (same as + // SinglePrefillWithKVCacheDevice) + uint32_t q_smem_offset_r = + q_smem.template get_permuted_offset( + get_warp_idx_q(threadIdx.y) * KTraits::NUM_MMA_Q * 16 + + lane_idx % 16, + lane_idx / 16); + + uint32_t k_smem_offset_r = + k_smem.template get_permuted_offset( + get_warp_idx_kv(threadIdx.z) * KTraits::NUM_MMA_KV * 16 + + KTraits::HALF_ELEMS_PER_THREAD * (lane_idx / 16) + + lane_idx % KTraits::HALF_ELEMS_PER_THREAD, + (lane_idx % 16) / KTraits::HALF_ELEMS_PER_THREAD); + + // 4. Call compute_qk (the function we want to test) + DTypeQKAccum s_frag[KTraits::NUM_MMA_Q][KTraits::NUM_MMA_KV] + [KTraits::HALF_ELEMS_PER_THREAD]; + + compute_qk(&q_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, + s_frag); + + // 5. Extract attention scores from s_frag to global memory + // Simple extraction for validation - use first warp only + if (get_warp_idx_q(threadIdx.y) == 0 && + get_warp_idx_kv(threadIdx.z) == 0) + { + // Map from MMA fragment layout to sequence indices + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { + for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { + for (uint32_t reg_id = 0; + reg_id < KTraits::HALF_ELEMS_PER_THREAD; ++reg_id) + { + // Calculate global Q,K indices from fragment indices +#if defined(PLATFORM_HIP_DEVICE) + uint32_t q_idx = + mma_q * 16 + + reg_id % KTraits::NUM_ACCUM_ROWS_PER_THREAD; + uint32_t kv_idx = + mma_kv * 16 + + 2 * (lane_idx % KTraits::THREADS_PER_MATRIX_ROW_SET) + + 8 * (reg_id / 2) + reg_id % 2; +#else + uint32_t q_idx = mma_q * 16 + (reg_id % 4) / 2; + uint32_t kv_idx = + mma_kv * 16 + 8 * (reg_id / 4) + reg_id % 2; +#endif + + if (q_idx < qo_len && kv_idx < kv_len) { + uint32_t output_idx = q_idx * kv_len + kv_idx; + qk_scores_output[output_idx] = + float(s_frag[mma_q][mma_kv][reg_id]); + } + } + } + } + } +} + +template +hipError_t ComputeQKStub(DTypeQ *q, + DTypeKV *k, + float *qk_scores_output, + uint32_t qo_len, + uint32_t kv_len, + uint32_t num_qo_heads, + uint32_t num_kv_heads, + uint32_t head_dim, + hipStream_t stream = nullptr) +{ + // Use same KernelTraits selection as SinglePrefillWithKVCache + constexpr uint32_t NUM_MMA_D_QK = 4; // head_dim=64 -> 64/16=4 + constexpr uint32_t NUM_MMA_D_VO = 4; + constexpr uint32_t CTA_TILE_Q = 16; + constexpr uint32_t NUM_MMA_Q = 1; + constexpr uint32_t NUM_MMA_KV = 1; + constexpr uint32_t NUM_WARPS_Q = 1; + constexpr uint32_t NUM_WARPS_KV = 1; + + using KTraits = + KernelTraits>; + + // Launch configuration (same pattern as SinglePrefillWithKVCache) + const uint32_t group_size = num_qo_heads / num_kv_heads; + const uint_fastdiv group_size_fastdiv(group_size); + + dim3 block_size(KTraits::NUM_THREADS, 1, 1); + dim3 grid_size(1, 1, num_kv_heads); + size_t shared_mem_size = sizeof(typename KTraits::SharedStorage); + + const uint32_t q_stride_n = num_qo_heads * head_dim; + const uint32_t q_stride_h = head_dim; + const uint32_t k_stride_n = num_kv_heads * head_dim; + const uint32_t k_stride_h = head_dim; + + ComputeQKStubKernel + <<>>( + q, k, qk_scores_output, qo_len, kv_len, num_qo_heads, num_kv_heads, + q_stride_n, q_stride_h, k_stride_n, k_stride_h, group_size_fastdiv); + + return hipGetLastError(); +} + +// Explicit instantiation for common types +template hipError_t ComputeQKStub<__half, __half>(__half *, + __half *, + float *, + uint32_t, + uint32_t, + uint32_t, + uint32_t, + uint32_t, + hipStream_t); diff --git a/libflashinfer/utils/cpu_reference_hip.h b/libflashinfer/utils/cpu_reference_hip.h index bcbe3a8c5a..922ee1532e 100644 --- a/libflashinfer/utils/cpu_reference_hip.h +++ b/libflashinfer/utils/cpu_reference_hip.h @@ -9,6 +9,7 @@ #include "flashinfer/attention/generic/page.cuh" #include "flashinfer/attention/generic/pos_enc.cuh" +#include "flashinfer/attention/generic/tensor_info.cuh" // For tensor_info_t #include "utils_hip.h" @@ -86,6 +87,56 @@ inline std::vector apply_llama_rope(const T *input, return rst; } +template +std::vector compute_qk(const std::vector &q, + const std::vector &k, + size_t qo_len, + size_t kv_len, + size_t num_qo_heads, + size_t num_kv_heads, + size_t head_dim, + QKVLayout kv_layout = QKVLayout::kHND) +{ + + assert(num_qo_heads % num_kv_heads == 0); + assert(q.size() == qo_len * num_qo_heads * head_dim); + assert(k.size() == kv_len * num_kv_heads * head_dim); + + std::vector qk_scores(qo_len * num_qo_heads * kv_len); + + DISPATCH_head_dim(head_dim, HEAD_DIM, { + tensor_info_t info(qo_len, kv_len, num_qo_heads, num_kv_heads, + kv_layout, HEAD_DIM); + + for (size_t qo_head_idx = 0; qo_head_idx < num_qo_heads; ++qo_head_idx) + { + const size_t kv_head_idx = qo_head_idx / info.get_group_size(); + + for (size_t q_idx = 0; q_idx < qo_len; ++q_idx) { + for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { + float qk_score = 0.0f; + + // Pure Q*K^T - NO scaling (matching HIP compute_qk) + for (size_t feat_idx = 0; feat_idx < head_dim; ++feat_idx) { + qk_score += fi::con::explicit_casting( + q[info.get_q_elem_offset( + q_idx, qo_head_idx, feat_idx)]) * + fi::con::explicit_casting( + k[info.get_kv_elem_offset( + kv_idx, kv_head_idx, feat_idx)]); + } + + size_t output_idx = + qo_head_idx * qo_len * kv_len + q_idx * kv_len + kv_idx; + qk_scores[output_idx] = qk_score; + } + } + } + }); + + return std::move(qk_scores); +} + template std::vector single_mha(const std::vector &q, From 6003a3f6dc45ef0552d57967ec592e1400b97ad8 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Mon, 18 Aug 2025 13:31:52 -0400 Subject: [PATCH 047/109] Standalone tester for compute_qk --- .../flashinfer/attention/generic/prefill.cuh | 4 - .../hip/test_load_q_global_smem_kernel.cpp | 41 +-- .../tests/hip/test_single_prefill.cpp | 189 +++++++++- libflashinfer/utils/compute_qk_stub.cu | 187 ---------- libflashinfer/utils/compute_qk_stub.cuh | 332 ++++++++++++++++++ libflashinfer/utils/cpu_reference_hip.h | 3 +- .../utils/flashinfer_prefill_ops.hip.h | 73 ++++ 7 files changed, 606 insertions(+), 223 deletions(-) delete mode 100644 libflashinfer/utils/compute_qk_stub.cu create mode 100644 libflashinfer/utils/compute_qk_stub.cuh diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh index 00c98fb48f..f73eb1f55b 100644 --- a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -980,10 +980,6 @@ __device__ __forceinline__ void compute_qk( } else { #if defined(PLATFORM_HIP_DEVICE) - // TODO: We need to validate the layout of K. Whether a - // transposed load is needed or whether K is pre-transposed. - // k_smem->load_fragment_4x4_transposed(*k_smem_offset_r, - // b_frag); k_smem->load_fragment(*k_smem_offset_r, b_frag); #else k_smem->load_fragment(*k_smem_offset_r, b_frag); diff --git a/libflashinfer/tests/hip/test_load_q_global_smem_kernel.cpp b/libflashinfer/tests/hip/test_load_q_global_smem_kernel.cpp index 44ba713802..c57628fbf7 100644 --- a/libflashinfer/tests/hip/test_load_q_global_smem_kernel.cpp +++ b/libflashinfer/tests/hip/test_load_q_global_smem_kernel.cpp @@ -17,18 +17,17 @@ using namespace flashinfer; // CPU Reference Implementation for Q Loading template -std::vector cpu_reference_q_smem_layout( - const std::vector &q_global, - size_t qo_len, - size_t num_qo_heads, - size_t head_dim, - size_t q_stride_n, - size_t q_stride_h, - size_t qo_packed_idx_base, - uint32_t group_size, - size_t - smem_height, // Expected shared memory height (16 for single MMA block) - size_t smem_width) // Expected shared memory width (head_dim) +std::vector +cpu_reference_q_smem_layout(const std::vector &q_global, + size_t qo_len, + size_t num_qo_heads, + size_t head_dim, + size_t q_stride_n, + size_t q_stride_h, + size_t qo_packed_idx_base, + uint32_t group_size, + size_t smem_height, + size_t smem_width) { std::vector q_smem_expected(smem_height * smem_width, DTypeQ(0)); @@ -64,14 +63,13 @@ uint_fastdiv create_group_size_div(uint32_t group_size) // Test kernel for Q loading template -__global__ void test_q_loading_kernel( - typename KTraits::DTypeQ *q_global, - typename KTraits::DTypeQ *q_smem_output, // Output: dump of shared memory - uint32_t qo_packed_idx_base, - uint32_t qo_len, - uint32_t q_stride_n, - uint32_t q_stride_h, - uint_fastdiv group_size_div) +__global__ void test_q_loading_kernel(typename KTraits::DTypeQ *q_global, + typename KTraits::DTypeQ *q_smem_output, + uint32_t qo_packed_idx_base, + uint32_t qo_len, + uint32_t q_stride_n, + uint32_t q_stride_h, + uint_fastdiv group_size_div) { // Set up shared memory extern __shared__ uint8_t smem[]; @@ -81,9 +79,6 @@ __global__ void test_q_loading_kernel( smem_t q_smem( smem_storage.q_smem); - // // Fast integer division for group_size - // uint_fastdiv group_size_div(group_size); - // Call the function we're testing load_q_global_smem(qo_packed_idx_base, qo_len, q_global, q_stride_n, q_stride_h, group_size_div, &q_smem, diff --git a/libflashinfer/tests/hip/test_single_prefill.cpp b/libflashinfer/tests/hip/test_single_prefill.cpp index c8ada17889..96b1ea782c 100644 --- a/libflashinfer/tests/hip/test_single_prefill.cpp +++ b/libflashinfer/tests/hip/test_single_prefill.cpp @@ -18,11 +18,150 @@ using namespace flashinfer; template -void _TestComputeQKCorrectness() +void _TestComputeQKCorrectness(size_t qo_len, + size_t kv_len, + size_t num_qo_heads, + size_t num_kv_heads, + size_t head_dim, + bool causal, + QKVLayout kv_layout, + PosEncodingMode pos_encoding_mode, + bool use_fp16_qk_reduction, + float rtol = 1e-3, + float atol = 1e-3) { + std::cout << "Testing compute_qk: qo_len=" << qo_len + << ", kv_len=" << kv_len << ", num_qo_heads=" << num_qo_heads + << ", num_kv_heads=" << num_kv_heads << ", head_dim=" << head_dim + << std::endl; + + // Generate test data (same as original test) std::vector q(qo_len * num_qo_heads * head_dim); std::vector k(kv_len * num_kv_heads * head_dim); - std::vector o(qo_len * num_qo_heads * head_dim); + std::vector v(kv_len * num_kv_heads * + head_dim); // Still needed for params + std::vector o(qo_len * num_qo_heads * + head_dim); // Still needed for params + + utils::vec_normal_(q); + utils::vec_normal_(k); + utils::vec_normal_(v); // Initialize even though we won't use it + utils::vec_zero_(o); + + // GPU memory allocation (same pattern as original) + DTypeQ *q_d; + FI_GPU_CALL(hipMalloc(&q_d, q.size() * sizeof(DTypeQ))); + FI_GPU_CALL(hipMemcpy(q_d, q.data(), q.size() * sizeof(DTypeQ), + hipMemcpyHostToDevice)); + + DTypeKV *k_d; + FI_GPU_CALL(hipMalloc(&k_d, k.size() * sizeof(DTypeKV))); + FI_GPU_CALL(hipMemcpy(k_d, k.data(), k.size() * sizeof(DTypeKV), + hipMemcpyHostToDevice)); + + DTypeKV *v_d; + FI_GPU_CALL(hipMalloc(&v_d, v.size() * sizeof(DTypeKV))); + FI_GPU_CALL(hipMemcpy(v_d, v.data(), v.size() * sizeof(DTypeKV), + hipMemcpyHostToDevice)); + + DTypeO *o_d; + FI_GPU_CALL(hipMalloc(&o_d, o.size() * sizeof(DTypeO))); + FI_GPU_CALL(hipMemcpy(o_d, o.data(), o.size() * sizeof(DTypeO), + hipMemcpyHostToDevice)); + + DTypeO *tmp_d; + FI_GPU_CALL(hipMalloc(&tmp_d, 16 * 1024 * 1024 * sizeof(DTypeO))); + + // Allocate output buffer for QK scores + const size_t qk_output_size = qo_len * kv_len * num_qo_heads; + float *qk_scores_d; + FI_GPU_CALL(hipMalloc(&qk_scores_d, qk_output_size * sizeof(float))); + + // Call ComputeQKStubCaller instead of SinglePrefillWithKVCache + hipError_t status = + flashinfer::ComputeQKStubCaller( + q_d, k_d, v_d, o_d, tmp_d, + /*lse=*/nullptr, qk_scores_d, // Add qk_scores_d parameter + num_qo_heads, num_kv_heads, qo_len, kv_len, head_dim, causal, + kv_layout, pos_encoding_mode, use_fp16_qk_reduction); + + EXPECT_EQ(status, hipSuccess) + << "ComputeQKStubCaller kernel launch failed, error message: " + << hipGetErrorString(status); + + // Get GPU QK scores + std::vector gpu_qk_scores(qk_output_size); + FI_GPU_CALL(hipMemcpy(gpu_qk_scores.data(), qk_scores_d, + qk_output_size * sizeof(float), + hipMemcpyDeviceToHost)); + + // Check if GPU output is not empty + bool isEmpty = gpu_qk_scores.empty(); + EXPECT_EQ(isEmpty, false) << "GPU QK scores vector is empty"; + + // Compute CPU reference using our cpu_reference::compute_qk + std::vector cpu_qk_scores = cpu_reference::compute_qk( + q, k, qo_len, kv_len, num_qo_heads, num_kv_heads, head_dim, kv_layout); + + // Validate results (same pattern as original test) + size_t num_results_error_atol = 0; + bool nan_detected = false; + + // Compare element-by-element + size_t comparison_size = + std::min(gpu_qk_scores.size(), cpu_qk_scores.size()); + for (size_t i = 0; i < comparison_size; ++i) { + float gpu_val = gpu_qk_scores[i]; + float cpu_val = cpu_qk_scores[i]; + + if (isnan(gpu_val)) { + nan_detected = true; + } + + if (!utils::isclose(cpu_val, gpu_val, rtol, atol)) { + num_results_error_atol++; + if (num_results_error_atol <= 10) + { // Only print first 10 mismatches + std::cout << "QK mismatch at i=" << i << ", cpu_val=" << cpu_val + << ", gpu_val=" << gpu_val << std::endl; + } + } + } + + // Calculate and report accuracy + float result_accuracy = + 1.0f - float(num_results_error_atol) / float(comparison_size); + + std::cout << "compute_qk results: num_qo_heads=" << num_qo_heads + << ", num_kv_heads=" << num_kv_heads << ", qo_len=" << qo_len + << ", kv_len=" << kv_len << ", head_dim=" << head_dim + << ", causal=" << causal + << ", kv_layout=" << QKVLayoutToString(kv_layout) + << ", pos_encoding_mode=" + << PosEncodingModeToString(pos_encoding_mode) + << ", qk_accuracy=" << result_accuracy << " (" + << num_results_error_atol << "/" << comparison_size + << " mismatches)" << std::endl; + + // Print some sample values for debugging + std::cout << "Sample QK scores (first 10): GPU vs CPU" << std::endl; + for (size_t i = 0; i < std::min(size_t(10), comparison_size); ++i) { + std::cout << " [" << i << "] GPU=" << gpu_qk_scores[i] + << ", CPU=" << cpu_qk_scores[i] << std::endl; + } + + // Assertions (slightly relaxed for initial testing) + EXPECT_GT(result_accuracy, 0.80) + << "compute_qk accuracy too low"; // Start with 80% + EXPECT_FALSE(nan_detected) << "NaN detected in compute_qk results"; + + // Cleanup + FI_GPU_CALL(hipFree(q_d)); + FI_GPU_CALL(hipFree(k_d)); + FI_GPU_CALL(hipFree(v_d)); + FI_GPU_CALL(hipFree(o_d)); + FI_GPU_CALL(hipFree(tmp_d)); + FI_GPU_CALL(hipFree(qk_scores_d)); } template @@ -395,23 +534,59 @@ void _TestSinglePrefillKernelCorrectness(size_t qo_len, // } // #endif +// int main(int argc, char **argv) +// { +// // ::testing::InitGoogleTest(&argc, argv); +// // return RUN_ALL_TESTS(); +// using DTypeIn = __half; +// using DTypeO = __half; +// bool use_fp16_qk_reduction = false; +// size_t qo_len = 399; +// size_t kv_len = 533; +// size_t num_heads = 1; +// size_t head_dim = 64; +// bool causal = false; +// size_t pos_encoding_mode = 0; +// size_t kv_layout = 0; + +// _TestSinglePrefillKernelCorrectness( +// qo_len, kv_len, num_heads, num_heads, head_dim, causal, +// QKVLayout(kv_layout), PosEncodingMode(pos_encoding_mode), +// use_fp16_qk_reduction); +// } + int main(int argc, char **argv) { - // ::testing::InitGoogleTest(&argc, argv); - // return RUN_ALL_TESTS(); + // Test compute_qk first with simple parameters + std::cout << "=== Testing compute_qk function ===" << std::endl; using DTypeIn = __half; using DTypeO = __half; bool use_fp16_qk_reduction = false; + bool causal = false; + size_t pos_encoding_mode = 0; + size_t kv_layout = 0; + + // Start with small dimensions for easier debugging + _TestComputeQKCorrectness( + 16, // qo_len - small for debugging + 32, // kv_len + 1, // num_qo_heads - single head + 1, // num_kv_heads - single head + 64, // head_dim + causal, QKVLayout(kv_layout), PosEncodingMode(pos_encoding_mode), + use_fp16_qk_reduction); + + std::cout << "\n=== Testing full single prefill ===" << std::endl; + // Your existing test... size_t qo_len = 399; size_t kv_len = 533; size_t num_heads = 1; size_t head_dim = 64; - bool causal = false; - size_t pos_encoding_mode = 0; - size_t kv_layout = 0; _TestSinglePrefillKernelCorrectness( qo_len, kv_len, num_heads, num_heads, head_dim, causal, QKVLayout(kv_layout), PosEncodingMode(pos_encoding_mode), use_fp16_qk_reduction); + + return 0; } diff --git a/libflashinfer/utils/compute_qk_stub.cu b/libflashinfer/utils/compute_qk_stub.cu deleted file mode 100644 index 061340eda6..0000000000 --- a/libflashinfer/utils/compute_qk_stub.cu +++ /dev/null @@ -1,187 +0,0 @@ -// SPDX-FileCopyrightText: 2025 Advanced Micro Devices, Inc. -// SPDX-License-Identifier: Apache-2.0 - -#include "flashinfer/attention/generic/prefill.cuh" -#include "gpu_iface/gpu_runtime_compat.hpp" - -using namespace flashinfer; - -template -__global__ void ComputeQKStubKernel(typename KTraits::DTypeQ *q, - typename KTraits::DTypeKV *k, - float *qk_scores_output, - uint32_t qo_len, - uint32_t kv_len, - uint32_t num_qo_heads, - uint32_t num_kv_heads, - uint32_t q_stride_n, - uint32_t q_stride_h, - uint32_t k_stride_n, - uint32_t k_stride_h, - uint_fastdiv group_size) -{ - using DTypeQ = typename KTraits::DTypeQ; - using DTypeKV = typename KTraits::DTypeKV; - using DTypeQKAccum = typename KTraits::DTypeQKAccum; - - extern __shared__ uint8_t smem[]; - typename KTraits::SharedStorage &smem_storage = - reinterpret_cast(smem); - - // Initialize shared memory objects - smem_t q_smem( - smem_storage.q_smem); - smem_t k_smem( - smem_storage.k_smem); - - const uint32_t lane_idx = threadIdx.x; - const uint32_t warp_idx = get_warp_idx(threadIdx.y, threadIdx.z); - const uint32_t kv_head_idx = blockIdx.z; - - // 1. Load Q into shared memory (same as SinglePrefillWithKVCacheDevice) - const uint32_t qo_packed_idx_base = (blockIdx.x * KTraits::NUM_WARPS_Q + - get_warp_idx_q(threadIdx.y)) * - KTraits::NUM_MMA_Q * 16; - - DTypeQ *q_ptr_base = q + (kv_head_idx * group_size) * q_stride_h; - - load_q_global_smem(qo_packed_idx_base, qo_len, q_ptr_base, - q_stride_n, q_stride_h, group_size, &q_smem, - threadIdx); - - // 2. Load K into shared memory (same as SinglePrefillWithKVCacheDevice) - DTypeKV *k_ptr = k + - (warp_idx * KTraits::KV_THR_LAYOUT_ROW + - lane_idx / KTraits::KV_THR_LAYOUT_COL) * - k_stride_n + - kv_head_idx * k_stride_h + - (lane_idx % KTraits::KV_THR_LAYOUT_COL) * - upcast_size(); - - uint32_t k_smem_offset_w = - k_smem.template get_permuted_offset( - warp_idx * KTraits::KV_THR_LAYOUT_ROW + - lane_idx / KTraits::KV_THR_LAYOUT_COL, - lane_idx % KTraits::KV_THR_LAYOUT_COL); - - produce_kv( - k_smem, &k_smem_offset_w, &k_ptr, k_stride_n, 0, kv_len, threadIdx); - - __syncthreads(); - - // 3. Set up fragment offsets for compute_qk (same as - // SinglePrefillWithKVCacheDevice) - uint32_t q_smem_offset_r = - q_smem.template get_permuted_offset( - get_warp_idx_q(threadIdx.y) * KTraits::NUM_MMA_Q * 16 + - lane_idx % 16, - lane_idx / 16); - - uint32_t k_smem_offset_r = - k_smem.template get_permuted_offset( - get_warp_idx_kv(threadIdx.z) * KTraits::NUM_MMA_KV * 16 + - KTraits::HALF_ELEMS_PER_THREAD * (lane_idx / 16) + - lane_idx % KTraits::HALF_ELEMS_PER_THREAD, - (lane_idx % 16) / KTraits::HALF_ELEMS_PER_THREAD); - - // 4. Call compute_qk (the function we want to test) - DTypeQKAccum s_frag[KTraits::NUM_MMA_Q][KTraits::NUM_MMA_KV] - [KTraits::HALF_ELEMS_PER_THREAD]; - - compute_qk(&q_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, - s_frag); - - // 5. Extract attention scores from s_frag to global memory - // Simple extraction for validation - use first warp only - if (get_warp_idx_q(threadIdx.y) == 0 && - get_warp_idx_kv(threadIdx.z) == 0) - { - // Map from MMA fragment layout to sequence indices - for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { - for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { - for (uint32_t reg_id = 0; - reg_id < KTraits::HALF_ELEMS_PER_THREAD; ++reg_id) - { - // Calculate global Q,K indices from fragment indices -#if defined(PLATFORM_HIP_DEVICE) - uint32_t q_idx = - mma_q * 16 + - reg_id % KTraits::NUM_ACCUM_ROWS_PER_THREAD; - uint32_t kv_idx = - mma_kv * 16 + - 2 * (lane_idx % KTraits::THREADS_PER_MATRIX_ROW_SET) + - 8 * (reg_id / 2) + reg_id % 2; -#else - uint32_t q_idx = mma_q * 16 + (reg_id % 4) / 2; - uint32_t kv_idx = - mma_kv * 16 + 8 * (reg_id / 4) + reg_id % 2; -#endif - - if (q_idx < qo_len && kv_idx < kv_len) { - uint32_t output_idx = q_idx * kv_len + kv_idx; - qk_scores_output[output_idx] = - float(s_frag[mma_q][mma_kv][reg_id]); - } - } - } - } - } -} - -template -hipError_t ComputeQKStub(DTypeQ *q, - DTypeKV *k, - float *qk_scores_output, - uint32_t qo_len, - uint32_t kv_len, - uint32_t num_qo_heads, - uint32_t num_kv_heads, - uint32_t head_dim, - hipStream_t stream = nullptr) -{ - // Use same KernelTraits selection as SinglePrefillWithKVCache - constexpr uint32_t NUM_MMA_D_QK = 4; // head_dim=64 -> 64/16=4 - constexpr uint32_t NUM_MMA_D_VO = 4; - constexpr uint32_t CTA_TILE_Q = 16; - constexpr uint32_t NUM_MMA_Q = 1; - constexpr uint32_t NUM_MMA_KV = 1; - constexpr uint32_t NUM_WARPS_Q = 1; - constexpr uint32_t NUM_WARPS_KV = 1; - - using KTraits = - KernelTraits>; - - // Launch configuration (same pattern as SinglePrefillWithKVCache) - const uint32_t group_size = num_qo_heads / num_kv_heads; - const uint_fastdiv group_size_fastdiv(group_size); - - dim3 block_size(KTraits::NUM_THREADS, 1, 1); - dim3 grid_size(1, 1, num_kv_heads); - size_t shared_mem_size = sizeof(typename KTraits::SharedStorage); - - const uint32_t q_stride_n = num_qo_heads * head_dim; - const uint32_t q_stride_h = head_dim; - const uint32_t k_stride_n = num_kv_heads * head_dim; - const uint32_t k_stride_h = head_dim; - - ComputeQKStubKernel - <<>>( - q, k, qk_scores_output, qo_len, kv_len, num_qo_heads, num_kv_heads, - q_stride_n, q_stride_h, k_stride_n, k_stride_h, group_size_fastdiv); - - return hipGetLastError(); -} - -// Explicit instantiation for common types -template hipError_t ComputeQKStub<__half, __half>(__half *, - __half *, - float *, - uint32_t, - uint32_t, - uint32_t, - uint32_t, - uint32_t, - hipStream_t); diff --git a/libflashinfer/utils/compute_qk_stub.cuh b/libflashinfer/utils/compute_qk_stub.cuh new file mode 100644 index 0000000000..6c8da0f75e --- /dev/null +++ b/libflashinfer/utils/compute_qk_stub.cuh @@ -0,0 +1,332 @@ +// SPDX-FileCopyrightText: 2025 Advanced Micro Devices, Inc. +// SPDX-License-Identifier: Apache-2.0 + +#include "flashinfer/attention/generic/default_prefill_params.cuh" +#include "flashinfer/attention/generic/prefill.cuh" +#include "gpu_iface/gpu_runtime_compat.hpp" + +using namespace flashinfer; + +template +__device__ __forceinline__ void +ComputeQKStubKernelDevice(const Params params, + typename KTraits::SharedStorage &smem_storage, + float *qk_scores_output, + const dim3 tid = threadIdx, + const uint32_t bx = blockIdx.x, + const uint32_t chunk_idx = blockIdx.y, + const uint32_t kv_head_idx = blockIdx.z, + const uint32_t num_chunks = gridDim.y, + const uint32_t num_kv_heads = gridDim.z) +{ + using DTypeKV = typename Params::DTypeKV; + using DTypeQ = typename Params::DTypeQ; + using DTypeQKAccum = typename KTraits::DTypeQKAccum; + + [[maybe_unused]] constexpr uint32_t NUM_MMA_Q = KTraits::NUM_MMA_Q; + [[maybe_unused]] constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; + [[maybe_unused]] constexpr uint32_t NUM_MMA_D_QK = KTraits::NUM_MMA_D_QK; + [[maybe_unused]] constexpr uint32_t HEAD_DIM_QK = KTraits::HEAD_DIM_QK; + [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_Q = + KTraits::UPCAST_STRIDE_Q; + [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_K = + KTraits::UPCAST_STRIDE_K; + [[maybe_unused]] constexpr uint32_t CTA_TILE_Q = KTraits::CTA_TILE_Q; + [[maybe_unused]] constexpr uint32_t CTA_TILE_KV = KTraits::CTA_TILE_KV; + [[maybe_unused]] constexpr uint32_t NUM_WARPS_Q = KTraits::NUM_WARPS_Q; + [[maybe_unused]] constexpr uint32_t NUM_WARPS_KV = KTraits::NUM_WARPS_KV; + [[maybe_unused]] constexpr SwizzleMode SWIZZLE_MODE_Q = + KTraits::SWIZZLE_MODE_Q; + [[maybe_unused]] constexpr SwizzleMode SWIZZLE_MODE_KV = + KTraits::SWIZZLE_MODE_KV; + [[maybe_unused]] constexpr uint32_t KV_THR_LAYOUT_ROW = + KTraits::KV_THR_LAYOUT_ROW; + [[maybe_unused]] constexpr uint32_t KV_THR_LAYOUT_COL = + KTraits::KV_THR_LAYOUT_COL; + [[maybe_unused]] constexpr uint32_t HALF_ELEMS_PER_THREAD = + KTraits::HALF_ELEMS_PER_THREAD; + [[maybe_unused]] constexpr uint32_t VECTOR_BIT_WIDTH = + KTraits::VECTOR_BIT_WIDTH; + + DTypeQ *q = params.q; + DTypeKV *k = params.k; + + const uint32_t qo_len = params.qo_len; + const uint32_t kv_len = params.kv_len; + + const uint32_t q_stride_n = params.q_stride_n; + const uint32_t q_stride_h = params.q_stride_h; + const uint32_t k_stride_n = params.k_stride_n; + const uint32_t k_stride_h = params.k_stride_h; + const uint_fastdiv &group_size = params.group_size; + + static_assert(sizeof(DTypeQ) == 2); + const uint32_t lane_idx = tid.x, + warp_idx = get_warp_idx(tid.y, tid.z); + const uint32_t chunk_start = 0; + const uint32_t chunk_size = kv_len; + + auto block = cg::this_thread_block(); + DTypeQKAccum s_frag[NUM_MMA_Q][NUM_MMA_KV][HALF_ELEMS_PER_THREAD]; + + // cooperative fetch q fragment from gmem to reg + const uint32_t qo_packed_idx_base = + (bx * NUM_WARPS_Q + get_warp_idx_q(tid.y)) * NUM_MMA_Q * 16; + smem_t qo_smem( + smem_storage.q_smem); + DTypeQ *q_ptr_base = q + (kv_head_idx * group_size) * q_stride_h; + + uint32_t q_smem_offset_r = + qo_smem.template get_permuted_offset( + get_warp_idx_q(tid.y) * NUM_MMA_Q * 16 + lane_idx % 16, + lane_idx / 16); + + load_q_global_smem(qo_packed_idx_base, qo_len, q_ptr_base, + q_stride_n, q_stride_h, group_size, &qo_smem, + tid); + + memory::commit_group(); + smem_t k_smem( + smem_storage.k_smem); + DTypeKV *k_ptr = k + + (chunk_start + warp_idx * KV_THR_LAYOUT_ROW + + lane_idx / KV_THR_LAYOUT_COL) * + k_stride_n + + kv_head_idx * k_stride_h + + (lane_idx % KV_THR_LAYOUT_COL) * + upcast_size(); + + uint32_t k_smem_offset_r = + k_smem.template get_permuted_offset( + get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + + HALF_ELEMS_PER_THREAD * (lane_idx / 16) + + lane_idx % HALF_ELEMS_PER_THREAD, + (lane_idx % 16) / HALF_ELEMS_PER_THREAD), + k_smem_offset_w = + k_smem.template get_permuted_offset( + warp_idx * KV_THR_LAYOUT_ROW + + lane_idx / KV_THR_LAYOUT_COL, + lane_idx % KV_THR_LAYOUT_COL); + produce_kv( + k_smem, &k_smem_offset_w, &k_ptr, k_stride_n, 0, chunk_size, tid); + memory::commit_group(); + + memory::wait_group<1>(); + block.sync(); + // compute attention score + compute_qk(&qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, + s_frag); + memory::wait_group<0>(); + block.sync(); + + // Extract QK scores from s_frag to global memory + if (get_warp_idx_q(tid.y) == 0 && + get_warp_idx_kv(tid.z) == 0) + { + for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { + for (uint32_t mma_kv = 0; mma_kv < NUM_MMA_KV; ++mma_kv) { + for (uint32_t reg_id = 0; reg_id < HALF_ELEMS_PER_THREAD; + ++reg_id) + { + // Map from MMA fragment layout to sequence indices + + // CDNA3 mapping + uint32_t q_idx = + mma_q * 16 + + reg_id % KTraits::NUM_ACCUM_ROWS_PER_THREAD; + uint32_t kv_idx = + mma_kv * 16 + + 2 * (lane_idx % KTraits::THREADS_PER_MATRIX_ROW_SET) + + 8 * (reg_id / 2) + reg_id % 2; + + if (q_idx < qo_len && kv_idx < kv_len) { + // Match CPU layout: [qo_head_idx][q_idx][kv_idx] + uint32_t qo_head_idx = + kv_head_idx * + group_size; // Simple for single head case + uint32_t output_idx = qo_head_idx * qo_len * kv_len + + q_idx * kv_len + kv_idx; + qk_scores_output[output_idx] = + float(s_frag[mma_q][mma_kv][reg_id]); + } + } + } + } + } +} + +template +__global__ __launch_bounds__(KTraits::NUM_THREADS) void ComputeQKStubKernel( + const __grid_constant__ Params params, + float *qk_scores_output) +{ + extern __shared__ uint8_t smem[]; + auto &smem_storage = + reinterpret_cast(smem); + ComputeQKStubKernelDevice(params, smem_storage, qk_scores_output); +} + +template +gpuError_t ComputeQKStubDispatched(Params params, + typename Params::DTypeO *tmp, + float *qk_scores_output, + gpuStream_t stream) +{ + using DTypeQ = typename Params::DTypeQ; + using DTypeKV = typename Params::DTypeKV; + using DTypeO = typename Params::DTypeO; + const uint32_t num_qo_heads = params.num_qo_heads; + const uint32_t num_kv_heads = params.num_kv_heads; + const uint32_t qo_len = params.qo_len; + const uint32_t kv_len = params.kv_len; + if (kv_len < qo_len && MASK_MODE == MaskMode::kCausal) { + std::ostringstream err_msg; + err_msg << "When mask_mode is set to MaskMode::kCausal, kv_len must be " + "greater than or equal to qo_len, got kv_len" + << kv_len << " and qo_len " << qo_len; + FLASHINFER_ERROR(err_msg.str()); + } + + const uint32_t group_size = num_qo_heads / num_kv_heads; + constexpr uint32_t NUM_MMA_D_QK = HEAD_DIM_QK / 16; + constexpr uint32_t NUM_MMA_D_VO = HEAD_DIM_VO / 16; + int64_t packed_qo_len = qo_len * group_size; + uint32_t cta_tile_q = FA2DetermineCtaTileQ(packed_qo_len, HEAD_DIM_VO); + + DISPATCH_CTA_TILE_Q(cta_tile_q, CTA_TILE_Q, { + constexpr uint32_t NUM_WARPS_Q = get_num_warps_q(CTA_TILE_Q); + constexpr uint32_t NUM_WARPS_KV = get_num_warps_kv(CTA_TILE_Q); + constexpr uint32_t NUM_MMA_Q = get_num_mma_q(CTA_TILE_Q); + + using DTypeQKAccum = + typename std::conditional, + half, float>::type; + + int dev_id = 0; + FI_GPU_CALL(gpuGetDevice(&dev_id)); + int max_smem_per_sm = getMaxSharedMemPerMultiprocessor(dev_id); + // we expect each sm execute two threadblocks + const int num_ctas_per_sm = + max_smem_per_sm >= 2 * (CTA_TILE_Q * HEAD_DIM_QK * sizeof(DTypeQ) + + (HEAD_DIM_QK + HEAD_DIM_VO) * 16 * + NUM_WARPS_KV * sizeof(DTypeKV)) + ? 2 + : 1; + const int max_smem_per_threadblock = max_smem_per_sm / num_ctas_per_sm; + + const uint32_t max_num_mma_kv_reg = + (HEAD_DIM_VO >= 128 && NUM_MMA_Q == 2 && + POS_ENCODING_MODE == PosEncodingMode::kRoPELlama && + !USE_FP16_QK_REDUCTION) + ? 2 + : (8 / NUM_MMA_Q); + const uint32_t max_num_mma_kv_smem = + (max_smem_per_threadblock - + CTA_TILE_Q * HEAD_DIM_QK * sizeof(DTypeQ)) / + ((HEAD_DIM_QK + HEAD_DIM_VO) * 16 * NUM_WARPS_KV * sizeof(DTypeKV)); + + // control NUM_MMA_KV for maximum warp occupancy + DISPATCH_NUM_MMA_KV( + min(max_num_mma_kv_smem, max_num_mma_kv_reg), NUM_MMA_KV, { + using KTraits = + KernelTraits; + if constexpr (KTraits::IsInvalid()) { + // Invalid configuration, skip + std::ostringstream err_msg; + err_msg << "FlashInfer Internal Error: Invalid " + "configuration : NUM_MMA_Q=" + << NUM_MMA_Q << " NUM_MMA_D_QK=" << NUM_MMA_D_QK + << " NUM_MMA_D_VO=" << NUM_MMA_D_VO + << " NUM_MMA_KV=" << NUM_MMA_KV + << " NUM_WARPS_Q=" << NUM_WARPS_Q + << " NUM_WARPS_KV=" << NUM_WARPS_KV + << " please create an issue " + "(https://github.com/flashinfer-ai/flashinfer/" + "issues)" + " and report the issue to the developers."; + FLASHINFER_ERROR(err_msg.str()); + } + else { + constexpr uint32_t num_threads = + (NUM_WARPS_Q * NUM_WARPS_KV) * WARP_SIZE; + auto kernel = ComputeQKStubKernel; + size_t smem_size = sizeof(typename KTraits::SharedStorage); + FI_GPU_CALL(gpuFuncSetAttribute( + kernel, gpuFuncAttributeMaxDynamicSharedMemorySize, + smem_size)); + int num_blocks_per_sm = 0; + int num_sm = 0; + FI_GPU_CALL(gpuDeviceGetAttribute( + &num_sm, gpuDevAttrMultiProcessorCount, dev_id)); + FI_GPU_CALL(gpuOccupancyMaxActiveBlocksPerMultiprocessor( + &num_blocks_per_sm, kernel, num_threads, smem_size)); + uint32_t max_num_kv_chunks = + (num_blocks_per_sm * num_sm) / + (num_kv_heads * + ceil_div(qo_len * group_size, CTA_TILE_Q)); + uint32_t num_chunks; + if (max_num_kv_chunks > 0) { + uint32_t chunk_size = + max(ceil_div(kv_len, max_num_kv_chunks), 256); + num_chunks = ceil_div(kv_len, chunk_size); + } + else { + num_chunks = 0; + } + + if (num_chunks <= 1 || tmp == nullptr) { + // Enough parallelism, do not split-kv + params.partition_kv = false; + void *args[] = {(void *)¶ms, + (void *)&qk_scores_output}; + dim3 nblks(ceil_div(qo_len * group_size, CTA_TILE_Q), 1, + num_kv_heads); + dim3 nthrs(WARP_SIZE, NUM_WARPS_Q, NUM_WARPS_KV); + FI_GPU_CALL(gpuLaunchKernel((void *)kernel, nblks, + nthrs, args, smem_size, + stream)); + } + else { + // Use cooperative groups to increase occupancy + params.partition_kv = true; + float *tmp_lse = + (float *)(tmp + num_chunks * qo_len * num_qo_heads * + HEAD_DIM_VO); + auto o = params.o; + auto lse = params.lse; + params.o = tmp; + params.lse = tmp_lse; + void *args[] = {(void *)¶ms}; + dim3 nblks(ceil_div(qo_len * group_size, CTA_TILE_Q), + num_chunks, num_kv_heads); + dim3 nthrs(WARP_SIZE, NUM_WARPS_Q, NUM_WARPS_KV); + FI_GPU_CALL(gpuLaunchKernel((void *)kernel, nblks, + nthrs, args, smem_size, + stream)); + if constexpr (AttentionVariant::use_softmax) { + FI_GPU_CALL(MergeStates( + tmp, tmp_lse, o, lse, num_chunks, qo_len, + num_qo_heads, HEAD_DIM_VO, stream)); + } + else { + FI_GPU_CALL(AttentionSum(tmp, o, num_chunks, qo_len, + num_qo_heads, HEAD_DIM_VO, + stream)); + } + } + } + }) + }); + return gpuSuccess; +} diff --git a/libflashinfer/utils/cpu_reference_hip.h b/libflashinfer/utils/cpu_reference_hip.h index 922ee1532e..7883907164 100644 --- a/libflashinfer/utils/cpu_reference_hip.h +++ b/libflashinfer/utils/cpu_reference_hip.h @@ -9,7 +9,6 @@ #include "flashinfer/attention/generic/page.cuh" #include "flashinfer/attention/generic/pos_enc.cuh" -#include "flashinfer/attention/generic/tensor_info.cuh" // For tensor_info_t #include "utils_hip.h" @@ -134,7 +133,7 @@ std::vector compute_qk(const std::vector &q, } }); - return std::move(qk_scores); + return qk_scores; } template diff --git a/libflashinfer/utils/flashinfer_prefill_ops.hip.h b/libflashinfer/utils/flashinfer_prefill_ops.hip.h index db4a2694e5..f4d4fabc3c 100644 --- a/libflashinfer/utils/flashinfer_prefill_ops.hip.h +++ b/libflashinfer/utils/flashinfer_prefill_ops.hip.h @@ -7,6 +7,7 @@ #include "utils_hip.h" +#include "compute_qk_stub.cuh" #include "flashinfer/attention/generic/allocator.h" #include "flashinfer/attention/generic/default_prefill_params.cuh" #include "flashinfer/attention/generic/exception.h" @@ -21,6 +22,18 @@ namespace flashinfer { +// template +// hipError_t ComputeQKStubDispatched(Params params, +// typename Params::DTypeO *tmp, +// float *qk_scores_output, +// hipStream_t stream); + template +hipError_t +ComputeQKStubCaller(DTypeQ *q, + DTypeKV *k, + DTypeKV *v, + DTypeO *o, + DTypeO *tmp, + float *lse, + float *qk_scores_output, + uint32_t num_qo_heads, + uint32_t num_kv_heads, + uint32_t qo_len, + uint32_t kv_len, + uint32_t head_dim, + bool causal = true, + QKVLayout kv_layout = QKVLayout::kNHD, + PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, + bool use_fp16_qk_reduction = false, + std::optional maybe_sm_scale = std::nullopt, + float rope_scale = 1.f, + float rope_theta = 1e4, + hipStream_t stream = nullptr) +{ + const float sm_scale = + maybe_sm_scale.value_or(1.f / std::sqrt(float(head_dim))); + const MaskMode mask_mode = causal ? MaskMode::kCausal : MaskMode::kNone; + auto [qo_stride_n, qo_stride_h, kv_stride_n, kv_stride_h] = get_qkv_strides( + kv_layout, kv_len, num_qo_heads, num_kv_heads, head_dim); + DISPATCH_use_fp16_qk_reduction( + static_cast(use_fp16_qk_reduction), USE_FP16_QK_REDUCTION, + {DISPATCH_mask_mode( + mask_mode, MASK_MODE, + {DISPATCH_head_dim( + head_dim, HEAD_DIM, + {DISPATCH_pos_encoding_mode( + pos_encoding_mode, POS_ENCODING_MODE, { + using Params = + SinglePrefillParams; + using AttentionVariant = DefaultAttention< + /*use_custom_mask=*/(MASK_MODE == + MaskMode::kCustom), + /*use_sliding_window=*/false, + /*use_logits_soft_cap=*/false, /*use_alibi=*/false>; + Params params(q, k, v, /*custom_mask=*/nullptr, o, lse, + /*alibi_slopes=*/nullptr, num_qo_heads, + num_kv_heads, qo_len, kv_len, qo_stride_n, + qo_stride_h, kv_stride_n, kv_stride_h, + head_dim, + /*window_left=*/-1, + /*logits_soft_cap=*/0.f, sm_scale, + rope_scale, rope_theta); + return ComputeQKStubDispatched< + HEAD_DIM, HEAD_DIM, POS_ENCODING_MODE, + USE_FP16_QK_REDUCTION, MASK_MODE, AttentionVariant, + Params>(params, tmp, qk_scores_output, stream); + })})})}); + return hipSuccess; +} + } // namespace flashinfer From f37d407acb7ba8cb62483dfe0aa94b9019e36de0 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Tue, 19 Aug 2025 12:50:10 -0400 Subject: [PATCH 048/109] Testing Q read logic --- compile_test.sh | 13 +++ .../flashinfer/attention/generic/prefill.cuh | 24 +++++ .../tests/hip/test_single_prefill.cpp | 92 ++++++++++--------- libflashinfer/utils/cpu_reference_hip.h | 9 ++ .../utils/flashinfer_prefill_ops.hip.h | 12 --- 5 files changed, 97 insertions(+), 53 deletions(-) create mode 100644 compile_test.sh diff --git a/compile_test.sh b/compile_test.sh new file mode 100644 index 0000000000..60ae60e96d --- /dev/null +++ b/compile_test.sh @@ -0,0 +1,13 @@ +amdclang++ -x hip \ + -std=c++17 \ + -I/home/AMD/diptodeb/devel/flashinfer/libflashinfer/include \ + -I/home/AMD/diptodeb/devel/flashinfer/libflashinfer \ + -I${CONDA_PREFIX}/include \ + -Wall \ + -DHIP_ENABLE_WARP_SYNC_BUILTINS=1 \ + -L${CONDA_PREFIX}/lib \ + -lgtest \ + -DDebug \ + -Wl,-rpath=${CONDA_PREFIX}/lib \ + libflashinfer/tests/hip/test_single_prefill.cpp \ + --offload-arch=gfx942 diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh index f73eb1f55b..598e76dbb3 100644 --- a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -2163,6 +2163,30 @@ SinglePrefillWithKVCacheDevice(const Params params, v_smem, &v_smem_offset_w, &v_ptr, v_stride_n, 0, chunk_size, tid); memory::commit_group(); +#if Debug + for (auto mma_q = 0ul; mma_q < 4; ++mma_q) { + uint32_t a_frag[KTraits::INT32_ELEMS_PER_THREAD]; + qo_smem.load_fragment(q_smem_offset_r, a_frag); + int global_idx = (blockIdx.z * gridDim.y * gridDim.x + + blockIdx.y * gridDim.x + blockIdx.x) * + (blockDim.z * blockDim.y * blockDim.x) + + (threadIdx.z * blockDim.y * blockDim.x + + threadIdx.y * blockDim.x + threadIdx.x); + if (global_idx == 0) { + auto frag_T = reinterpret_cast<__half *>(a_frag); + printf("DEBUG: Q Frag in permuted_smem for mma_q %lu \n", + mma_q); + for (auto i = 0ul; i < 4; ++i) { + printf("%f ", (float)(*(frag_T + i))); + } + printf("\n"); + } + + q_smem_offset_r = qo_smem.template advance_offset_by_column<4>( + q_smem_offset_r, 0); + } +#endif + #pragma unroll 1 for (uint32_t iter = 0; iter < num_iterations; ++iter) { memory::wait_group<1>(); diff --git a/libflashinfer/tests/hip/test_single_prefill.cpp b/libflashinfer/tests/hip/test_single_prefill.cpp index 96b1ea782c..44a3b1825e 100644 --- a/libflashinfer/tests/hip/test_single_prefill.cpp +++ b/libflashinfer/tests/hip/test_single_prefill.cpp @@ -77,6 +77,14 @@ void _TestComputeQKCorrectness(size_t qo_len, float *qk_scores_d; FI_GPU_CALL(hipMalloc(&qk_scores_d, qk_output_size * sizeof(float))); + std::cout << "Debug: Kernel launch parameters:" << std::endl; + std::cout << " qo_len=" << qo_len << ", kv_len=" << kv_len << std::endl; + std::cout << " num_qo_heads=" << num_qo_heads + << ", num_kv_heads=" << num_kv_heads << std::endl; + std::cout << " head_dim=" << head_dim << std::endl; + std::cout << " qk_output_size=" << qk_output_size << std::endl; + std::cout << " Launching ComputeQKStubCaller..." << std::endl; + // Call ComputeQKStubCaller instead of SinglePrefillWithKVCache hipError_t status = flashinfer::ComputeQKStubCaller( @@ -85,6 +93,8 @@ void _TestComputeQKCorrectness(size_t qo_len, num_qo_heads, num_kv_heads, qo_len, kv_len, head_dim, causal, kv_layout, pos_encoding_mode, use_fp16_qk_reduction); + std::cout << " Kernel launch status: " << hipGetErrorString(status) + << std::endl; EXPECT_EQ(status, hipSuccess) << "ComputeQKStubCaller kernel launch failed, error message: " << hipGetErrorString(status); @@ -534,59 +544,59 @@ void _TestSinglePrefillKernelCorrectness(size_t qo_len, // } // #endif -// int main(int argc, char **argv) -// { -// // ::testing::InitGoogleTest(&argc, argv); -// // return RUN_ALL_TESTS(); -// using DTypeIn = __half; -// using DTypeO = __half; -// bool use_fp16_qk_reduction = false; -// size_t qo_len = 399; -// size_t kv_len = 533; -// size_t num_heads = 1; -// size_t head_dim = 64; -// bool causal = false; -// size_t pos_encoding_mode = 0; -// size_t kv_layout = 0; - -// _TestSinglePrefillKernelCorrectness( -// qo_len, kv_len, num_heads, num_heads, head_dim, causal, -// QKVLayout(kv_layout), PosEncodingMode(pos_encoding_mode), -// use_fp16_qk_reduction); -// } - int main(int argc, char **argv) { - // Test compute_qk first with simple parameters - std::cout << "=== Testing compute_qk function ===" << std::endl; + // ::testing::InitGoogleTest(&argc, argv); + // return RUN_ALL_TESTS(); using DTypeIn = __half; using DTypeO = __half; bool use_fp16_qk_reduction = false; - bool causal = false; - size_t pos_encoding_mode = 0; - size_t kv_layout = 0; - - // Start with small dimensions for easier debugging - _TestComputeQKCorrectness( - 16, // qo_len - small for debugging - 32, // kv_len - 1, // num_qo_heads - single head - 1, // num_kv_heads - single head - 64, // head_dim - causal, QKVLayout(kv_layout), PosEncodingMode(pos_encoding_mode), - use_fp16_qk_reduction); - - std::cout << "\n=== Testing full single prefill ===" << std::endl; - // Your existing test... size_t qo_len = 399; size_t kv_len = 533; size_t num_heads = 1; size_t head_dim = 64; + bool causal = false; + size_t pos_encoding_mode = 0; + size_t kv_layout = 0; _TestSinglePrefillKernelCorrectness( qo_len, kv_len, num_heads, num_heads, head_dim, causal, QKVLayout(kv_layout), PosEncodingMode(pos_encoding_mode), use_fp16_qk_reduction); - - return 0; } + +// int main(int argc, char **argv) +// { +// // Test compute_qk first with simple parameters +// std::cout << "=== Testing compute_qk function ===" << std::endl; +// using DTypeIn = __half; +// using DTypeO = __half; +// bool use_fp16_qk_reduction = false; +// bool causal = false; +// size_t pos_encoding_mode = 0; +// size_t kv_layout = 0; + +// // Start with small dimensions for easier debugging +// _TestComputeQKCorrectness( +// 16, // qo_len - small for debugging +// 32, // kv_len +// 1, // num_qo_heads - single head +// 1, // num_kv_heads - single head +// 64, // head_dim +// causal, QKVLayout(kv_layout), PosEncodingMode(pos_encoding_mode), +// use_fp16_qk_reduction); + +// std::cout << "\n=== Testing full single prefill ===" << std::endl; +// // Your existing test... +// size_t qo_len = 399; +// size_t kv_len = 533; +// size_t num_heads = 1; +// size_t head_dim = 64; + +// _TestSinglePrefillKernelCorrectness( +// qo_len, kv_len, num_heads, num_heads, head_dim, causal, +// QKVLayout(kv_layout), PosEncodingMode(pos_encoding_mode), +// use_fp16_qk_reduction); + +// return 0; +// } diff --git a/libflashinfer/utils/cpu_reference_hip.h b/libflashinfer/utils/cpu_reference_hip.h index 7883907164..fadfefcb89 100644 --- a/libflashinfer/utils/cpu_reference_hip.h +++ b/libflashinfer/utils/cpu_reference_hip.h @@ -14,6 +14,7 @@ #include #include +#include namespace cpu_reference { @@ -162,6 +163,14 @@ single_mha(const std::vector &q, DISPATCH_head_dim(head_dim, HEAD_DIM, { tensor_info_t info(qo_len, kv_len, num_qo_heads, num_kv_heads, kv_layout, HEAD_DIM); +#if Debug + std::cout << "DEBUG Q (CPU): " << '\n'; + for (auto i = 0ul; i < 64; ++i) { + // q[info.get_q_elem_offset(q_idx, qo_head_idx, feat_idx) + std::cout << (float)q[info.get_q_elem_offset(0, 0, i)] << " "; + } + std::cout << std::endl; +#endif for (size_t qo_head_idx = 0; qo_head_idx < num_qo_heads; ++qo_head_idx) { const size_t kv_head_idx = qo_head_idx / info.get_group_size(); diff --git a/libflashinfer/utils/flashinfer_prefill_ops.hip.h b/libflashinfer/utils/flashinfer_prefill_ops.hip.h index f4d4fabc3c..38b58d4f80 100644 --- a/libflashinfer/utils/flashinfer_prefill_ops.hip.h +++ b/libflashinfer/utils/flashinfer_prefill_ops.hip.h @@ -22,18 +22,6 @@ namespace flashinfer { -// template -// hipError_t ComputeQKStubDispatched(Params params, -// typename Params::DTypeO *tmp, -// float *qk_scores_output, -// hipStream_t stream); - template Date: Fri, 22 Aug 2025 09:56:51 -0400 Subject: [PATCH 049/109] Updated produce_kv --- .../flashinfer/attention/generic/prefill.cuh | 232 ++++++++++++------ .../tests/hip/test_produce_kv_kernel.cpp | 0 libflashinfer/utils/cpu_reference_hip.h | 15 +- 3 files changed, 171 insertions(+), 76 deletions(-) create mode 100644 libflashinfer/tests/hip/test_produce_kv_kernel.cpp diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh index 598e76dbb3..bcbf7ba433 100644 --- a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -141,8 +141,8 @@ struct KernelTraits using SmemBasePtrTy = uint2; static constexpr uint32_t NUM_THREADS = NUM_WARPS_Q * NUM_WARPS_KV * 64; - static constexpr uint32_t WARP_THREAD_COLS = 16; static constexpr uint32_t WARP_THREAD_ROWS = 4; + static constexpr uint32_t WARP_THREAD_COLS = 16; static constexpr uint32_t HALF_ELEMS_PER_THREAD = 4; static constexpr uint32_t INT32_ELEMS_PER_THREAD = 2; static constexpr uint32_t VECTOR_BIT_WIDTH = HALF_ELEMS_PER_THREAD * 16; @@ -152,8 +152,8 @@ struct KernelTraits static constexpr SwizzleMode SWIZZLE_MODE_KV = SwizzleMode::kLinear; // Presently we use 16x4 thread layout for all cases. - static constexpr uint32_t KV_THR_LAYOUT_ROW = 16; - static constexpr uint32_t KV_THR_LAYOUT_COL = 4; + static constexpr uint32_t KV_THR_LAYOUT_ROW = WARP_THREAD_ROWS; + static constexpr uint32_t KV_THR_LAYOUT_COL = WARP_THREAD_COLS; // The constant is defined based on the matrix layout of the "D/C" // accumulator matrix in a D = A*B+C computation. On CDNA3 the D/C matrices // are distributed as four 4x16 bands across the 64 threads. Each thread @@ -170,8 +170,8 @@ struct KernelTraits #else using SmemBasePtrTy = uint4; static constexpr uint32_t NUM_THREADS = NUM_WARPS_Q * NUM_WARPS_KV * 32; - constexpr uint32_t WARP_THREAD_COLS = 8; constexpr uint32_t WARP_THREAD_ROWS = 4; + constexpr uint32_t WARP_THREAD_COLS = 8; constexpr uint32_t HALF_ELEMS_PER_THREAD = 8; constexpr uint32_t INT32_ELEMS_PER_THREAD = 4; constexpr uint32_t VECTOR_BIT_WIDTH = HALF_ELEMS_PER_THREAD * 16; @@ -391,7 +391,62 @@ q_frag_apply_llama_rope_with_pos(T *x_first_half, } template -__device__ __forceinline__ void produce_kv_helper_( +__device__ __forceinline__ void produce_kv_impl_cuda_( + uint32_t warp_idx, + uint32_t lane_idx, + smem_t smem, + uint32_t *smem_offset, + typename KTraits::DTypeKV **gptr, + const uint32_t stride_n, + const uint32_t kv_idx_base, + const uint32_t kv_len) +{ + if constexpr (KTraits::SWIZZLE_MODE_KV == SwizzleMode::k128B) { + uint32_t kv_idx = kv_idx_base + warp_idx * 4 + lane_idx / 8; + // NOTE: NUM_MMA_KV * 4 / NUM_WARPS_Q = NUM_WARPS_KV * NUM_MMA_KV * 4 / + // num_warps + static_assert(NUM_MMA_KV * 4 % NUM_WARPS_Q == 0); +#pragma unroll + for (uint32_t i = 0; i < NUM_MMA_KV * 4 / NUM_WARPS_Q; ++i) { +#pragma unroll + for (uint32_t j = 0; j < NUM_MMA_D / (8 / sizeof(DTypeKV)); ++j) { + smem.load_128b_async(*smem_offset, *gptr, + kv_idx < kv_len); + *smem_offset = + smem.template advance_offset_by_column<8>(*smem_offset, j); + *gptr += 8 * upcast_size(); + } + kv_idx += NUM_WARPS * 4; + *smem_offset = smem.template advance_offset_by_row( + *smem_offset) - + sizeof(DTypeKV) * NUM_MMA_D; + *gptr += NUM_WARPS * 4 * stride_n - + sizeof(DTypeKV) * NUM_MMA_D * upcast_size(); + } + *smem_offset -= CTA_TILE_KV * UPCAST_STRIDE; + } + else { + uint32_t kv_idx = kv_idx_base + warp_idx * 8 + lane_idx / 4; + // NOTE: NUM_MMA_KV * 2 / NUM_WARPS_Q = NUM_WARPS_KV * NUM_MMA_KV * 2 / + // num_warps + static_assert(NUM_MMA_KV * 2 % NUM_WARPS_Q == 0); +#pragma unroll + for (uint32_t i = 0; i < NUM_MMA_KV * 2 / NUM_WARPS_Q; ++i) { + smem.load_128b_async(*smem_offset, *gptr, + kv_idx < kv_len); + *smem_offset = smem.template advance_offset_by_row( + *smem_offset); + kv_idx += NUM_WARPS * 8; + *gptr += NUM_WARPS * 8 * stride_n; + } + *smem_offset -= KTraits::CTA_TILE_KV * UPCAST_STRIDE; + } +} + +template +__device__ __forceinline__ void produce_kv_impl_cdna3_( uint32_t warp_idx, uint32_t lane_idx, smem_t smem, @@ -401,9 +456,10 @@ __device__ __forceinline__ void produce_kv_helper_( const uint32_t kv_idx_base, const uint32_t kv_len) { + static_assert(KTraits::SWIZZLE_MODE_KV == SwizzleMode::kLinear); using DTypeKV = typename KTraits::DTypeKV; - constexpr uint32_t WARP_THREAD_COLS = KTraits::KV_THR_LAYOUT_COL; - constexpr uint32_t WARP_THREAD_ROWS = KTraits::KV_THR_LAYOUT_ROW; + constexpr uint32_t KV_THR_LAYOUT_COL = KTraits::KV_THR_LAYOUT_COL; // 16 + constexpr uint32_t KV_THR_LAYOUT_ROW = KTraits::KV_THR_LAYOUT_ROW; // 4 constexpr uint32_t NUM_WARPS = KTraits::NUM_WARPS; constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; constexpr uint32_t NUM_WARPS_Q = KTraits::NUM_WARPS_Q; @@ -412,35 +468,70 @@ __device__ __forceinline__ void produce_kv_helper_( constexpr uint32_t UPCAST_STRIDE = produce_v ? KTraits::UPCAST_STRIDE_V : KTraits::UPCAST_STRIDE_K; constexpr uint32_t VECTOR_BIT_WIDTH = KTraits::VECTOR_BIT_WIDTH; + constexpr uint32_t HALF_ELEMS_PER_THREAD = + KTraits::HALF_ELEMS_PER_THREAD; // 4 + + // CDNA3-specific constants + constexpr uint32_t SEQUENCES_PER_MMA_TILE = 16; + constexpr uint32_t SEQUENCES_PER_THREAD_GROUP = KV_THR_LAYOUT_ROW; // 4 + constexpr uint32_t THREAD_GROUPS_PER_MMA_TILE = + SEQUENCES_PER_MMA_TILE / SEQUENCES_PER_THREAD_GROUP; // 4 + constexpr uint32_t FEATURE_CHUNKS_PER_THREAD_GROUP = + NUM_MMA_D / HALF_ELEMS_PER_THREAD; // NUM_MMA_D/4 + constexpr uint32_t COLUMN_RESET_OFFSET = + FEATURE_CHUNKS_PER_THREAD_GROUP * KV_THR_LAYOUT_COL; -#if defined(PLATFORM_HIP_DEVICE) - constexpr uint32_t COLUMN_RESET_OFFSET = (NUM_MMA_D / 4) * WARP_THREAD_COLS; -#else - constexpr uint32_t COLUMN_RESET_OFFSET = sizeof(DTypeKV) * NUM_MMA_D; -#endif + uint32_t row = lane_idx / KV_THR_LAYOUT_COL; + uint32_t kv_idx = kv_idx_base + warp_idx * KV_THR_LAYOUT_ROW + row; - uint32_t row = lane_idx / WARP_THREAD_COLS; - uint32_t kv_idx = kv_idx_base + warp_idx * WARP_THREAD_ROWS + row; // NOTE: NUM_MMA_KV*4/NUM_WARPS_Q = NUM_WARPS_KV*NUM_MMA_KV*4/num_warps static_assert(NUM_MMA_KV * 4 % NUM_WARPS_Q == 0); + +#pragma unroll + for (uint32_t i = 0; i < NUM_MMA_KV * 4 / NUM_WARPS_Q; ++i) + { // MMA tile iterations + + // CDNA3: Load complete 16×HEAD_DIM tile per i iteration #pragma unroll - for (uint32_t i = 0; i < NUM_MMA_KV * 4 / NUM_WARPS_Q; ++i) { + for (uint32_t k = 0; k < THREAD_GROUPS_PER_MMA_TILE; ++k) + { // 4 sequence groups #pragma unroll - for (uint32_t j = 0; j < NUM_MMA_D / (8 / sizeof(DTypeKV)); ++j) { - smem.template load_vector_async(*smem_offset, *gptr, - kv_idx < kv_len); - *smem_offset = - smem.template advance_offset_by_column( - *smem_offset, j); - *gptr += 8 * upcast_size(); + for (uint32_t j = 0; j < FEATURE_CHUNKS_PER_THREAD_GROUP; ++j) + { // Feature chunks + smem.template load_vector_async(*smem_offset, *gptr, + kv_idx < kv_len); + + // Advance to next feature chunk (same sequence group) + *smem_offset = + smem.template advance_offset_by_column( + *smem_offset, j); + *gptr += KV_THR_LAYOUT_COL * + upcast_size(); + } + + // Advance to next sequence group within same MMA tile + if (k < THREAD_GROUPS_PER_MMA_TILE - 1) + { // Don't advance after last group + kv_idx += NUM_WARPS * KV_THR_LAYOUT_ROW; + *smem_offset = + smem.template advance_offset_by_row< + NUM_WARPS * KV_THR_LAYOUT_ROW, UPCAST_STRIDE>( + *smem_offset) - + COLUMN_RESET_OFFSET; + *gptr += NUM_WARPS * KV_THR_LAYOUT_ROW * stride_n - + FEATURE_CHUNKS_PER_THREAD_GROUP * KV_THR_LAYOUT_COL * + upcast_size(); + } } - kv_idx += NUM_WARPS * WARP_THREAD_ROWS; + + // Final advance to next MMA tile + kv_idx += NUM_WARPS * KV_THR_LAYOUT_ROW; *smem_offset = - smem.template advance_offset_by_row(*smem_offset) - COLUMN_RESET_OFFSET; - *gptr += NUM_WARPS * WARP_THREAD_ROWS * stride_n - - sizeof(DTypeKV) * NUM_MMA_D * + *gptr += NUM_WARPS * KV_THR_LAYOUT_ROW * stride_n - + FEATURE_CHUNKS_PER_THREAD_GROUP * KV_THR_LAYOUT_COL * upcast_size(); } *smem_offset -= KTraits::CTA_TILE_KV * UPCAST_STRIDE; @@ -477,35 +568,15 @@ __device__ __forceinline__ void produce_kv( const uint32_t warp_idx = get_warp_idx(tid.y, tid.z), lane_idx = tid.x; - if constexpr (KTraits::SWIZZLE_MODE_KV == SwizzleMode::k128B) { - produce_kv_helper_( - warp_idx, lane_idx, smem, smem_offset, gptr, stride_n, kv_idx_base, - kv_len); - } #if defined(PLATFORM_HIP_DEVICE) - else if constexpr (KTraits::SWIZZLE_MODE_KV == SwizzleMode::kLinear) { - produce_kv_helper_( - warp_idx, lane_idx, smem, smem_offset, gptr, stride_n, kv_idx_base, - kv_len); - } + produce_kv_impl_cdna3_( + warp_idx, lane_idx, smem, smem_offset, gptr, stride_n, kv_idx_base, + kv_len); +#elif defined(PLATFORM_CUDA_DEVICE) + produce_kv_impl_cuda_( + warp_idx, lane_idx, smem, smem_offset, gptr, stride_n, kv_idx_base, + kv_len); #endif - else { - uint32_t kv_idx = kv_idx_base + warp_idx * 8 + lane_idx / 4; - // NOTE: NUM_MMA_KV * 2 / NUM_WARPS_Q = NUM_WARPS_KV * NUM_MMA_KV * 2 / - // num_warps - static_assert(NUM_MMA_KV * 2 % NUM_WARPS_Q == 0); -#pragma unroll - for (uint32_t i = 0; i < NUM_MMA_KV * 2 / NUM_WARPS_Q; ++i) { - smem.template load_vector_async(*smem_offset, *gptr, - kv_idx < kv_len); - *smem_offset = smem.template advance_offset_by_row( - *smem_offset); - kv_idx += NUM_WARPS * 8; - *gptr += NUM_WARPS * 8 * stride_n; - } - *smem_offset -= KTraits::CTA_TILE_KV * UPCAST_STRIDE; - } } template @@ -2164,26 +2235,43 @@ SinglePrefillWithKVCacheDevice(const Params params, memory::commit_group(); #if Debug - for (auto mma_q = 0ul; mma_q < 4; ++mma_q) { - uint32_t a_frag[KTraits::INT32_ELEMS_PER_THREAD]; - qo_smem.load_fragment(q_smem_offset_r, a_frag); - int global_idx = (blockIdx.z * gridDim.y * gridDim.x + - blockIdx.y * gridDim.x + blockIdx.x) * - (blockDim.z * blockDim.y * blockDim.x) + - (threadIdx.z * blockDim.y * blockDim.x + - threadIdx.y * blockDim.x + threadIdx.x); - if (global_idx == 0) { - auto frag_T = reinterpret_cast<__half *>(a_frag); - printf("DEBUG: Q Frag in permuted_smem for mma_q %lu \n", - mma_q); - for (auto i = 0ul; i < 4; ++i) { - printf("%f ", (float)(*(frag_T + i))); - } - printf("\n"); + int global_idx = (blockIdx.z * gridDim.y * gridDim.x + + blockIdx.y * gridDim.x + blockIdx.x) * + (blockDim.z * blockDim.y * blockDim.x) + + (threadIdx.z * blockDim.y * blockDim.x + + threadIdx.y * blockDim.x + threadIdx.x); + // for (auto mma_q = 0ul; mma_q < 4; ++mma_q) { + // uint32_t a_frag[KTraits::INT32_ELEMS_PER_THREAD]; + // qo_smem.load_fragment(q_smem_offset_r, a_frag); + // int global_idx = (blockIdx.z * gridDim.y * gridDim.x + + // blockIdx.y * gridDim.x + blockIdx.x) * + // (blockDim.z * blockDim.y * blockDim.x) + + // (threadIdx.z * blockDim.y * blockDim.x + + // threadIdx.y * blockDim.x + threadIdx.x); + // if (global_idx == 0) { + // auto frag_T = reinterpret_cast<__half *>(a_frag); + // printf("DEBUG: Q Frag in permuted_smem for mma_q %lu \n", + // mma_q); + // for (auto i = 0ul; i < 4; ++i) { + // printf("%f ", (float)(*(frag_T + i))); + // } + // printf("\n"); + // } + + // q_smem_offset_r = qo_smem.template advance_offset_by_column<4>( + // q_smem_offset_r, 0); + // } + uint32_t b_frag[KTraits::INT32_ELEMS_PER_THREAD]; + k_smem.load_fragment(k_smem_offset_r, b_frag); + + if (global_idx == 4) { + auto frag_T = reinterpret_cast<__half *>(b_frag); + // printf("DEBUG: K Frag in permuted_smem for mma_kv %lu \n", + // mma_kv); + for (auto i = 0ul; i < 4; ++i) { + printf("%f ", (float)(*(frag_T + i))); } - - q_smem_offset_r = qo_smem.template advance_offset_by_column<4>( - q_smem_offset_r, 0); + printf("\n"); } #endif diff --git a/libflashinfer/tests/hip/test_produce_kv_kernel.cpp b/libflashinfer/tests/hip/test_produce_kv_kernel.cpp new file mode 100644 index 0000000000..e69de29bb2 diff --git a/libflashinfer/utils/cpu_reference_hip.h b/libflashinfer/utils/cpu_reference_hip.h index fadfefcb89..a4d2f6b82a 100644 --- a/libflashinfer/utils/cpu_reference_hip.h +++ b/libflashinfer/utils/cpu_reference_hip.h @@ -164,10 +164,17 @@ single_mha(const std::vector &q, tensor_info_t info(qo_len, kv_len, num_qo_heads, num_kv_heads, kv_layout, HEAD_DIM); #if Debug - std::cout << "DEBUG Q (CPU): " << '\n'; - for (auto i = 0ul; i < 64; ++i) { - // q[info.get_q_elem_offset(q_idx, qo_head_idx, feat_idx) - std::cout << (float)q[info.get_q_elem_offset(0, 0, i)] << " "; + // std::cout << "DEBUG Q (CPU): " << '\n'; + // for (auto i = 0ul; i < 64; ++i) { + // // q[info.get_q_elem_offset(q_idx, qo_head_idx, feat_idx) + // std::cout << (float)q[info.get_q_elem_offset(0, 0, i)] << " "; + // } + // std::cout << std::endl; + + std::cout << "DEBUG K (CPU): " << '\n'; + for (auto i = 0ul; i < 4; ++i) { + // k[info.get_kv_elem_offset(kv_idx, kv_head_idx, feat_idx) + std::cout << (float)k[info.get_kv_elem_offset(4, 0, i)] << " "; } std::cout << std::endl; #endif From 1d60d18dc1499fd17a18662b08b6cec8e0f07c5b Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Fri, 22 Aug 2025 10:51:00 -0400 Subject: [PATCH 050/109] Fix compiler warnings. --- .../flashinfer/attention/generic/prefill.cuh | 30 +++++++++++-------- 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh index bcbf7ba433..c62a5a1367 100644 --- a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -401,6 +401,16 @@ __device__ __forceinline__ void produce_kv_impl_cuda_( const uint32_t kv_idx_base, const uint32_t kv_len) { + using DTypeKV = typename KTraits::DTypeKV; + constexpr uint32_t NUM_WARPS = KTraits::NUM_WARPS; + constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; + constexpr uint32_t NUM_WARPS_Q = KTraits::NUM_WARPS_Q; + constexpr uint32_t NUM_MMA_D = + produce_v ? KTraits::NUM_MMA_D_VO : KTraits::NUM_MMA_D_QK; + constexpr uint32_t UPCAST_STRIDE = + produce_v ? KTraits::UPCAST_STRIDE_V : KTraits::UPCAST_STRIDE_K; + constexpr uint32_t VECTOR_BIT_WIDTH = KTraits::VECTOR_BIT_WIDTH; + if constexpr (KTraits::SWIZZLE_MODE_KV == SwizzleMode::k128B) { uint32_t kv_idx = kv_idx_base + warp_idx * 4 + lane_idx / 8; // NOTE: NUM_MMA_KV * 4 / NUM_WARPS_Q = NUM_WARPS_KV * NUM_MMA_KV * 4 / @@ -410,11 +420,11 @@ __device__ __forceinline__ void produce_kv_impl_cuda_( for (uint32_t i = 0; i < NUM_MMA_KV * 4 / NUM_WARPS_Q; ++i) { #pragma unroll for (uint32_t j = 0; j < NUM_MMA_D / (8 / sizeof(DTypeKV)); ++j) { - smem.load_128b_async(*smem_offset, *gptr, - kv_idx < kv_len); + smem.template load_128b_async(*smem_offset, *gptr, + kv_idx < kv_len); *smem_offset = smem.template advance_offset_by_column<8>(*smem_offset, j); - *gptr += 8 * upcast_size(); + *gptr += 8 * upcast_size(); } kv_idx += NUM_WARPS * 4; *smem_offset = smem.template advance_offset_by_row(); + sizeof(DTypeKV) * NUM_MMA_D * + upcast_size(); } - *smem_offset -= CTA_TILE_KV * UPCAST_STRIDE; + *smem_offset -= KTraits::CTA_TILE_KV * UPCAST_STRIDE; } else { uint32_t kv_idx = kv_idx_base + warp_idx * 8 + lane_idx / 4; @@ -433,8 +444,8 @@ __device__ __forceinline__ void produce_kv_impl_cuda_( static_assert(NUM_MMA_KV * 2 % NUM_WARPS_Q == 0); #pragma unroll for (uint32_t i = 0; i < NUM_MMA_KV * 2 / NUM_WARPS_Q; ++i) { - smem.load_128b_async(*smem_offset, *gptr, - kv_idx < kv_len); + smem.template load_128b_async(*smem_offset, *gptr, + kv_idx < kv_len); *smem_offset = smem.template advance_offset_by_row( *smem_offset); @@ -560,11 +571,6 @@ __device__ __forceinline__ void produce_kv( const dim3 tid = threadIdx) { // NOTE: for fp8, this function doesn't work for head_dim = 64 at the moment - constexpr uint32_t NUM_WARPS = KTraits::NUM_WARPS; - constexpr uint32_t NUM_WARPS_Q = KTraits::NUM_WARPS_Q; - constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; - constexpr uint32_t UPCAST_STRIDE = - produce_v ? KTraits::UPCAST_STRIDE_V : KTraits::UPCAST_STRIDE_K; const uint32_t warp_idx = get_warp_idx(tid.y, tid.z), lane_idx = tid.x; From ead9a210dd59a8772143f3bd05e13b1e7d230d4e Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Fri, 22 Aug 2025 15:00:55 -0400 Subject: [PATCH 051/109] Fix k_smem_offset_rcalc --- .../flashinfer/attention/generic/prefill.cuh | 47 ++++++++++++------- libflashinfer/utils/cpu_reference_hip.h | 11 +++-- 2 files changed, 39 insertions(+), 19 deletions(-) diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh index c62a5a1367..439972e24b 100644 --- a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -2212,13 +2212,21 @@ SinglePrefillWithKVCacheDevice(const Params params, (lane_idx % KV_THR_LAYOUT_COL) * upcast_size(); +#if defined(PLATFORM_HIP_DEVICE) uint32_t k_smem_offset_r = - k_smem.template get_permuted_offset( - get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + - HALF_ELEMS_PER_THREAD * (lane_idx / 16) + - lane_idx % HALF_ELEMS_PER_THREAD, - (lane_idx % 16) / HALF_ELEMS_PER_THREAD), - v_smem_offset_r = + k_smem.template get_permuted_offset( + get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + + lane_idx % 16, + (lane_idx / 16)); +#elif defined(PLATFORM_CUDA_DEVICE) + uint32_t k_smem_offset_r = + k_smem.template get_permuted_offset( + get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + + 8 * (lane_idx / 16) + lane_idx % 8, + (lane_idx % 16) / 8); +#endif + + uint32_t v_smem_offset_r = v_smem.template get_permuted_offset( get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + lane_idx % 16, @@ -2267,17 +2275,24 @@ SinglePrefillWithKVCacheDevice(const Params params, // q_smem_offset_r = qo_smem.template advance_offset_by_column<4>( // q_smem_offset_r, 0); // } - uint32_t b_frag[KTraits::INT32_ELEMS_PER_THREAD]; - k_smem.load_fragment(k_smem_offset_r, b_frag); - - if (global_idx == 4) { - auto frag_T = reinterpret_cast<__half *>(b_frag); - // printf("DEBUG: K Frag in permuted_smem for mma_kv %lu \n", - // mma_kv); - for (auto i = 0ul; i < 4; ++i) { - printf("%f ", (float)(*(frag_T + i))); + if (global_idx == 0) { + + for (auto j = 0; j < 64; ++j) { + uint32_t k_smem_offset_r_test = + k_smem.template get_permuted_offset( + get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + + j % 16, + (j / 16)); + uint32_t b_frag[KTraits::INT32_ELEMS_PER_THREAD]; + k_smem.load_fragment(k_smem_offset_r_test, b_frag); + auto frag_T = reinterpret_cast<__half *>(b_frag); + // printf("DEBUG: K Frag in permuted_smem for mma_kv %lu \n", + // mma_kv); + for (auto i = 0ul; i < 4; ++i) { + printf("%f ", (float)(*(frag_T + i))); + } + printf("\n"); } - printf("\n"); } #endif diff --git a/libflashinfer/utils/cpu_reference_hip.h b/libflashinfer/utils/cpu_reference_hip.h index a4d2f6b82a..5a473408ca 100644 --- a/libflashinfer/utils/cpu_reference_hip.h +++ b/libflashinfer/utils/cpu_reference_hip.h @@ -172,9 +172,14 @@ single_mha(const std::vector &q, // std::cout << std::endl; std::cout << "DEBUG K (CPU): " << '\n'; - for (auto i = 0ul; i < 4; ++i) { - // k[info.get_kv_elem_offset(kv_idx, kv_head_idx, feat_idx) - std::cout << (float)k[info.get_kv_elem_offset(4, 0, i)] << " "; + for (auto j = 0ul; j < 16; ++j) { + for (auto i = 0ul; i < 64; ++i) { + // k[info.get_kv_elem_offset(kv_idx, kv_head_idx, feat_idx) + // std::cout << (float)k[info.get_kv_elem_offset(15, 0, j * 4 + + // i)] + std::cout << (float)k[info.get_kv_elem_offset(j, 0, i)] << " "; + } + std::cout << '\n'; } std::cout << std::endl; #endif From 89cfce4c680bf36acd3aa4df812f179bb2b628ca Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Tue, 26 Aug 2025 07:34:24 -0400 Subject: [PATCH 052/109] Fix init_rope_freq. --- .../flashinfer/attention/generic/prefill.cuh | 257 +++++++++++++++--- 1 file changed, 213 insertions(+), 44 deletions(-) diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh index 439972e24b..1ec2d7f106 100644 --- a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -343,13 +343,19 @@ q_frag_apply_llama_rope(T *x_first_half, // --------- // 2 3 | 6 7 #if defined(PLATFORM_HIP_DEVICE) - uint32_t i = reg_id / 2, j = reg_id % 2; + // // Same sequence for all 4 features + // uint32_t i = 0; + // Direct mapping to frequency array + uint32_t freq_idx = reg_id; + // Same position for this thread's sequence + uint32_t position = qo_packed_offset; #else uint32_t i = ((reg_id % 4) / 2), j = (reg_id / 4); + uint32_t freq_idx = 2 * j + reg_id % 2; + uint32_t position = qo_packed_offset + 8 * i; #endif - __sincosf(float((qo_packed_offset + 8 * i) / group_size) * - rope_freq[2 * j + reg_id % 2], - &sin, &cos); + __sincosf(float(position / group_size) * rope_freq[freq_idx], &sin, + &cos); tmp = x_first_half[reg_id]; x_first_half[reg_id] = (tmp * cos - (float)x_second_half[reg_id] * sin); x_second_half[reg_id] = @@ -656,6 +662,41 @@ __device__ __forceinline__ void page_produce_kv( } } +__device__ __forceinline__ uint32_t get_feature_index(uint32_t j) +{ +#if defined(PLATFORM_HIP_DEVICE) + // CDNA3 A-matrix MMA tile to thread mapping for a 64-thread wavefront: + // Each group of 16 threads handles the same four consecutive features for + // different sequences: + // T0-T15: Features [0,1,2,3] for sequences 0-15 respectively + // T16-T31: Features [4,5,6,7] for sequences 0-15 respectively + // T32-T47: Features [8,9,10,11] for sequences 0-15 respectively + // T48-T63: Features [12,13,14,15] for sequences 0-15 respectively + // + uint32_t feature_index = (mma_d * 16 + (lane_idx / 4) + j) % (HEAD_DIM / 2); +#elif defined(PLATFORM_CUDA_DEVICE) + // CUDA A-matrix MMA tile to thread mapping for a 32 thread warp: + // Each group of four consecutive threads map four different features for + // the same sequence. + // T0: {0,1,8,9}, T1: {2,3,10,11}, T2: {4,5,12,13}, T3: {6,7,14,15} + // + // The pattern repeats across 8 rows with each row mapped to a set of four + // consecutive threads. + // row 0 --> T0, T1, T2, T3 + // row 1 --> T4, T5, T6, T7 + // ... + // row 7 --> T28, T29, T30, T31. + // The full data to thread mapping repeats again for the next set of 16 + // rows. Thereby, forming a 16x16 MMA tile dubdivided into four 8x8 + // quadrants. + uint32_t feature_index = + ((mma_d * 16 + (j / 2) * 8 + (lane_idx % 4) * 2 + (j % 2)) % + (HEAD_DIM / 2)); +#endif + + return feature_index; +} + template __device__ __forceinline__ void init_rope_freq(float (*rope_freq)[4], @@ -666,9 +707,6 @@ init_rope_freq(float (*rope_freq)[4], constexpr uint32_t HEAD_DIM = KTraits::NUM_MMA_D_QK * 16; const uint32_t lane_idx = tid_x; - constexpr uint32_t THREADS_PER_ROW = KTraits::THREADS_PER_MATRIX_ROW_SET; - constexpr uint32_t ELEMS_PER_THREAD = 8 / THREADS_PER_ROW; - #pragma unroll for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO / 2; ++mma_d) { #pragma unroll @@ -676,12 +714,7 @@ init_rope_freq(float (*rope_freq)[4], rope_freq[mma_d][j] = rope_rcp_scale * __powf(rope_rcp_theta, - float(2 * - ((mma_d * 16 + (j / 2) * 8 + - (lane_idx % THREADS_PER_ROW) * ELEMS_PER_THREAD + - (j % 2)) % - (HEAD_DIM / 2))) / - float(HEAD_DIM)); + float(2 * get_feature_index(j)) / float(HEAD_DIM)); } } } @@ -807,19 +840,26 @@ __device__ __forceinline__ void q_smem_inplace_apply_rotary( uint32_t q_frag_local[2][KTraits::INT32_ELEMS_PER_THREAD]; static_assert(KTraits::NUM_MMA_D_QK % 4 == 0, "NUM_MMA_D_QK must be a multiple of 4"); +#if defined(PLATFORM_HIP_DEVICE) + constexpr uint32_t LAST_HALF_OFFSET = KTraits::NUM_MMA_D_QK * 2; + constexpr uint32_t FIRST_HALF_OFFSET = KTraits::NUM_MMA_D_QK; + const uint32_t SEQ_ID = lane_idx % 16; +#elif defined(PLATFORM_CUDA_DEVICE) + constexpr uint32_t LAST_HALF_OFFSET = KTraits::NUM_MMA_D_QK; + constexpr uint32_t FIRST_HALF_OFFSET = KTraits::NUM_MMA_D_QK / 2; + const uint32_t SEQ_ID = lane_idx / KTraits::THREADS_PER_MATRIX_ROW_SET; +#endif #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { uint32_t q_smem_offset_r_first_half = *q_smem_offset_r; #pragma unroll - for (uint32_t mma_di = 0; mma_di < KTraits::NUM_MMA_D_QK / 2; - ++mma_di) - { + for (uint32_t mma_di = 0; mma_di < FIRST_HALF_OFFSET; ++mma_di) { q_smem->template load_fragment(q_smem_offset_r_first_half, q_frag_local[0]); uint32_t q_smem_offset_r_last_half = - q_smem->template advance_offset_by_column< - KTraits::NUM_MMA_D_QK>(q_smem_offset_r_first_half, 0); + q_smem->template advance_offset_by_column( + q_smem_offset_r_first_half, 0); q_smem->template load_fragment(q_smem_offset_r_last_half, q_frag_local[1]); q_frag_apply_llama_ropetemplate store_fragment(q_smem_offset_r_last_half, q_frag_local[1]); q_smem->template store_fragment(q_smem_offset_r_first_half, q_frag_local[0]); q_smem_offset_r_first_half = - q_smem->template advance_offset_by_column<2>( - q_smem_offset_r_first_half, mma_di); + q_smem + ->template advance_offset_by_column( + q_smem_offset_r_first_half, mma_di); } *q_smem_offset_r += 16 * UPCAST_STRIDE_Q; } @@ -2254,19 +2294,45 @@ SinglePrefillWithKVCacheDevice(const Params params, (blockDim.z * blockDim.y * blockDim.x) + (threadIdx.z * blockDim.y * blockDim.x + threadIdx.y * blockDim.x + threadIdx.x); + + if (global_idx == 0) { + printf("partition_kv : %d\n", partition_kv); + printf("kv_len : %d\n", kv_len); + printf("max_chunk_size : %d\n", max_chunk_size); + printf("chunk_end : %d\n", chunk_end); + printf("chunk_start : %d\n", chunk_start); + } + // Test Q + // if (global_idx == 0) { + // uint32_t q_smem_offset_r_debug; + // //for (auto i = 0; i < 4; ++i) { + // for (auto j = 0; j < 16; ++j) { + // uint32_t q_smem_offset_r_debug = + // qo_smem.template + // get_permuted_offset( + // get_warp_idx_q(tid.y) * NUM_MMA_Q * 16 + // + (j) % 16, (j) / 16); + // uint32_t a_frag[KTraits::INT32_ELEMS_PER_THREAD]; + // k_smem.load_fragment(q_smem_offset_r_debug, a_frag); + // auto frag_T = reinterpret_cast<__half *>(a_frag); + // for (auto i = 0ul; i < 4; ++i) { + // printf("%f ", (float)(*(frag_T + i))); + // } + // printf("\n"); + // } + // // q_smem_offset_r_debug = qo_smem.template + // advance_offset_by_column<4>( + // // q_smem_offset_r_debug, 0); + // // } + // } + // for (auto mma_q = 0ul; mma_q < 4; ++mma_q) { // uint32_t a_frag[KTraits::INT32_ELEMS_PER_THREAD]; // qo_smem.load_fragment(q_smem_offset_r, a_frag); - // int global_idx = (blockIdx.z * gridDim.y * gridDim.x + - // blockIdx.y * gridDim.x + blockIdx.x) * - // (blockDim.z * blockDim.y * blockDim.x) + - // (threadIdx.z * blockDim.y * blockDim.x + - // threadIdx.y * blockDim.x + threadIdx.x); // if (global_idx == 0) { // auto frag_T = reinterpret_cast<__half *>(a_frag); // printf("DEBUG: Q Frag in permuted_smem for mma_q %lu \n", - // mma_q); - // for (auto i = 0ul; i < 4; ++i) { + // mma_q); for (auto i = 0ul; i < 4; ++i) { // printf("%f ", (float)(*(frag_T + i))); // } // printf("\n"); @@ -2275,25 +2341,128 @@ SinglePrefillWithKVCacheDevice(const Params params, // q_smem_offset_r = qo_smem.template advance_offset_by_column<4>( // q_smem_offset_r, 0); // } + + uint32_t a_frag[KTraits::INT32_ELEMS_PER_THREAD]; + qo_smem.load_fragment(q_smem_offset_r, a_frag); if (global_idx == 0) { + auto frag_T = reinterpret_cast<__half *>(a_frag); + printf("DEBUG: Q Frag \n"); + for (auto i = 0ul; i < 4; ++i) { + printf("%f ", (float)(*(frag_T + i))); + } + printf("\n"); + } - for (auto j = 0; j < 64; ++j) { - uint32_t k_smem_offset_r_test = - k_smem.template get_permuted_offset( - get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + - j % 16, - (j / 16)); - uint32_t b_frag[KTraits::INT32_ELEMS_PER_THREAD]; - k_smem.load_fragment(k_smem_offset_r_test, b_frag); - auto frag_T = reinterpret_cast<__half *>(b_frag); - // printf("DEBUG: K Frag in permuted_smem for mma_kv %lu \n", - // mma_kv); - for (auto i = 0ul; i < 4; ++i) { - printf("%f ", (float)(*(frag_T + i))); - } - printf("\n"); + memory::wait_group<0>(); + block.sync(); + q_smem_inplace_apply_rotary(qo_packed_idx_base, qo_len, kv_len, + group_size, &qo_smem, + &q_smem_offset_r, rope_freq, tid); + block.sync(); + + qo_smem.load_fragment(q_smem_offset_r, a_frag); + if (global_idx == 0) { + auto frag_T = reinterpret_cast<__half *>(a_frag); + printf("DEBUG: LLAMA Rope transformed Q Frag \n"); + for (auto i = 0ul; i < 4; ++i) { + printf("%f ", (float)(*(frag_T + i))); } + printf("\n"); } + + // // Test K loads + // if (global_idx == 0) { + + // for (auto j = 0; j < 64; ++j) { + // uint32_t k_smem_offset_r_test = + // k_smem.template get_permuted_offset( + // get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + + // j % 16, + // (j / 16)); + // uint32_t b_frag[KTraits::INT32_ELEMS_PER_THREAD]; + // k_smem.load_fragment(k_smem_offset_r_test, b_frag); + // auto frag_T = reinterpret_cast<__half *>(b_frag); + // // printf("DEBUG: K Frag in permuted_smem for mma_kv %lu \n", + // // mma_kv); + // for (auto i = 0ul; i < 4; ++i) { + // printf("%f ", (float)(*(frag_T + i))); + // } + // printf("\n"); + // } + // } + + // if (global_idx == 0) { + // printf("DEBUG Q ORIGINAL (HIP):\n"); + + // for (uint32_t seq_idx = 0; seq_idx < 16; ++seq_idx) { + // printf("Q[%u] original: ", seq_idx); + + // // Load all feature groups for this sequence + // for (uint32_t feat_group = 0; feat_group < NUM_MMA_D_QK; + // ++feat_group) { + // uint32_t feat_offset = qo_smem.template + // get_permuted_offset( + // seq_idx, feat_group * HALF_ELEMS_PER_THREAD); + + // uint32_t q_frag[KTraits::INT32_ELEMS_PER_THREAD]; + // qo_smem.load_fragment(feat_offset, q_frag); + // auto frag_T = reinterpret_cast<__half *>(q_frag); + + // // Print 4 features from this group + // for (auto feat = 0ul; feat < HALF_ELEMS_PER_THREAD; + // ++feat) { + // printf("%f ", (float)(*(frag_T + feat))); + // } + // } + // printf("\n"); + // } + // } + + // memory::wait_group<0>(); + // block.sync(); + // q_smem_inplace_apply_rotary( + // qo_packed_idx_base, qo_len, kv_len, group_size, &qo_smem, + // &q_smem_offset_r, rope_freq, tid); + // block.sync(); + + // // Debug: Print Q fragments after RoPE + // if (global_idx == 0) { + // printf("DEBUG Q LLAMA ROPE (HIP):\n"); + + // // Reset q_smem_offset_r to start + // uint32_t q_smem_offset_r_debug = + // qo_smem.template get_permuted_offset( + // get_warp_idx_q(tid.y) * NUM_MMA_Q * 16 + + // lane_idx % 16, lane_idx / 16); + + // for (uint32_t seq_idx = 0; seq_idx < 16; ++seq_idx) { + // // Calculate offset for this sequence + // uint32_t seq_offset = qo_smem.template + // get_permuted_offset( + // seq_idx, 0); + + // printf("Q[%u] after RoPE: ", seq_idx); + + // // Load all feature groups for this sequence + // for (uint32_t feat_group = 0; feat_group < NUM_MMA_D_QK; + // ++feat_group) { + // uint32_t feat_offset = qo_smem.template + // get_permuted_offset( + // seq_idx, feat_group * HALF_ELEMS_PER_THREAD); + + // uint32_t q_frag[KTraits::INT32_ELEMS_PER_THREAD]; + // qo_smem.load_fragment(feat_offset, q_frag); + // auto frag_T = reinterpret_cast<__half *>(q_frag); + + // // Print 4 features from this group + // for (auto feat = 0ul; feat < HALF_ELEMS_PER_THREAD; + // ++feat) { + // printf("%f ", (float)(*(frag_T + feat))); + // } + // } + // printf("\n"); + // } + // } #endif #pragma unroll 1 From 712b8ed57aa9f211b498b3cea03b27fa9e5d8218 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Wed, 27 Aug 2025 06:40:26 -0400 Subject: [PATCH 053/109] Debug --- .../flashinfer/attention/generic/prefill.cuh | 135 ++++++---- .../tests/hip/test_apply_llama_rope.cpp | 230 ++++++++++++++++++ .../tests/hip/test_load_q_global_smem.cpp | 218 +++-------------- .../tests/hip/test_load_q_global_smem_v1.cpp | 193 +++++++++++++++ ...nel.cpp => test_load_q_global_smem_v2.cpp} | 0 .../tests/hip/test_single_prefill.cpp | 2 +- libflashinfer/utils/cpu_reference_hip.h | 100 +++++++- 7 files changed, 627 insertions(+), 251 deletions(-) create mode 100644 libflashinfer/tests/hip/test_apply_llama_rope.cpp create mode 100644 libflashinfer/tests/hip/test_load_q_global_smem_v1.cpp rename libflashinfer/tests/hip/{test_load_q_global_smem_kernel.cpp => test_load_q_global_smem_v2.cpp} (100%) diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh index 1ec2d7f106..a8226d9da2 100644 --- a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -336,26 +336,47 @@ q_frag_apply_llama_rope(T *x_first_half, const uint32_t qo_packed_offset, const uint_fastdiv group_size) { + +#if Debug + int global_idx = (blockIdx.z * gridDim.y * gridDim.x + + blockIdx.y * gridDim.x + blockIdx.x) * + (blockDim.z * blockDim.y * blockDim.x) + + (threadIdx.z * blockDim.y * blockDim.x + + threadIdx.y * blockDim.x + threadIdx.x); + + if (global_idx == 0) { + printf("=== Q_FRAG_APPLY_LLAMA_ROPE DEBUG ===\n"); + printf("qo_packed_offset=%u, group_size=%u, HALF_ELEMS_PER_THREAD=%u\n", + qo_packed_offset, (uint32_t)group_size, HALF_ELEMS_PER_THREAD); + printf("Input frequencies: %f %f %f %f\n", rope_freq[0], rope_freq[1], + rope_freq[2], rope_freq[3]); + } +#endif + #pragma unroll for (uint32_t reg_id = 0; reg_id < HALF_ELEMS_PER_THREAD; ++reg_id) { float cos, sin, tmp; - // 0 1 | 4 5 - // --------- - // 2 3 | 6 7 #if defined(PLATFORM_HIP_DEVICE) - // // Same sequence for all 4 features - // uint32_t i = 0; - // Direct mapping to frequency array uint32_t freq_idx = reg_id; - // Same position for this thread's sequence uint32_t position = qo_packed_offset; #else + // 0 1 | 4 5 + // --------- + // 2 3 | 6 7 uint32_t i = ((reg_id % 4) / 2), j = (reg_id / 4); uint32_t freq_idx = 2 * j + reg_id % 2; uint32_t position = qo_packed_offset + 8 * i; #endif __sincosf(float(position / group_size) * rope_freq[freq_idx], &sin, &cos); +#if Debug + if (global_idx == 0) { + printf("reg_id=%u: freq_idx=%u, position=%u, angle=%f\n", reg_id, + freq_idx, position, + float(position / group_size) * rope_freq[freq_idx]); + } +#endif + tmp = x_first_half[reg_id]; x_first_half[reg_id] = (tmp * cos - (float)x_second_half[reg_id] * sin); x_second_half[reg_id] = @@ -662,7 +683,10 @@ __device__ __forceinline__ void page_produce_kv( } } -__device__ __forceinline__ uint32_t get_feature_index(uint32_t j) +template +__device__ __forceinline__ uint32_t get_feature_index(uint32_t mma_d, + uint32_t lane_idx, + uint32_t j) { #if defined(PLATFORM_HIP_DEVICE) // CDNA3 A-matrix MMA tile to thread mapping for a 64-thread wavefront: @@ -693,7 +717,6 @@ __device__ __forceinline__ uint32_t get_feature_index(uint32_t j) ((mma_d * 16 + (j / 2) * 8 + (lane_idx % 4) * 2 + (j % 2)) % (HEAD_DIM / 2)); #endif - return feature_index; } @@ -711,10 +734,11 @@ init_rope_freq(float (*rope_freq)[4], for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO / 2; ++mma_d) { #pragma unroll for (uint32_t j = 0; j < 4; ++j) { + uint32_t feature_index = + get_feature_index(mma_d, lane_idx, j); + float freq_base = float(2 * feature_index) / float(HEAD_DIM); rope_freq[mma_d][j] = - rope_rcp_scale * - __powf(rope_rcp_theta, - float(2 * get_feature_index(j)) / float(HEAD_DIM)); + rope_rcp_scale * __powf(rope_rcp_theta, freq_base); } } } @@ -834,55 +858,60 @@ __device__ __forceinline__ void q_smem_inplace_apply_rotary( float (*rope_freq)[4], const dim3 tid = threadIdx) { - if (get_warp_idx_kv(tid.z) == 0) { - constexpr uint32_t UPCAST_STRIDE_Q = KTraits::UPCAST_STRIDE_Q; - const uint32_t lane_idx = tid.x; - uint32_t q_frag_local[2][KTraits::INT32_ELEMS_PER_THREAD]; - static_assert(KTraits::NUM_MMA_D_QK % 4 == 0, - "NUM_MMA_D_QK must be a multiple of 4"); + if (get_warp_idx_kv(tid.z) != 0) + return; + + constexpr uint32_t UPCAST_STRIDE_Q = KTraits::UPCAST_STRIDE_Q; + constexpr uint32_t COL_ADVANCE_TO_NEXT = + 16 / KTraits::HALF_ELEMS_PER_THREAD; + #if defined(PLATFORM_HIP_DEVICE) - constexpr uint32_t LAST_HALF_OFFSET = KTraits::NUM_MMA_D_QK * 2; - constexpr uint32_t FIRST_HALF_OFFSET = KTraits::NUM_MMA_D_QK; - const uint32_t SEQ_ID = lane_idx % 16; + constexpr uint32_t COL_ADVANCE_TO_LAST_HALF = KTraits::NUM_MMA_D_QK * 2; #elif defined(PLATFORM_CUDA_DEVICE) - constexpr uint32_t LAST_HALF_OFFSET = KTraits::NUM_MMA_D_QK; - constexpr uint32_t FIRST_HALF_OFFSET = KTraits::NUM_MMA_D_QK / 2; - const uint32_t SEQ_ID = lane_idx / KTraits::THREADS_PER_MATRIX_ROW_SET; + constexpr uint32_t COL_ADVANCE_TO_LAST_HALF = KTraits::NUM_MMA_D_QK; #endif + const uint32_t lane_idx = tid.x; + uint32_t q_frag_local[2][KTraits::INT32_ELEMS_PER_THREAD]; + static_assert(KTraits::NUM_MMA_D_QK % 4 == 0, + "NUM_MMA_D_QK must be a multiple of 4"); #pragma unroll - for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { - uint32_t q_smem_offset_r_first_half = *q_smem_offset_r; + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { + uint32_t q_smem_offset_r_first_half = *q_smem_offset_r; +#if defined(PLATFORM_HIP_DEVICE) + const uint32_t seq_id = q_packed_idx + kv_len * group_size - + qo_len * group_size + mma_q * 16 + + lane_idx % 16; +#elif defined(PLATFORM_CUDA_DEVICE) + const uint32_t seq_id = q_packed_idx + kv_len * group_size - + qo_len * group_size + mma_q * 16 + lane_idx / 4; +#endif #pragma unroll - for (uint32_t mma_di = 0; mma_di < FIRST_HALF_OFFSET; ++mma_di) { - q_smem->template load_fragment(q_smem_offset_r_first_half, - q_frag_local[0]); - uint32_t q_smem_offset_r_last_half = - q_smem->template advance_offset_by_column( - q_smem_offset_r_first_half, 0); - q_smem->template load_fragment(q_smem_offset_r_last_half, - q_frag_local[1]); - q_frag_apply_llama_rope( - (typename KTraits::DTypeQ *)q_frag_local[0], - (typename KTraits::DTypeQ *)q_frag_local[1], - rope_freq[mma_di], - q_packed_idx + kv_len * group_size - qo_len * group_size + - mma_q * 16 + SEQ_ID, - group_size); - q_smem->template store_fragment(q_smem_offset_r_last_half, - q_frag_local[1]); - q_smem->template store_fragment(q_smem_offset_r_first_half, - q_frag_local[0]); - q_smem_offset_r_first_half = - q_smem - ->template advance_offset_by_column( - q_smem_offset_r_first_half, mma_di); - } - *q_smem_offset_r += 16 * UPCAST_STRIDE_Q; + for (uint32_t mma_di = 0; mma_di < KTraits::NUM_MMA_D_QK / 2; ++mma_di) + { + q_smem->template load_fragment(q_smem_offset_r_first_half, + q_frag_local[0]); + uint32_t q_smem_offset_r_last_half = + q_smem->template advance_offset_by_column< + COL_ADVANCE_TO_LAST_HALF>(q_smem_offset_r_first_half, 0); + q_smem->template load_fragment(q_smem_offset_r_last_half, + q_frag_local[1]); + q_frag_apply_llama_rope( + (typename KTraits::DTypeQ *)q_frag_local[0], + (typename KTraits::DTypeQ *)q_frag_local[1], rope_freq[mma_di], + seq_id, group_size); + q_smem->template store_fragment(q_smem_offset_r_last_half, + q_frag_local[1]); + q_smem->template store_fragment(q_smem_offset_r_first_half, + q_frag_local[0]); + q_smem_offset_r_first_half = + q_smem->template advance_offset_by_column( + q_smem_offset_r_first_half, mma_di); } - *q_smem_offset_r -= KTraits::NUM_MMA_Q * 16 * UPCAST_STRIDE_Q; + *q_smem_offset_r += 16 * UPCAST_STRIDE_Q; } + *q_smem_offset_r -= KTraits::NUM_MMA_Q * 16 * UPCAST_STRIDE_Q; } template diff --git a/libflashinfer/tests/hip/test_apply_llama_rope.cpp b/libflashinfer/tests/hip/test_apply_llama_rope.cpp new file mode 100644 index 0000000000..f6cc629c39 --- /dev/null +++ b/libflashinfer/tests/hip/test_apply_llama_rope.cpp @@ -0,0 +1,230 @@ +// SPDX - FileCopyrightText : 2025 Advanced Micro Devices, Inc. +// +// SPDX - License - Identifier : Apache 2.0 + +#include "../../utils/cpu_reference_hip.h" +#include "../../utils/utils_hip.h" +#include "flashinfer/attention/generic/prefill.cuh" +#include "gpu_iface/gpu_runtime_compat.hpp" +#include +#include +#include + +namespace +{ +using QParamType = std::tuple; + +template struct TestKernelTraits +{ + static constexpr uint32_t NUM_MMA_D_QK = HEAD_DIM / 16; + static constexpr uint32_t NUM_MMA_D_VO = HEAD_DIM / 16; +}; + +template +__global__ void test_init_rope_freq_kernel(float *output_freq, + float rope_rcp_scale, + float rope_rcp_theta) +{ + using KTraits = TestKernelTraits; + + // Allocate local frequency array + float rope_freq[KTraits::NUM_MMA_D_VO / 2][4]; // [2][4] for HEAD_DIM=64 + + // Call the init_rope_freq function from prefill.cuh + flashinfer::init_rope_freq(rope_freq, rope_rcp_scale, + rope_rcp_theta, threadIdx.x); + + // Write frequencies to their correct feature indices + const uint32_t lane_idx = threadIdx.x; + if (lane_idx < 64) { // Only write for valid threads + for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO / 2; ++mma_d) { + for (uint32_t j = 0; j < 4; ++j) { + // Calculate the actual feature index this frequency corresponds + // to + uint32_t feature_idx = + flashinfer::get_feature_index(mma_d, lane_idx, j); + + // Write frequency to the correct feature index in global array + if (feature_idx < HEAD_DIM) { + output_freq[feature_idx] = rope_freq[mma_d][j]; + if (feature_idx + HEAD_DIM / 2 < HEAD_DIM) { + output_freq[feature_idx + HEAD_DIM / 2] = + rope_freq[mma_d][j]; + } + } + } + } + } +} + +template +class LLamaRopeTestFixture : public ::testing::TestWithParam +{ +protected: + uint32_t qo_len, num_qo_heads, head_dim; + std::vector q; + + LLamaRopeTestFixture() + { + const auto ¶ms = GetParam(); + qo_len = std::get<0>(params); + num_qo_heads = std::get<1>(params); + head_dim = std::get<2>(params); + q.resize(qo_len * num_qo_heads * head_dim); + } + + void SetUp() override { utils::vec_normal_(q); } + + void TearDown() override {} + + std::vector apply_cpu_rope(size_t offset, + float rope_scale = 1.0f, + float rope_theta = 10000.0f) + { + return cpu_reference::apply_llama_rope(q.data(), head_dim, offset, + rope_scale, rope_theta); + } + + std::vector get_cpu_rope_frequencies(float rope_scale = 1.0f, + float rope_theta = 10000.0f) + { + std::vector frequencies(head_dim); + + for (size_t k = 0; k < head_dim; ++k) { + // Extract ONLY the frequency calculation (without position/offset) + float freq_base = float(2 * (k % (head_dim / 2))) / float(head_dim); + float frequency = + (1.0f / rope_scale) / std::pow(rope_theta, freq_base); + frequencies[k] = frequency; + } + + return frequencies; + } + + std::vector get_gpu_rope_frequencies(float rope_scale = 1.0f, + float rope_theta = 10000.0f) + { + // Convert to reciprocal values as expected by GPU kernel + float rope_rcp_scale = 1.0f / rope_scale; + float rope_rcp_theta = 1.0f / rope_theta; + + // Allocate GPU memory for output (one frequency per feature) + float *d_output_freq; + size_t output_size = head_dim * sizeof(float); + FI_GPU_CALL(hipMalloc(&d_output_freq, output_size)); + FI_GPU_CALL(hipMemset(d_output_freq, 0, output_size)); + + // Launch kernel with 64 threads + dim3 grid(1); + dim3 block(64); + + if (head_dim == 64) { + test_init_rope_freq_kernel<64><<>>( + d_output_freq, rope_rcp_scale, rope_rcp_theta); + } + + FI_GPU_CALL(hipDeviceSynchronize()); + + // Copy all frequencies back + std::vector gpu_frequencies(head_dim); + FI_GPU_CALL(hipMemcpy(gpu_frequencies.data(), d_output_freq, + output_size, hipMemcpyDeviceToHost)); + + FI_GPU_CALL(hipFree(d_output_freq)); + return gpu_frequencies; + } + + std::vector> + apply_cpu_rope_all_sequences(size_t kv_len = 1000, + float rope_scale = 1.0f, + float rope_theta = 10000.0f) + { + std::vector> results; + + DISPATCH_head_dim(head_dim, HEAD_DIM, { + using namespace flashinfer; + tensor_info_t info(qo_len, kv_len, num_qo_heads, num_qo_heads, + QKVLayout::kHND, HEAD_DIM); + + // Apply RoPE to all sequences and heads + for (size_t qo_head_idx = 0; qo_head_idx < num_qo_heads; + ++qo_head_idx) + { + for (size_t q_idx = 0; q_idx < qo_len; ++q_idx) { + size_t offset = q_idx + kv_len - qo_len; + + // Apply RoPE to this specific Q sequence/head + auto q_rotary_local = cpu_reference::apply_llama_rope_debug( + q.data() + + info.get_q_elem_offset(q_idx, qo_head_idx, 0), + head_dim, offset, rope_scale, rope_theta); + + results.push_back(std::move(q_rotary_local)); + } + } + }); + + return results; + } +}; + +using LLamaRopeTestWithFP16 = LLamaRopeTestFixture<__half>; +} // namespace + +// Wrapper to validate freq application +// call q_smem_inplace_apply_rotary and copy back results to CPU. + +// Test 1. Copy CPU Q matrix to GPU call freq init validator +// launch kernel + +// Test 2. Copy CPU Q matrix to GPU call freq apply validator +// launch kernel + +TEST_P(LLamaRopeTestWithFP16, TestInitRopeFreq) +{ + constexpr float RELATIVE_EPSILON = 1e-6f; + size_t num_mismatches = 0; + auto cpu_frequencies = this->get_cpu_rope_frequencies(); + auto gpu_frequencies = this->get_gpu_rope_frequencies(); + + // Print side-by-side comparison for easier visual inspection + std::cout << "\nSide-by-side comparison:\n"; + std::cout << "Index\tCPU\t\tGPU\t\tDifference\n"; + std::cout << "-----\t---\t\t---\t\t----------\n"; + + for (size_t i = 0; i < std::min(16u, this->head_dim); ++i) { + float diff = std::abs(cpu_frequencies[i] - gpu_frequencies[i]); + std::cout << i << "\t" << cpu_frequencies[i] << "\t\t" + << gpu_frequencies[i] << "\t\t" << diff << std::endl; + } + + ASSERT_EQ(cpu_frequencies.size(), this->head_dim); + ASSERT_EQ(gpu_frequencies.size(), this->head_dim); + + for (auto i = 0ul; i < cpu_frequencies.size(); ++i) { + auto diff = std::abs(cpu_frequencies[i] - gpu_frequencies[i]); + if (diff >= RELATIVE_EPSILON) { + std::cout << "Diff : " << diff << " at feature index " << i << " " + << "cpu_frequencies[i]: " << cpu_frequencies[i] << " " + << "gpu_frequencies[i]: " << gpu_frequencies[i] << '\n'; + ++num_mismatches; + } + } + + ASSERT_EQ(num_mismatches, 0); +} + +TEST_P(LLamaRopeTestWithFP16, VectorSizeIsCorrect) +{ + const auto ¶ms = GetParam(); + size_t expected_size = + std::get<0>(params) * std::get<1>(params) * std::get<2>(params); + ASSERT_EQ(this->q.size(), expected_size); +} + +INSTANTIATE_TEST_SUITE_P( + LLamaRopeTestWithFP16, + LLamaRopeTestWithFP16, + ::testing::Values( + std::make_tuple(256, 1, 64) // qo_len=256, num_qo_heads=1, head_dim=64 + )); diff --git a/libflashinfer/tests/hip/test_load_q_global_smem.cpp b/libflashinfer/tests/hip/test_load_q_global_smem.cpp index c9d8c3840a..bedaf3afb7 100644 --- a/libflashinfer/tests/hip/test_load_q_global_smem.cpp +++ b/libflashinfer/tests/hip/test_load_q_global_smem.cpp @@ -1,193 +1,43 @@ -#include -#include -#include -#include +// SPDX - FileCopyrightText : 2025 Advanced Micro Devices, Inc. +// +// SPDX - License - Identifier : Apache 2.0 + +#include +#include +#include +#include +#include #include -// Constants for MI300 -constexpr uint32_t WARP_STEP_SIZE = 16; // 16 threads per warp row -constexpr uint32_t QUERY_ELEMS_PER_THREAD = - 4; // Each thread loads 4 fp16 elements -constexpr uint32_t WARP_THREAD_ROWS = 4; // 4 rows of threads in a warp +#include "flashinfer/attention/generic/default_prefill_params.cuh" +#include "flashinfer/attention/generic/prefill.cuh" +#include "flashinfer/attention/generic/variants.cuh" +#include "utils/cpu_reference_hip.h" +#include "utils/utils_hip.h" // vec_normal_ -// Simplified linear shared memory operations (CPU implementation) -template -uint32_t get_permuted_offset_linear(uint32_t row, uint32_t col) +namespace { - return row * stride + col; -} - -template -uint32_t advance_offset_by_column_linear(uint32_t offset, uint32_t step_idx) -{ - return offset + step_size; -} - -template -uint32_t advance_offset_by_row_linear(uint32_t offset) -{ - return offset + step_size * row_stride; -} - -// CPU-based offset pattern verification with configurable NUM_MMA_Q -template -void SimulateOffsetPattern(std::vector &thread_ids_at_offsets) -{ - // Constants derived from HEAD_DIM - constexpr uint32_t UPCAST_STRIDE_Q = HEAD_DIM / QUERY_ELEMS_PER_THREAD; - constexpr uint32_t NUM_MMA_D_QK = HEAD_DIM / 16; - constexpr uint32_t COLUMN_RESET_OFFSET = - (NUM_MMA_D_QK / 4) * WARP_STEP_SIZE; - constexpr uint32_t grid_width = - (HEAD_DIM / QUERY_ELEMS_PER_THREAD); // 16 for 64, 32 for 128 - constexpr uint32_t grid_height = - 16 * NUM_MMA_Q; // 16 for NUM_MMA_Q=1, 32 for NUM_MMA_Q=2 - - // Initialize with -1 (unwritten) - thread_ids_at_offsets.assign(grid_height * grid_width, -1); - - // Simulate each thread - for (uint32_t tid = 0; tid < 64; tid++) { - uint32_t row = tid / WARP_STEP_SIZE; // 0-3 for 64 threads - uint32_t col = tid % WARP_STEP_SIZE; // 0-15 - - // Calculate initial offset using linear addressing - uint32_t q_smem_offset_w = - get_permuted_offset_linear(row, col); - - // Main loop structure from load_q_global_smem - for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { - for (uint32_t j = 0; j < 4; ++j) { - // Calculate sequence index - const uint32_t seq_idx = row + mma_q * 16 + j; - - for (uint32_t mma_do = 0; mma_do < NUM_MMA_D_QK / 4; ++mma_do) { - // Record which thread wrote to this offset - if (q_smem_offset_w < grid_height * grid_width) - { // Safety check - thread_ids_at_offsets[q_smem_offset_w] = tid; - } - else { - printf("ERROR by tid: %d, offset: %d\n", tid, - q_smem_offset_w); - } - - // Advance to next column within same row - q_smem_offset_w = - advance_offset_by_column_linear( - q_smem_offset_w, mma_do); - } - - // Advance to next sequence (row) with adjustment back to first - // column - q_smem_offset_w = advance_offset_by_row_linear( - q_smem_offset_w) - - COLUMN_RESET_OFFSET; - } - } - } -} - -// Helper function to run the test with configurable NUM_MMA_Q -template void RunOffsetTest() +constexpr uint32_t qo_len = 64; +constexpr uint32_t num_qo_heads = 1; +constexpr uint32_t head_dim = 64; +} // namespace + +// CPU reference implementation that creates a Q matrix with a kNHD layout and +// initializes. +void initialize_cpu_q() { - constexpr uint32_t grid_width = - (HEAD_DIM / QUERY_ELEMS_PER_THREAD); // 16 for 64, 32 for 128 - constexpr uint32_t grid_height = - 16 * NUM_MMA_Q; // 16 for NUM_MMA_Q=1, 32 for NUM_MMA_Q=2 - - printf("\n=== Testing offset calculations with HEAD_DIM = %u, NUM_MMA_Q = " - "%u ===\n", - HEAD_DIM, NUM_MMA_Q); - - // Host array to store thread IDs at each offset - std::vector thread_ids(grid_height * grid_width, -1); - - // Run CPU simulation of offset pattern - SimulateOffsetPattern(thread_ids); - - // Print the grid of thread IDs (potentially truncated for readability) - printf("Thread IDs writing to each offset (%dx%d grid):\n", grid_height, - grid_width); - - // Column headers - printf(" "); - for (int c = 0; c < grid_width; c++) { - printf("%3d ", c); - if (c == 15 && grid_width > 16) - printf("| "); // Divider between first and second half - } - printf("\n +"); - for (int c = 0; c < grid_width; c++) { - printf("----"); - if (c == 15 && grid_width > 16) - printf("+"); // Divider between first and second half - } - printf("\n"); - - // Print quadrants with clear separation - for (int r = 0; r < grid_height; r++) { - printf("%2d | ", r); - for (int c = 0; c < grid_width; c++) { - int thread_id = thread_ids[r * grid_width + c]; - if (thread_id >= 0) { - printf("%3d ", thread_id); - } - else { - printf(" . "); // Dot for unwritten positions - } - if (c == 15 && grid_width > 16) - printf("| "); // Divider between first and second half - } - printf("\n"); - - // Add horizontal divider between first and second block of sequences - if (r == 15 && NUM_MMA_Q > 1) { - printf(" +"); - for (int c = 0; c < grid_width; c++) { - printf("----"); - if (c == 15 && grid_width > 16) - printf("+"); // Intersection divider - } - printf("\n"); - } - } - - // Check for unwritten positions - int unwritten = 0; - for (int i = 0; i < grid_height * grid_width; i++) { - if (thread_ids[i] == -1) { - unwritten++; - } - } - - // Print statistics - printf("\nStatistics:\n"); - printf("- Positions written: %d/%d (%.1f%%)\n", - grid_height * grid_width - unwritten, grid_height * grid_width, - 100.0f * (grid_height * grid_width - unwritten) / - (grid_height * grid_width)); - printf("- Unwritten positions: %d/%d (%.1f%%)\n", unwritten, - grid_height * grid_width, - 100.0f * unwritten / (grid_height * grid_width)); - - // Validate full coverage - EXPECT_EQ(unwritten, 0) << "Not all positions were written"; + std::vector q(qo_len * num_qo_heads * head_dim); + utils::vec_normal_(q); } -// Original tests with NUM_MMA_Q = 1 -TEST(MI300OffsetTest, HeadDim64_NumMmaQ1) { RunOffsetTest<64, 1>(); } +// Validates the original Q matrix on CPU with the copied over data from GPU. +// Ensures that the copied over data matches both the CDNA3 A-matrix layout and +// also validates with the original Q matrix. -TEST(MI300OffsetTest, HeadDim128_NumMmaQ1) { RunOffsetTest<128, 1>(); } +// GPU kernel that launches exactly one warp and calls prefill.cuh's +// load_q_global_smem to populate a LDS array from a global array. Then copies +// back the shared memory array to another output global array. -// New tests with NUM_MMA_Q = 2 -TEST(MI300OffsetTest, HeadDim64_NumMmaQ2) { RunOffsetTest<64, 2>(); } - -TEST(MI300OffsetTest, HeadDim128_NumMmaQ2) { RunOffsetTest<128, 2>(); } - -int main(int argc, char **argv) -{ - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} +// Laucher of GPU kernel. +// Copies the Q array from the CPU reference to GPU and then calls the kernel +// to copy from global to shared memory. diff --git a/libflashinfer/tests/hip/test_load_q_global_smem_v1.cpp b/libflashinfer/tests/hip/test_load_q_global_smem_v1.cpp new file mode 100644 index 0000000000..c9d8c3840a --- /dev/null +++ b/libflashinfer/tests/hip/test_load_q_global_smem_v1.cpp @@ -0,0 +1,193 @@ +#include +#include +#include +#include +#include + +// Constants for MI300 +constexpr uint32_t WARP_STEP_SIZE = 16; // 16 threads per warp row +constexpr uint32_t QUERY_ELEMS_PER_THREAD = + 4; // Each thread loads 4 fp16 elements +constexpr uint32_t WARP_THREAD_ROWS = 4; // 4 rows of threads in a warp + +// Simplified linear shared memory operations (CPU implementation) +template +uint32_t get_permuted_offset_linear(uint32_t row, uint32_t col) +{ + return row * stride + col; +} + +template +uint32_t advance_offset_by_column_linear(uint32_t offset, uint32_t step_idx) +{ + return offset + step_size; +} + +template +uint32_t advance_offset_by_row_linear(uint32_t offset) +{ + return offset + step_size * row_stride; +} + +// CPU-based offset pattern verification with configurable NUM_MMA_Q +template +void SimulateOffsetPattern(std::vector &thread_ids_at_offsets) +{ + // Constants derived from HEAD_DIM + constexpr uint32_t UPCAST_STRIDE_Q = HEAD_DIM / QUERY_ELEMS_PER_THREAD; + constexpr uint32_t NUM_MMA_D_QK = HEAD_DIM / 16; + constexpr uint32_t COLUMN_RESET_OFFSET = + (NUM_MMA_D_QK / 4) * WARP_STEP_SIZE; + constexpr uint32_t grid_width = + (HEAD_DIM / QUERY_ELEMS_PER_THREAD); // 16 for 64, 32 for 128 + constexpr uint32_t grid_height = + 16 * NUM_MMA_Q; // 16 for NUM_MMA_Q=1, 32 for NUM_MMA_Q=2 + + // Initialize with -1 (unwritten) + thread_ids_at_offsets.assign(grid_height * grid_width, -1); + + // Simulate each thread + for (uint32_t tid = 0; tid < 64; tid++) { + uint32_t row = tid / WARP_STEP_SIZE; // 0-3 for 64 threads + uint32_t col = tid % WARP_STEP_SIZE; // 0-15 + + // Calculate initial offset using linear addressing + uint32_t q_smem_offset_w = + get_permuted_offset_linear(row, col); + + // Main loop structure from load_q_global_smem + for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { + for (uint32_t j = 0; j < 4; ++j) { + // Calculate sequence index + const uint32_t seq_idx = row + mma_q * 16 + j; + + for (uint32_t mma_do = 0; mma_do < NUM_MMA_D_QK / 4; ++mma_do) { + // Record which thread wrote to this offset + if (q_smem_offset_w < grid_height * grid_width) + { // Safety check + thread_ids_at_offsets[q_smem_offset_w] = tid; + } + else { + printf("ERROR by tid: %d, offset: %d\n", tid, + q_smem_offset_w); + } + + // Advance to next column within same row + q_smem_offset_w = + advance_offset_by_column_linear( + q_smem_offset_w, mma_do); + } + + // Advance to next sequence (row) with adjustment back to first + // column + q_smem_offset_w = advance_offset_by_row_linear( + q_smem_offset_w) - + COLUMN_RESET_OFFSET; + } + } + } +} + +// Helper function to run the test with configurable NUM_MMA_Q +template void RunOffsetTest() +{ + constexpr uint32_t grid_width = + (HEAD_DIM / QUERY_ELEMS_PER_THREAD); // 16 for 64, 32 for 128 + constexpr uint32_t grid_height = + 16 * NUM_MMA_Q; // 16 for NUM_MMA_Q=1, 32 for NUM_MMA_Q=2 + + printf("\n=== Testing offset calculations with HEAD_DIM = %u, NUM_MMA_Q = " + "%u ===\n", + HEAD_DIM, NUM_MMA_Q); + + // Host array to store thread IDs at each offset + std::vector thread_ids(grid_height * grid_width, -1); + + // Run CPU simulation of offset pattern + SimulateOffsetPattern(thread_ids); + + // Print the grid of thread IDs (potentially truncated for readability) + printf("Thread IDs writing to each offset (%dx%d grid):\n", grid_height, + grid_width); + + // Column headers + printf(" "); + for (int c = 0; c < grid_width; c++) { + printf("%3d ", c); + if (c == 15 && grid_width > 16) + printf("| "); // Divider between first and second half + } + printf("\n +"); + for (int c = 0; c < grid_width; c++) { + printf("----"); + if (c == 15 && grid_width > 16) + printf("+"); // Divider between first and second half + } + printf("\n"); + + // Print quadrants with clear separation + for (int r = 0; r < grid_height; r++) { + printf("%2d | ", r); + for (int c = 0; c < grid_width; c++) { + int thread_id = thread_ids[r * grid_width + c]; + if (thread_id >= 0) { + printf("%3d ", thread_id); + } + else { + printf(" . "); // Dot for unwritten positions + } + if (c == 15 && grid_width > 16) + printf("| "); // Divider between first and second half + } + printf("\n"); + + // Add horizontal divider between first and second block of sequences + if (r == 15 && NUM_MMA_Q > 1) { + printf(" +"); + for (int c = 0; c < grid_width; c++) { + printf("----"); + if (c == 15 && grid_width > 16) + printf("+"); // Intersection divider + } + printf("\n"); + } + } + + // Check for unwritten positions + int unwritten = 0; + for (int i = 0; i < grid_height * grid_width; i++) { + if (thread_ids[i] == -1) { + unwritten++; + } + } + + // Print statistics + printf("\nStatistics:\n"); + printf("- Positions written: %d/%d (%.1f%%)\n", + grid_height * grid_width - unwritten, grid_height * grid_width, + 100.0f * (grid_height * grid_width - unwritten) / + (grid_height * grid_width)); + printf("- Unwritten positions: %d/%d (%.1f%%)\n", unwritten, + grid_height * grid_width, + 100.0f * unwritten / (grid_height * grid_width)); + + // Validate full coverage + EXPECT_EQ(unwritten, 0) << "Not all positions were written"; +} + +// Original tests with NUM_MMA_Q = 1 +TEST(MI300OffsetTest, HeadDim64_NumMmaQ1) { RunOffsetTest<64, 1>(); } + +TEST(MI300OffsetTest, HeadDim128_NumMmaQ1) { RunOffsetTest<128, 1>(); } + +// New tests with NUM_MMA_Q = 2 +TEST(MI300OffsetTest, HeadDim64_NumMmaQ2) { RunOffsetTest<64, 2>(); } + +TEST(MI300OffsetTest, HeadDim128_NumMmaQ2) { RunOffsetTest<128, 2>(); } + +int main(int argc, char **argv) +{ + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/libflashinfer/tests/hip/test_load_q_global_smem_kernel.cpp b/libflashinfer/tests/hip/test_load_q_global_smem_v2.cpp similarity index 100% rename from libflashinfer/tests/hip/test_load_q_global_smem_kernel.cpp rename to libflashinfer/tests/hip/test_load_q_global_smem_v2.cpp diff --git a/libflashinfer/tests/hip/test_single_prefill.cpp b/libflashinfer/tests/hip/test_single_prefill.cpp index 44a3b1825e..23749dad3e 100644 --- a/libflashinfer/tests/hip/test_single_prefill.cpp +++ b/libflashinfer/tests/hip/test_single_prefill.cpp @@ -556,7 +556,7 @@ int main(int argc, char **argv) size_t num_heads = 1; size_t head_dim = 64; bool causal = false; - size_t pos_encoding_mode = 0; + size_t pos_encoding_mode = 1; // 1 == kRopeLLama size_t kv_layout = 0; _TestSinglePrefillKernelCorrectness( diff --git a/libflashinfer/utils/cpu_reference_hip.h b/libflashinfer/utils/cpu_reference_hip.h index 5a473408ca..bd32a04a77 100644 --- a/libflashinfer/utils/cpu_reference_hip.h +++ b/libflashinfer/utils/cpu_reference_hip.h @@ -57,6 +57,57 @@ exclusive_prefix_sum(const T *input, size_t batch_size, size_t d) return std::move(output); } +template +inline std::vector apply_llama_rope_debug(const T *input, + size_t D, + size_t offset, + float rope_scale, + float rope_theta) +{ + std::vector rst(D); + std::vector permuted_input(D); + // Print the input parameters + // Only print for first position to avoid flood + if (offset == 134) { // First position in your log + std::cout << "=== CPU ROPE DEBUG ===\n"; + std::cout << "D: " << D << ", offset: " << offset + << ", rope_scale: " << rope_scale + << ", rope_theta: " << rope_theta << std::endl; + + std::cout << "CPU Frequencies vs GPU comparison:\n"; + for (size_t k = 0; k < min(4ul, D); ++k) { + float freq_base = float(2 * (k % (D / 2))) / float(D); + float frequency = + 1.0f / std::pow(rope_theta, freq_base); // This should match GPU + float angle = + (offset / rope_scale) / std::pow(rope_theta, freq_base); + + std::cout << "CPU: feature[" << k << "] freq_base=" << freq_base + << " frequency=" << frequency << " angle=" << angle + << std::endl; + } + } + + for (size_t k = 0; k < D; ++k) { + permuted_input[k] = + (k < D / 2) ? -fi::con::explicit_casting(input[k + D / 2]) + : fi::con::explicit_casting(input[k - D / 2]); + } + + for (size_t k = 0; k < D; ++k) { + float inv_freq = + (offset / rope_scale) / + (std::pow(rope_theta, float(2 * (k % (D / 2))) / float(D))); + float cos = std::cos(inv_freq); + float sin = std::sin(inv_freq); + + if (std::is_same_v) + rst[k] = cos * fi::con::explicit_casting(input[k]) + + sin * permuted_input[k]; + } + return rst; +} + template inline std::vector apply_llama_rope(const T *input, size_t D, @@ -164,22 +215,45 @@ single_mha(const std::vector &q, tensor_info_t info(qo_len, kv_len, num_qo_heads, num_kv_heads, kv_layout, HEAD_DIM); #if Debug - // std::cout << "DEBUG Q (CPU): " << '\n'; - // for (auto i = 0ul; i < 64; ++i) { - // // q[info.get_q_elem_offset(q_idx, qo_head_idx, feat_idx) - // std::cout << (float)q[info.get_q_elem_offset(0, 0, i)] << " "; + std::cout << "DEBUG: Original Q (CPU): " << '\n'; + for (auto i = 0ul; i < 4; ++i) { + // q[info.get_q_elem_offset(q_idx, qo_head_idx, feat_idx) + std::cout << (float)q[info.get_q_elem_offset(0, 0, i)] << " "; + } + std::cout << std::endl; + + // std::cout << "DEBUG K (CPU): " << '\n'; + // for (auto j = 0ul; j < 16; ++j) { + // for (auto i = 0ul; i < 64; ++i) { + // // k[info.get_kv_elem_offset(kv_idx, kv_head_idx, feat_idx) + // // std::cout << (float)k[info.get_kv_elem_offset(15, 0, j * 4 + // + + // // i)] + // std::cout << (float)k[info.get_kv_elem_offset(j, 0, i)] << " + // "; + // } + // std::cout << '\n'; // } // std::cout << std::endl; - - std::cout << "DEBUG K (CPU): " << '\n'; - for (auto j = 0ul; j < 16; ++j) { - for (auto i = 0ul; i < 64; ++i) { - // k[info.get_kv_elem_offset(kv_idx, kv_head_idx, feat_idx) - // std::cout << (float)k[info.get_kv_elem_offset(15, 0, j * 4 + - // i)] - std::cout << (float)k[info.get_kv_elem_offset(j, 0, i)] << " "; + std::cout << "num_qo_heads " << num_qo_heads << '\n'; + std::cout << "qo_len " << qo_len << '\n'; + for (size_t qo_head_idx = 0; qo_head_idx < num_qo_heads; ++qo_head_idx) + { + for (size_t q_idx = 0; q_idx < qo_len; ++q_idx) { + q_rotary_local = + std::move(cpu_reference::apply_llama_rope_debug( + q.data() + + info.get_q_elem_offset(q_idx, qo_head_idx, 0), + head_dim, q_idx + kv_len - qo_len, rope_scale, + rope_theta)); } - std::cout << '\n'; + } + + std::cout << "DEBUG: LLAMA Rope Transformed Q (CPU): " << '\n'; + for (auto i = 0ul; i < 4; ++i) { + // q[info.get_q_elem_offset(q_idx, qo_head_idx, feat_idx) + std::cout << (float)q_rotary_local[info.get_q_elem_offset(0, 0, i)] + << " "; } std::cout << std::endl; #endif From 07a7e6487ba426cbe9dfdb044faf8878fb6d2145 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Wed, 27 Aug 2025 08:17:31 -0400 Subject: [PATCH 054/109] Debug llama --- .../flashinfer/attention/generic/prefill.cuh | 2 +- .../tests/hip/test_apply_llama_rope.cpp | 177 ++++++++++++++++++ 2 files changed, 178 insertions(+), 1 deletion(-) diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh index a8226d9da2..64f24ab475 100644 --- a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -709,7 +709,7 @@ __device__ __forceinline__ uint32_t get_feature_index(uint32_t mma_d, // row 0 --> T0, T1, T2, T3 // row 1 --> T4, T5, T6, T7 // ... - // row 7 --> T28, T29, T30, T31. + // row 7 --> T28, T29, T30, T31 // The full data to thread mapping repeats again for the next set of 16 // rows. Thereby, forming a 16x16 MMA tile dubdivided into four 8x8 // quadrants. diff --git a/libflashinfer/tests/hip/test_apply_llama_rope.cpp b/libflashinfer/tests/hip/test_apply_llama_rope.cpp index f6cc629c39..f6a800766d 100644 --- a/libflashinfer/tests/hip/test_apply_llama_rope.cpp +++ b/libflashinfer/tests/hip/test_apply_llama_rope.cpp @@ -5,7 +5,9 @@ #include "../../utils/cpu_reference_hip.h" #include "../../utils/utils_hip.h" #include "flashinfer/attention/generic/prefill.cuh" +#include "gpu_iface/fastdiv.cuh" #include "gpu_iface/gpu_runtime_compat.hpp" + #include #include #include @@ -57,6 +59,86 @@ __global__ void test_init_rope_freq_kernel(float *output_freq, } } +template +__global__ void +test_q_frag_apply_llama_rope_kernel(__half *q_input, + __half *q_output, + uint32_t qo_len, + uint32_t num_qo_heads, + uint32_t kv_len, + float rope_rcp_scale, + float rope_rcp_theta, + flashinfer::uint_fastdiv group_size_fastdiv) +{ + using KTraits = TestKernelTraits; + constexpr uint32_t HALF_ELEMS_PER_THREAD = 4; + constexpr uint32_t INT32_ELEMS_PER_THREAD = 2; + constexpr uint32_t NUM_MMA_D_QK = HEAD_DIM / 16; + + float rope_freq[KTraits::NUM_MMA_D_VO / 2][4]; + flashinfer::init_rope_freq(rope_freq, rope_rcp_scale, + rope_rcp_theta, threadIdx.x); + + const uint32_t lane_idx = threadIdx.x; + const uint32_t warp_idx = blockIdx.x; + + // TODO: Need to check that qo_len is evenly divisible by 16. + for (uint32_t qo_head_idx = 0; qo_head_idx < num_qo_heads; ++qo_head_idx) { + for (uint32_t seq_chunk = 0; seq_chunk < qo_len; seq_chunk += 16) { + + uint32_t seq_idx = seq_chunk + (lane_idx % 16); + if (seq_idx >= qo_len) + continue; + + uint32_t abs_position = seq_idx + kv_len - qo_len; + // Each iteration processes 16*2=32 features (first_half + + // second_half) + for (uint32_t feat_chunk = 0; feat_chunk < NUM_MMA_D_QK / 2; + ++feat_chunk) + { + uint32_t feat_offset_first = feat_chunk * 32; + uint32_t feat_offset_second = feat_offset_first + HEAD_DIM / 2; + + // Load fragments from global memory + __half q_frag_first[HALF_ELEMS_PER_THREAD]; + __half q_frag_second[HALF_ELEMS_PER_THREAD]; + + // Calculate base address for this sequence and head + uint32_t base_offset = qo_head_idx * HEAD_DIM + + seq_idx * (num_qo_heads * HEAD_DIM); + + // Load first half (4 consecutive features per thread) + for (uint32_t i = 0; i < HALF_ELEMS_PER_THREAD; ++i) { + uint32_t feat_idx1 = + flashinfer::get_feature_index(feat_chunk, + lane_idx, i); + uint32_t feat_idx2 = feat_idx1 + HEAD_DIM / 2; + q_frag_first[i] = *(q_input + base_offset + feat_idx1); + q_frag_second[i] = *(q_input + base_offset + feat_idx2); + } + + // Apply RoPE using the validated function + uint32_t mma_di = feat_chunk; + flashinfer::q_frag_apply_llama_rope<__half, + HALF_ELEMS_PER_THREAD>( + q_frag_first, q_frag_second, + rope_freq[mma_di % (KTraits::NUM_MMA_D_VO / 2)], + abs_position, group_size_fastdiv); + + // Store results back to global memory + for (uint32_t i = 0; i < HALF_ELEMS_PER_THREAD; ++i) { + uint32_t feat_idx1 = + flashinfer::get_feature_index(feat_chunk, + lane_idx, i); + uint32_t feat_idx2 = feat_idx1 + HEAD_DIM / 2; + *(q_output + base_offset + feat_idx1) = q_frag_first[i]; + *(q_output + base_offset + feat_idx2) = q_frag_second[i]; + } + } + } + } +} + template class LLamaRopeTestFixture : public ::testing::TestWithParam { @@ -166,6 +248,56 @@ class LLamaRopeTestFixture : public ::testing::TestWithParam return results; } + + std::vector test_gpu_q_frag_apply_rope(size_t kv_len = 1000, + float rope_scale = 1.0f, + float rope_theta = 10000.0f) + { + // Convert to reciprocal values + float rope_rcp_scale = 1.0f / rope_scale; + float rope_rcp_theta = 1.0f / rope_theta; + uint32_t group_size = 1; // Simple case for now + + // Allocate GPU memory for input and output + __half *d_q_input, *d_q_output; + size_t q_size = q.size() * sizeof(__half); + + FI_GPU_CALL(hipMalloc(&d_q_input, q_size)); + FI_GPU_CALL(hipMalloc(&d_q_output, q_size)); + + // Copy input Q to GPU + FI_GPU_CALL( + hipMemcpy(d_q_input, q.data(), q_size, hipMemcpyHostToDevice)); + FI_GPU_CALL(hipMemset(d_q_output, 0, q_size)); + + // Launch kernel - one block with 64 threads + dim3 grid(1); // Single block for simplicity + dim3 block(64); // CDNA3 wavefront size + + if (head_dim == 64) { + test_q_frag_apply_llama_rope_kernel<64><<>>( + d_q_input, d_q_output, qo_len, num_qo_heads, kv_len, + rope_rcp_scale, rope_rcp_theta, group_size); + } + + FI_GPU_CALL(hipDeviceSynchronize()); + + // Copy results back to CPU + std::vector<__half> gpu_output(q.size()); + FI_GPU_CALL(hipMemcpy(gpu_output.data(), d_q_output, q_size, + hipMemcpyDeviceToHost)); + + // Convert to float for comparison + std::vector result(head_dim); + for (size_t i = 0; i < head_dim; ++i) { + result[i] = float(gpu_output[i]); // First sequence, first head + } + + FI_GPU_CALL(hipFree(d_q_input)); + FI_GPU_CALL(hipFree(d_q_output)); + + return result; + } }; using LLamaRopeTestWithFP16 = LLamaRopeTestFixture<__half>; @@ -222,6 +354,51 @@ TEST_P(LLamaRopeTestWithFP16, VectorSizeIsCorrect) ASSERT_EQ(this->q.size(), expected_size); } +TEST_P(LLamaRopeTestWithFP16, TestQFragApplyRopeComparison) +{ + constexpr float RELATIVE_EPSILON = 1e-3f; + + auto cpu_result = this->apply_cpu_rope(744); + auto gpu_result = this->test_gpu_q_frag_apply_rope(); + + std::cout << "\n=== CPU vs GPU RoPE Application Comparison ===\n"; + std::cout << "CPU result (offset=1000, first 8 features): "; + for (size_t i = 0; i < std::min(8u, this->head_dim); ++i) { + std::cout << cpu_result[i] << " "; + } + std::cout << std::endl; + + std::cout << "GPU result (offset=1000, first 8 features): "; + for (size_t i = 0; i < std::min(8u, this->head_dim); ++i) { + std::cout << gpu_result[i] << " "; + } + std::cout << std::endl; + + // Compare element by element + size_t num_mismatches = 0; + for (size_t i = 0; i < std::min(cpu_result.size(), gpu_result.size()); ++i) + { + float diff = std::abs(cpu_result[i] - gpu_result[i]); + float rel_diff = (std::abs(cpu_result[i]) > 1e-6f) + ? diff / std::abs(cpu_result[i]) + : diff; + + if (rel_diff > RELATIVE_EPSILON) { + std::cout << "Mismatch at feature " << i + << ": CPU=" << cpu_result[i] << " GPU=" << gpu_result[i] + << " diff=" << diff << " rel_diff=" << rel_diff + << std::endl; + ++num_mismatches; + } + } + + std::cout << "Total mismatches: " << num_mismatches << " out of " + << head_dim << std::endl; + + EXPECT_EQ(num_mismatches, 0) + << "Found mismatches between CPU and GPU RoPE application"; +} + INSTANTIATE_TEST_SUITE_P( LLamaRopeTestWithFP16, LLamaRopeTestWithFP16, From e337e288bf35ad27afccb1a56fdaebaf1efa47e6 Mon Sep 17 00:00:00 2001 From: rtmadduri Date: Thu, 28 Aug 2025 05:31:56 +0000 Subject: [PATCH 055/109] utils --- libflashinfer/utils/utils_hip.h | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/libflashinfer/utils/utils_hip.h b/libflashinfer/utils/utils_hip.h index fee6b92baf..6ca5e7b85f 100644 --- a/libflashinfer/utils/utils_hip.h +++ b/libflashinfer/utils/utils_hip.h @@ -60,6 +60,31 @@ namespace utils { +enum Predicate +{ + Linear, + Ones, + Zeros, +}; + +template void generate_data(std::vector &vec) +{ + if constexpr (Pred == Predicate::Linear) { + assert(vec.size() <= 0); + for (int i = 0; i < vec.size(); i++) { + vec[i] = fi::con::explicit_casting(static_cast(i)); + } + } + + else if constexpr (Pred == Predicate::Ones) { + vec_fill_(vec, fi::con::explicit_casting(1.0f)); + } + + else if constexpr (Pred == Predicate::Zeros) { + vec_zero_(vec); + } +} + template void vec_normal_(std::vector &vec, float mean = 0.f, float std = 1.f) { From 5b9daa56505d7883a61ca91def9a21fd08b722ce Mon Sep 17 00:00:00 2001 From: rtmadduri Date: Thu, 28 Aug 2025 05:33:48 +0000 Subject: [PATCH 056/109] testing harness for compute_sfm --- .../attention/generic/prefill_tester.cuh | 2234 +++++++++++++++++ libflashinfer/tests/hip/test_compute_sfm.cpp | 234 ++ .../utils/flashinfer_prefill_ops.hip.h | 124 +- 3 files changed, 2532 insertions(+), 60 deletions(-) create mode 100644 libflashinfer/include/flashinfer/attention/generic/prefill_tester.cuh create mode 100644 libflashinfer/tests/hip/test_compute_sfm.cpp diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill_tester.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill_tester.cuh new file mode 100644 index 0000000000..93903d0d4a --- /dev/null +++ b/libflashinfer/include/flashinfer/attention/generic/prefill_tester.cuh @@ -0,0 +1,2234 @@ +// SPDX - FileCopyrightText : 2023-2025 FlashInfer team. +// SPDX - FileCopyrightText : 2025 Advanced Micro Devices, Inc. +// +// SPDX - License - Identifier : Apache - 2.0 +#ifndef FLASHINFER_PREFILL_CUH_ +#define FLASHINFER_PREFILL_CUH_ + +#include "gpu_iface/cooperative_groups.h" +#include "gpu_iface/fastdiv.cuh" +#include "gpu_iface/math_ops.hpp" +#include "gpu_iface/memory_ops.hpp" +#include "gpu_iface/mma_ops.hpp" +#include "gpu_iface/platform.hpp" +#include "gpu_iface/utils.cuh" + +#ifdef FP16_QK_REDUCTION_SUPPORTED +#include "../../fp16.h" +#endif +#include "frag_layout_swizzle.cuh" + +#include "cascade.cuh" +#include "dispatch.cuh" +#include "page.cuh" +#include "permuted_smem.cuh" +#include "pos_enc.cuh" +#include "variants.cuh" + +namespace flashinfer +{ + +DEFINE_HAS_MEMBER(maybe_q_rope_offset) +DEFINE_HAS_MEMBER(maybe_k_rope_offset) + +namespace cg = flashinfer::gpu_iface::cg; +namespace memory = flashinfer::gpu_iface::memory; +namespace mma = gpu_iface::mma; + +using gpu_iface::vec_dtypes::vec_cast; +using mma::MMAMode; + +constexpr uint32_t WARP_SIZE = gpu_iface::kWarpSize; + +constexpr uint32_t get_num_warps_q(const uint32_t cta_tile_q) +{ + if (cta_tile_q > 16) { + return 4; + } + else { + return 1; + } +} + +constexpr uint32_t get_num_warps_kv(const uint32_t cta_tile_kv) +{ + return 4 / get_num_warps_q(cta_tile_kv); +} + +constexpr uint32_t get_num_mma_q(const uint32_t cta_tile_q) +{ + if (cta_tile_q > 64) { + return 2; + } + else { + return 1; + } +} + +template +struct SharedStorageQKVO +{ + union + { + struct + { + alignas(16) DTypeQ q_smem[CTA_TILE_Q * HEAD_DIM_QK]; + alignas(16) DTypeKV k_smem[CTA_TILE_KV * HEAD_DIM_QK]; + alignas(16) DTypeKV v_smem[CTA_TILE_KV * HEAD_DIM_VO]; + }; + struct + { // NOTE(Zihao): synchronize attention states across warps + alignas(16) std::conditional_t< + NUM_WARPS_KV == 1, + float[1], + float[NUM_WARPS_KV * CTA_TILE_Q * HEAD_DIM_VO]> cta_sync_o_smem; + alignas(16) std::conditional_t< + NUM_WARPS_KV == 1, + float2[1], + float2[NUM_WARPS_KV * CTA_TILE_Q]> cta_sync_md_smem; + }; + alignas(16) DTypeO smem_o[CTA_TILE_Q * HEAD_DIM_VO]; + }; +}; + +template +struct KernelTraits +{ + static constexpr MaskMode MASK_MODE = MASK_MODE_; + static constexpr uint32_t NUM_MMA_Q = NUM_MMA_Q_; + static constexpr uint32_t NUM_MMA_KV = NUM_MMA_KV_; + static constexpr uint32_t NUM_MMA_D_QK = NUM_MMA_D_QK_; + static constexpr uint32_t NUM_MMA_D_VO = NUM_MMA_D_VO_; + static constexpr uint32_t NUM_WARPS_Q = NUM_WARPS_Q_; + static constexpr uint32_t NUM_WARPS_KV = NUM_WARPS_KV_; + static constexpr uint32_t NUM_WARPS = NUM_WARPS_Q * NUM_WARPS_KV; + static constexpr uint32_t HEAD_DIM_QK = NUM_MMA_D_QK * 16; + static constexpr uint32_t HEAD_DIM_VO = NUM_MMA_D_VO * 16; + static constexpr uint32_t CTA_TILE_Q = CTA_TILE_Q_; + static constexpr uint32_t CTA_TILE_KV = NUM_MMA_KV * NUM_WARPS_KV * 16; + static constexpr PosEncodingMode POS_ENCODING_MODE = POS_ENCODING_MODE_; + + using DTypeQ = DTypeQ_; + using DTypeKV = DTypeKV_; + using DTypeO = DTypeO_; + using DTypeQKAccum = DTypeQKAccum_; + using IdType = IdType_; + using AttentionVariant = AttentionVariant_; + + static_assert(sizeof(DTypeKV_) != 1, "8-bit types not supported for CDNA3"); + + using SmemBasePtrTy = uint2; + static constexpr uint32_t NUM_THREADS = NUM_WARPS_Q * NUM_WARPS_KV * 64; + static constexpr uint32_t WARP_THREAD_ROWS = 4; + static constexpr uint32_t WARP_THREAD_COLS = 16; + static constexpr uint32_t HALF_ELEMS_PER_THREAD = 4; + static constexpr uint32_t INT32_ELEMS_PER_THREAD = 2; + static constexpr uint32_t VECTOR_BIT_WIDTH = HALF_ELEMS_PER_THREAD * 16; + // FIXME: Update with a proper swizzle pattern. Linear is used primarily + // for intial testing. + static constexpr SwizzleMode SWIZZLE_MODE_Q = SwizzleMode::kLinear; + static constexpr SwizzleMode SWIZZLE_MODE_KV = SwizzleMode::kLinear; + + // Presently we use 16x4 thread layout for all cases. + static constexpr uint32_t KV_THR_LAYOUT_ROW = WARP_THREAD_ROWS; + static constexpr uint32_t KV_THR_LAYOUT_COL = WARP_THREAD_COLS; + // The constant is defined based on the matrix layout of the "D/C" + // accumulator matrix in a D = A*B+C computation. On CDNA3 the D/C matrices + // are distributed as four 4x16 bands across the 64 threads. Each thread + // owns one element from four different rows. + static constexpr uint32_t NUM_ACCUM_ROWS_PER_THREAD = 4; + // Number of threads that collaboratively handle the same set of matrix rows + // in attention score computation and cross-warp synchronization. + // CUDA: 4 threads (each thread handles 2 elements from same row group) + // CDNA3: 16 threads (each thread handles 1 element from same row group) + static constexpr uint32_t THREADS_PER_MATRIX_ROW_SET = 16; + // controls the indexing stride used in logits-related functions + // (logits_transform, logits_mask, and LSE writing). + static constexpr uint32_t LOGITS_INDEX_STRIDE = 4; + + static constexpr uint32_t UPCAST_STRIDE_Q = + HEAD_DIM_QK / upcast_size(); + static constexpr uint32_t UPCAST_STRIDE_K = + HEAD_DIM_QK / upcast_size(); + static constexpr uint32_t UPCAST_STRIDE_V = + HEAD_DIM_VO / upcast_size(); + static constexpr uint32_t UPCAST_STRIDE_O = + HEAD_DIM_VO / upcast_size(); + + static constexpr bool IsInvalid() + { + return ((NUM_MMA_D_VO < 4) || + (NUM_MMA_D_VO == 4 && NUM_MMA_KV % 2 == 1) || + (POS_ENCODING_MODE == PosEncodingMode::kRoPELlama && + NUM_MMA_D_VO > 4 && NUM_MMA_D_VO % (2 * NUM_WARPS_Q) != 0) || + (NUM_MMA_Q * (8 * NUM_MMA_D_VO + + 2 * sizeof(DTypeQKAccum) * NUM_MMA_KV) >= + 256) || + (sizeof(DTypeKV) == 1 && NUM_MMA_KV * 2 % NUM_WARPS_Q != 0) || + (sizeof(DTypeKV) == 1 && + POS_ENCODING_MODE == PosEncodingMode::kRoPELlama)); + } + + using SharedStorage = SharedStorageQKVO; +#ifdef FP16_QK_REDUCTION_SUPPORTED + template static constexpr DT getNegInf() + { + if constexpr (std::is_same::value) { + return std::bit_cast( + fp16_ieee_from_fp32_value(-gpu_iface::math::inf)); + } + else { + return static_cast(-gpu_iface::math::inf); + } + } + + static constexpr DTypeQKAccum MaskFillValue = + AttentionVariant::use_softmax ? getNegInf() + : DTypeQKAccum(0.f); +#else + static_assert(!std::is_same::value, + "Set -DFP16_QK_REDUCTION_SUPPORTED and install boost_math " + "then recompile to support fp16 reduction"); + static constexpr DTypeQKAccum MaskFillValue = + AttentionVariant::use_softmax ? DTypeQKAccum(-gpu_iface::math::inf) + : DTypeQKAccum(0.f); +#endif +}; + +namespace +{ + +template +__device__ __forceinline__ uint32_t +get_warp_idx_q(const uint32_t tid_y = threadIdx.y) +{ + if constexpr (KTraits::NUM_WARPS_Q == 1) { + return 0; + } + else { + return tid_y; + } +} + +template +__device__ __forceinline__ uint32_t +get_warp_idx_kv(const uint32_t tid_z = threadIdx.z) +{ + if constexpr (KTraits::NUM_WARPS_KV == 1) { + return 0; + } + else { + return tid_z; + } +} + +template +__device__ __forceinline__ uint32_t +get_warp_idx(const uint32_t tid_y = threadIdx.y, + const uint32_t tid_z = threadIdx.z) +{ + return get_warp_idx_kv(tid_z) * KTraits::NUM_WARPS_Q + + get_warp_idx_q(tid_y); +} + +/*! + * \brief Apply Llama style rotary embedding to two 16x16 fragments. + * \tparam T The data type of the input fragments. + * \param x_first_half First fragment x[offset:offset+16, j*16:(j+1)*16] + * \param x_second_half Second fragment x[offset:offset*16, + * j*16+d/2:(j+1)*16+d/2] + * \param rope_freq Rope frequency + * \param offset The offset of the first row in both fragments. + * \note The sin/cos computation is slow, especially for A100 GPUs which has low + * non tensor-ops flops, will optimize in the future. + */ +template +__device__ __forceinline__ void +k_frag_apply_llama_rope(T *x_first_half, + T *x_second_half, + const float *rope_freq, + const uint32_t kv_offset) +{ + static_assert(sizeof(T) == 2); +#pragma unroll + for (uint32_t reg_id = 0; reg_id < HALF_ELEMS_PER_THREAD; ++reg_id) { + float cos, sin, tmp; + // 0 1 | 2 3 + // --------- + // 4 5 | 6 7 + + uint32_t i = reg_id / 2, j = reg_id % 2; + __sincosf(float(kv_offset + 8 * i) * rope_freq[2 * j + reg_id % 2], + &sin, &cos); + tmp = x_first_half[reg_id]; + x_first_half[reg_id] = (tmp * cos - (float)x_second_half[reg_id] * sin); + x_second_half[reg_id] = + ((float)x_second_half[reg_id] * cos + tmp * sin); + } +} + +template +__device__ __forceinline__ void +q_frag_apply_llama_rope(T *x_first_half, + T *x_second_half, + const float *rope_freq, + const uint32_t qo_packed_offset, + const uint_fastdiv group_size) +{ +#pragma unroll + for (uint32_t reg_id = 0; reg_id < HALF_ELEMS_PER_THREAD; ++reg_id) { + float cos, sin, tmp; + // 0 1 | 4 5 + // --------- + // 2 3 | 6 7 + + // // Same sequence for all 4 features + // uint32_t i = 0; + // Direct mapping to frequency array + uint32_t freq_idx = reg_id; + // Same position for this thread's sequence + uint32_t position = qo_packed_offset; + + __sincosf(float(position / group_size) * rope_freq[freq_idx], &sin, + &cos); + tmp = x_first_half[reg_id]; + x_first_half[reg_id] = (tmp * cos - (float)x_second_half[reg_id] * sin); + x_second_half[reg_id] = + ((float)x_second_half[reg_id] * cos + tmp * sin); + } +} + +template +__device__ __forceinline__ void +q_frag_apply_llama_rope_with_pos(T *x_first_half, + T *x_second_half, + const float *rope_freq, + const uint32_t qo_packed_offset, + const uint_fastdiv group_size, + const IdType *q_rope_offset) +{ + float pos[2] = { + static_cast(q_rope_offset[qo_packed_offset / group_size]), + static_cast(q_rope_offset[(qo_packed_offset + 8) / group_size])}; +#pragma unroll + for (uint32_t reg_id = 0; reg_id < HALF_ELEMS_PER_THREAD; ++reg_id) { + float cos, sin, tmp; + // 0 1 | 4 5 + // --------- + // 2 3 | 6 7 + + const uint32_t i = reg_id / 2; + const uint32_t j = reg_id % 2; + + __sincosf(pos[i] * rope_freq[2 * j + reg_id % 2], &sin, &cos); + tmp = x_first_half[reg_id]; + x_first_half[reg_id] = (tmp * cos - (float)x_second_half[reg_id] * sin); + x_second_half[reg_id] = + ((float)x_second_half[reg_id] * cos + tmp * sin); + } +} + +template +__device__ __forceinline__ void produce_kv_impl_cuda_( + uint32_t warp_idx, + uint32_t lane_idx, + smem_t smem, + uint32_t *smem_offset, + typename KTraits::DTypeKV **gptr, + const uint32_t stride_n, + const uint32_t kv_idx_base, + const uint32_t kv_len) +{ + using DTypeKV = typename KTraits::DTypeKV; + constexpr uint32_t NUM_WARPS = KTraits::NUM_WARPS; + constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; + constexpr uint32_t NUM_WARPS_Q = KTraits::NUM_WARPS_Q; + constexpr uint32_t NUM_MMA_D = + produce_v ? KTraits::NUM_MMA_D_VO : KTraits::NUM_MMA_D_QK; + constexpr uint32_t UPCAST_STRIDE = + produce_v ? KTraits::UPCAST_STRIDE_V : KTraits::UPCAST_STRIDE_K; + constexpr uint32_t VECTOR_BIT_WIDTH = KTraits::VECTOR_BIT_WIDTH; + + if constexpr (KTraits::SWIZZLE_MODE_KV == SwizzleMode::k128B) { + uint32_t kv_idx = kv_idx_base + warp_idx * 4 + lane_idx / 8; + // NOTE: NUM_MMA_KV * 4 / NUM_WARPS_Q = NUM_WARPS_KV * NUM_MMA_KV * 4 / + // num_warps + static_assert(NUM_MMA_KV * 4 % NUM_WARPS_Q == 0); +#pragma unroll + for (uint32_t i = 0; i < NUM_MMA_KV * 4 / NUM_WARPS_Q; ++i) { +#pragma unroll + for (uint32_t j = 0; j < NUM_MMA_D / (8 / sizeof(DTypeKV)); ++j) { + smem.template load_128b_async(*smem_offset, *gptr, + kv_idx < kv_len); + *smem_offset = + smem.template advance_offset_by_column<8>(*smem_offset, j); + *gptr += 8 * upcast_size(); + } + kv_idx += NUM_WARPS * 4; + *smem_offset = smem.template advance_offset_by_row( + *smem_offset) - + sizeof(DTypeKV) * NUM_MMA_D; + *gptr += NUM_WARPS * 4 * stride_n - + sizeof(DTypeKV) * NUM_MMA_D * + upcast_size(); + } + *smem_offset -= KTraits::CTA_TILE_KV * UPCAST_STRIDE; + } + else { + uint32_t kv_idx = kv_idx_base + warp_idx * 8 + lane_idx / 4; + // NOTE: NUM_MMA_KV * 2 / NUM_WARPS_Q = NUM_WARPS_KV * NUM_MMA_KV * 2 / + // num_warps + static_assert(NUM_MMA_KV * 2 % NUM_WARPS_Q == 0); +#pragma unroll + for (uint32_t i = 0; i < NUM_MMA_KV * 2 / NUM_WARPS_Q; ++i) { + smem.template load_128b_async(*smem_offset, *gptr, + kv_idx < kv_len); + *smem_offset = smem.template advance_offset_by_row( + *smem_offset); + kv_idx += NUM_WARPS * 8; + *gptr += NUM_WARPS * 8 * stride_n; + } + *smem_offset -= KTraits::CTA_TILE_KV * UPCAST_STRIDE; + } +} + +template +__device__ __forceinline__ void produce_kv_impl_cdna3_( + uint32_t warp_idx, + uint32_t lane_idx, + smem_t smem, + uint32_t *smem_offset, + typename KTraits::DTypeKV **gptr, + const uint32_t stride_n, + const uint32_t kv_idx_base, + const uint32_t kv_len) +{ + static_assert(KTraits::SWIZZLE_MODE_KV == SwizzleMode::kLinear); + using DTypeKV = typename KTraits::DTypeKV; + constexpr uint32_t KV_THR_LAYOUT_COL = KTraits::KV_THR_LAYOUT_COL; // 16 + constexpr uint32_t KV_THR_LAYOUT_ROW = KTraits::KV_THR_LAYOUT_ROW; // 4 + constexpr uint32_t NUM_WARPS = KTraits::NUM_WARPS; + constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; + constexpr uint32_t NUM_WARPS_Q = KTraits::NUM_WARPS_Q; + constexpr uint32_t NUM_MMA_D = + produce_v ? KTraits::NUM_MMA_D_VO : KTraits::NUM_MMA_D_QK; + constexpr uint32_t UPCAST_STRIDE = + produce_v ? KTraits::UPCAST_STRIDE_V : KTraits::UPCAST_STRIDE_K; + constexpr uint32_t VECTOR_BIT_WIDTH = KTraits::VECTOR_BIT_WIDTH; + constexpr uint32_t HALF_ELEMS_PER_THREAD = + KTraits::HALF_ELEMS_PER_THREAD; // 4 + + // CDNA3-specific constants + constexpr uint32_t SEQUENCES_PER_MMA_TILE = 16; + constexpr uint32_t SEQUENCES_PER_THREAD_GROUP = KV_THR_LAYOUT_ROW; // 4 + constexpr uint32_t THREAD_GROUPS_PER_MMA_TILE = + SEQUENCES_PER_MMA_TILE / SEQUENCES_PER_THREAD_GROUP; // 4 + constexpr uint32_t FEATURE_CHUNKS_PER_THREAD_GROUP = + NUM_MMA_D / HALF_ELEMS_PER_THREAD; // NUM_MMA_D/4 + constexpr uint32_t COLUMN_RESET_OFFSET = + FEATURE_CHUNKS_PER_THREAD_GROUP * KV_THR_LAYOUT_COL; + + uint32_t row = lane_idx / KV_THR_LAYOUT_COL; + uint32_t kv_idx = kv_idx_base + warp_idx * KV_THR_LAYOUT_ROW + row; + + // NOTE: NUM_MMA_KV*4/NUM_WARPS_Q = NUM_WARPS_KV*NUM_MMA_KV*4/num_warps + static_assert(NUM_MMA_KV * 4 % NUM_WARPS_Q == 0); + +#pragma unroll + for (uint32_t i = 0; i < NUM_MMA_KV * 4 / NUM_WARPS_Q; ++i) + { // MMA tile iterations + + // CDNA3: Load complete 16×HEAD_DIM tile per i iteration +#pragma unroll + for (uint32_t k = 0; k < THREAD_GROUPS_PER_MMA_TILE; ++k) + { // 4 sequence groups +#pragma unroll + for (uint32_t j = 0; j < FEATURE_CHUNKS_PER_THREAD_GROUP; ++j) + { // Feature chunks + smem.template load_vector_async(*smem_offset, *gptr, + kv_idx < kv_len); + + // Advance to next feature chunk (same sequence group) + *smem_offset = + smem.template advance_offset_by_column( + *smem_offset, j); + *gptr += KV_THR_LAYOUT_COL * + upcast_size(); + } + + // Advance to next sequence group within same MMA tile + if (k < THREAD_GROUPS_PER_MMA_TILE - 1) + { // Don't advance after last group + kv_idx += NUM_WARPS * KV_THR_LAYOUT_ROW; + *smem_offset = + smem.template advance_offset_by_row< + NUM_WARPS * KV_THR_LAYOUT_ROW, UPCAST_STRIDE>( + *smem_offset) - + COLUMN_RESET_OFFSET; + *gptr += NUM_WARPS * KV_THR_LAYOUT_ROW * stride_n - + FEATURE_CHUNKS_PER_THREAD_GROUP * KV_THR_LAYOUT_COL * + upcast_size(); + } + } + + // Final advance to next MMA tile + kv_idx += NUM_WARPS * KV_THR_LAYOUT_ROW; + *smem_offset = + smem.template advance_offset_by_row(*smem_offset) - + COLUMN_RESET_OFFSET; + *gptr += NUM_WARPS * KV_THR_LAYOUT_ROW * stride_n - + FEATURE_CHUNKS_PER_THREAD_GROUP * KV_THR_LAYOUT_COL * + upcast_size(); + } + *smem_offset -= KTraits::CTA_TILE_KV * UPCAST_STRIDE; +} + +/*! + * \brief Produce k/v fragments from global memory to shared memory. + * \tparam fill_mode The fill mode of the shared memory. + * \tparam NUM_MMA_D_VO The number of fragments in y dimension. + * \tparam NUM_MMA_KV The number of fragments in z dimension. + * \tparam num_warps The number of warps in the threadblock. + * \tparam T The data type of the input tensor. + * \param smem The shared memory to store kv fragments. + * \param gptr The global memory pointer. + * \param kv_idx_base The base kv index. + * \param kv_len The length of kv tensor. + */ +template +__device__ __forceinline__ void produce_kv( + smem_t smem, + uint32_t *smem_offset, + typename KTraits::DTypeKV **gptr, + const uint32_t stride_n, + const uint32_t kv_idx_base, + const uint32_t kv_len, + const dim3 tid = threadIdx) +{ + // NOTE: for fp8, this function doesn't work for head_dim = 64 at the moment + const uint32_t warp_idx = get_warp_idx(tid.y, tid.z), + lane_idx = tid.x; + + produce_kv_impl_cdna3_( + warp_idx, lane_idx, smem, smem_offset, gptr, stride_n, kv_idx_base, + kv_len); +} + +template +__device__ __forceinline__ void page_produce_kv( + smem_t smem, + uint32_t *smem_offset, + const paged_kv_t + &paged_kv, + const uint32_t kv_idx_base, + const size_t *thr_local_kv_offset, + const uint32_t kv_len, + const dim3 tid = threadIdx) +{ + // NOTE: for fp8, this function doesn't work for head_dim = 64 at the moment + using DType = typename KTraits::DTypeKV; + constexpr SharedMemFillMode fill_mode = + produce_v ? SharedMemFillMode::kFillZero : SharedMemFillMode::kNoFill; + constexpr uint32_t NUM_WARPS = KTraits::NUM_WARPS; + constexpr uint32_t NUM_WARPS_Q = KTraits::NUM_WARPS_Q; + constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; + constexpr uint32_t NUM_MMA_D = + produce_v ? KTraits::NUM_MMA_D_VO : KTraits::NUM_MMA_D_QK; + constexpr uint32_t UPCAST_STRIDE = + produce_v ? KTraits::UPCAST_STRIDE_V : KTraits::UPCAST_STRIDE_K; + constexpr uint32_t VECTOR_BIT_WIDTH = KTraits::VECTOR_BIT_WIDTH; + + const uint32_t warp_idx = get_warp_idx(tid.y, tid.z), + lane_idx = tid.x; + if constexpr (KTraits::SWIZZLE_MODE_KV == SwizzleMode::k128B) { + uint32_t kv_idx = kv_idx_base + warp_idx * 4 + lane_idx / 8; + // NOTE: NUM_MMA_KV * 4 / NUM_WARPS_Q = NUM_WARPS_KV * NUM_MMA_KV * 4 / + // num_warps + static_assert(NUM_MMA_KV * 4 % NUM_WARPS_Q == 0); +#pragma unroll + for (uint32_t i = 0; i < NUM_MMA_KV * 4 / NUM_WARPS_Q; ++i) { + DType *gptr = produce_v ? paged_kv.v_data + thr_local_kv_offset[i] + : paged_kv.k_data + thr_local_kv_offset[i]; +#pragma unroll + for (uint32_t j = 0; j < NUM_MMA_D / (8 / sizeof(DType)); ++j) { + smem.template load_vector_async(*smem_offset, gptr, + kv_idx < kv_len); + *smem_offset = + smem.template advance_offset_by_column<8>(*smem_offset, j); + gptr += 8 * upcast_size(); + } + kv_idx += NUM_WARPS * 4; + *smem_offset = smem.template advance_offset_by_row( + *smem_offset) - + sizeof(DType) * NUM_MMA_D; + } + *smem_offset -= KTraits::CTA_TILE_KV * UPCAST_STRIDE; + } + else { + uint32_t kv_idx = kv_idx_base + warp_idx * 8 + lane_idx / 4; + // NOTE: NUM_MMA_KV * 2 / NUM_WARPS_Q = NUM_WARPS_KV * NUM_MMA_KV * 2 / + // num_warps + static_assert(NUM_MMA_KV * 2 % NUM_WARPS_Q == 0); +#pragma unroll + for (uint32_t i = 0; i < NUM_MMA_KV * 2 / NUM_WARPS_Q; ++i) { + DType *gptr = produce_v ? paged_kv.v_data + thr_local_kv_offset[i] + : paged_kv.k_data + thr_local_kv_offset[i]; + smem.template load_vector_async(*smem_offset, gptr, + kv_idx < kv_len); + kv_idx += NUM_WARPS * 8; + *smem_offset = smem.template advance_offset_by_row( + *smem_offset); + } + *smem_offset -= KTraits::CTA_TILE_KV * UPCAST_STRIDE; + } +} + +__device__ __forceinline__ uint32_t get_feature_index(uint32_t j) +{ + + // CDNA3 A-matrix MMA tile to thread mapping for a 64-thread wavefront: + // Each group of 16 threads handles the same four consecutive features for + // different sequences: + // T0-T15: Features [0,1,2,3] for sequences 0-15 respectively + // T16-T31: Features [4,5,6,7] for sequences 0-15 respectively + // T32-T47: Features [8,9,10,11] for sequences 0-15 respectively + // T48-T63: Features [12,13,14,15] for sequences 0-15 respectively + // + uint32_t feature_index = (mma_d * 16 + (lane_idx / 4) + j) % (HEAD_DIM / 2); + + return feature_index; +} + +template +__device__ __forceinline__ void +init_rope_freq(float (*rope_freq)[4], + const float rope_rcp_scale, + const float rope_rcp_theta, + const uint32_t tid_x = threadIdx.x) +{ + constexpr uint32_t HEAD_DIM = KTraits::NUM_MMA_D_QK * 16; + const uint32_t lane_idx = tid_x; + +#pragma unroll + for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO / 2; ++mma_d) { +#pragma unroll + for (uint32_t j = 0; j < 4; ++j) { + rope_freq[mma_d][j] = + rope_rcp_scale * + __powf(rope_rcp_theta, + float(2 * get_feature_index(j)) / float(HEAD_DIM)); + } + } +} + +template +__device__ __forceinline__ void init_states( + typename KTraits::AttentionVariant variant, + float (*o_frag)[KTraits::NUM_MMA_D_VO][KTraits::HALF_ELEMS_PER_THREAD], + typename KTraits::DTypeQKAccum (*m)[KTraits::NUM_ACCUM_ROWS_PER_THREAD], + float (*d)[KTraits::NUM_ACCUM_ROWS_PER_THREAD]) +{ + constexpr uint32_t NUM_ACCUM_ROWS_PER_THREAD = + KTraits::NUM_ACCUM_ROWS_PER_THREAD; +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { +#pragma unroll + for (uint32_t reg_id = 0; reg_id < KTraits::HALF_ELEMS_PER_THREAD; + ++reg_id) + { + o_frag[mma_q][mma_d][reg_id] = 0.f; + } + } + } + + if constexpr (variant.use_softmax) { +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t j = 0; j < NUM_ACCUM_ROWS_PER_THREAD; ++j) { + m[mma_q][j] = + typename KTraits::DTypeQKAccum(-gpu_iface::math::inf); + d[mma_q][j] = 1.f; + } + } + } +} + +template +__device__ __forceinline__ void load_q_global_smem( + uint32_t packed_offset, + const uint32_t qo_upper_bound, + typename KTraits::DTypeQ *q_ptr_base, + const uint32_t q_stride_n, + const uint32_t q_stride_h, + const uint_fastdiv group_size, + smem_t *q_smem, + const dim3 tid = threadIdx) +{ + using DTypeQ = typename KTraits::DTypeQ; + constexpr uint32_t WARP_THREAD_COLS = KTraits::WARP_THREAD_COLS; + constexpr uint32_t WARP_THREAD_ROWS = KTraits::WARP_THREAD_ROWS; + constexpr uint32_t HALF_ELEMS_PER_THREAD = KTraits::HALF_ELEMS_PER_THREAD; + constexpr uint32_t NUM_MMA_D_QK = KTraits::NUM_MMA_D_QK; + constexpr uint32_t UPCAST_STRIDE_Q = KTraits::UPCAST_STRIDE_Q; + constexpr uint32_t VECTOR_BIT_WIDTH = KTraits::VECTOR_BIT_WIDTH; + + constexpr uint32_t COLUMN_RESET_OFFSET = + (NUM_MMA_D_QK / 4) * WARP_THREAD_COLS; + + const uint32_t lane_idx = tid.x, + warp_idx_x = get_warp_idx_q(tid.y); + uint32_t row = lane_idx / WARP_THREAD_COLS; + uint32_t col = lane_idx % WARP_THREAD_COLS; + + if (get_warp_idx_kv(tid.z) == 0) { + uint32_t q_smem_offset_w = + q_smem->template get_permuted_offset( + warp_idx_x * KTraits::NUM_MMA_Q * 16 + row, col); + +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t j = 0; j < 2 * 2; ++j) { + uint32_t q, r; + group_size.divmod(packed_offset + row + mma_q * 16 + j * 4, q, + r); + const uint32_t q_idx = q; + DTypeQ *q_ptr = q_ptr_base + q * q_stride_n + r * q_stride_h + + col * upcast_size(); +#pragma unroll + for (uint32_t mma_do = 0; mma_do < KTraits::NUM_MMA_D_QK / 4; + ++mma_do) + { + // load q fragment from gmem to smem + q_smem->template load_vector_async< + SharedMemFillMode::kNoFill>(q_smem_offset_w, q_ptr, + q_idx < qo_upper_bound); + q_smem_offset_w = q_smem->template advance_offset_by_column< + WARP_THREAD_COLS>(q_smem_offset_w, mma_do); + q_ptr += HALF_ELEMS_PER_THREAD * + upcast_size(); + } + q_smem_offset_w = + q_smem->template advance_offset_by_row( + q_smem_offset_w) - + COLUMN_RESET_OFFSET; + } + } + } +} + +template +__device__ __forceinline__ void q_smem_inplace_apply_rotary( + const uint32_t q_packed_idx, + const uint32_t qo_len, + const uint32_t kv_len, + const uint_fastdiv group_size, + smem_t *q_smem, + uint32_t *q_smem_offset_r, + float (*rope_freq)[4], + const dim3 tid = threadIdx) +{ + if (get_warp_idx_kv(tid.z) == 0) { + constexpr uint32_t UPCAST_STRIDE_Q = KTraits::UPCAST_STRIDE_Q; + const uint32_t lane_idx = tid.x; + uint32_t q_frag_local[2][KTraits::INT32_ELEMS_PER_THREAD]; + static_assert(KTraits::NUM_MMA_D_QK % 4 == 0, + "NUM_MMA_D_QK must be a multiple of 4"); + constexpr uint32_t LAST_HALF_OFFSET = KTraits::NUM_MMA_D_QK * 2; + constexpr uint32_t FIRST_HALF_OFFSET = KTraits::NUM_MMA_D_QK; + const uint32_t SEQ_ID = lane_idx % 16; + +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { + uint32_t q_smem_offset_r_first_half = *q_smem_offset_r; +#pragma unroll + for (uint32_t mma_di = 0; mma_di < FIRST_HALF_OFFSET; ++mma_di) { + q_smem->template load_fragment(q_smem_offset_r_first_half, + q_frag_local[0]); + uint32_t q_smem_offset_r_last_half = + q_smem->template advance_offset_by_column( + q_smem_offset_r_first_half, 0); + q_smem->template load_fragment(q_smem_offset_r_last_half, + q_frag_local[1]); + q_frag_apply_llama_rope( + (typename KTraits::DTypeQ *)q_frag_local[0], + (typename KTraits::DTypeQ *)q_frag_local[1], + rope_freq[mma_di], + q_packed_idx + kv_len * group_size - qo_len * group_size + + mma_q * 16 + SEQ_ID, + group_size); + q_smem->template store_fragment(q_smem_offset_r_last_half, + q_frag_local[1]); + q_smem->template store_fragment(q_smem_offset_r_first_half, + q_frag_local[0]); + q_smem_offset_r_first_half = + q_smem + ->template advance_offset_by_column( + q_smem_offset_r_first_half, mma_di); + } + *q_smem_offset_r += 16 * UPCAST_STRIDE_Q; + } + *q_smem_offset_r -= KTraits::NUM_MMA_Q * 16 * UPCAST_STRIDE_Q; + } +} + +template +__device__ __forceinline__ void compute_qk( + smem_t *q_smem, + uint32_t *q_smem_offset_r, + smem_t *k_smem, + uint32_t *k_smem_offset_r, + typename KTraits::DTypeQKAccum ( + *s_frag)[KTraits::NUM_MMA_KV][KTraits::HALF_ELEMS_PER_THREAD]) +{ + constexpr uint32_t UPCAST_STRIDE_Q = KTraits::UPCAST_STRIDE_Q; + constexpr uint32_t UPCAST_STRIDE_K = KTraits::UPCAST_STRIDE_K; + constexpr uint32_t QK_SMEM_COLUMN_ADVANCE = + 16 / KTraits::HALF_ELEMS_PER_THREAD; + + uint32_t a_frag[KTraits::NUM_MMA_Q][KTraits::INT32_ELEMS_PER_THREAD], + b_frag[KTraits::INT32_ELEMS_PER_THREAD]; + // compute q*k^T +#pragma unroll + for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_QK; ++mma_d) { +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { + q_smem->load_fragment(*q_smem_offset_r, a_frag[mma_q]); + *q_smem_offset_r = + q_smem->template advance_offset_by_row<16, UPCAST_STRIDE_Q>( + *q_smem_offset_r); + } + + *q_smem_offset_r = + q_smem->template advance_offset_by_column( + *q_smem_offset_r, mma_d) - + KTraits::NUM_MMA_Q * 16 * UPCAST_STRIDE_Q; + +#pragma unroll + for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { + + k_smem->load_fragment(*k_smem_offset_r, b_frag); + *k_smem_offset_r = + k_smem->template advance_offset_by_row<16, UPCAST_STRIDE_K>( + *k_smem_offset_r); + +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { + if constexpr (std::is_same_v) + { + if (mma_d == 0) { + mma::mma_sync_m16n16k16_row_col_f16f16f32< + typename KTraits::DTypeQ, MMAMode::kInit>( + s_frag[mma_q][mma_kv], a_frag[mma_q], b_frag); + } + else { + mma::mma_sync_m16n16k16_row_col_f16f16f32< + typename KTraits::DTypeQ>(s_frag[mma_q][mma_kv], + a_frag[mma_q], b_frag); + } + } + else if (std::is_same_v) { + static_assert( + false, + "FP16 DTypeQKAccum not yet implemented for CDNA3"); + } + } + } + if constexpr (sizeof(typename KTraits::DTypeKV) == 1) { + if (mma_d % 2 == 1) { + *k_smem_offset_r = k_smem->template advance_offset_by_column< + QK_SMEM_COLUMN_ADVANCE>(*k_smem_offset_r, mma_d / 2); + } + *k_smem_offset_r -= KTraits::NUM_MMA_KV * 16 * UPCAST_STRIDE_K; + } + else { + *k_smem_offset_r = + k_smem + ->template advance_offset_by_column( + *k_smem_offset_r, mma_d) - + KTraits::NUM_MMA_KV * 16 * UPCAST_STRIDE_K; + } + } + *q_smem_offset_r -= KTraits::NUM_MMA_D_QK * QK_SMEM_COLUMN_ADVANCE; + *k_smem_offset_r -= + KTraits::NUM_MMA_D_QK * sizeof(typename KTraits::DTypeKV); +} + +template +__device__ __forceinline__ void logits_transform( + const Params ¶ms, + typename KTraits::AttentionVariant variant, + const uint32_t batch_idx, + const uint32_t qo_packed_idx_base, + const uint32_t kv_idx_base, + const uint32_t qo_len, + const uint32_t kv_len, + const uint_fastdiv group_size, + DTypeQKAccum (*s_frag)[KTraits::NUM_MMA_KV][KTraits::HALF_ELEMS_PER_THREAD], + const dim3 tid = threadIdx, + const uint32_t kv_head_idx = blockIdx.z) +{ + constexpr uint32_t TPR = KTraits::THREADS_PER_MATRIX_ROW_SET; + constexpr uint32_t NAPTR = KTraits::NUM_ACCUM_ROWS_PER_THREAD; + constexpr uint32_t LIS = KTraits::LOGITS_INDEX_STRIDE; + + const uint32_t lane_idx = tid.x; + uint32_t q[KTraits::NUM_MMA_Q][NAPTR], r[KTraits::NUM_MMA_Q][NAPTR]; + float logits = 0., logitsTransformed = 0.; + +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t j = 0; j < NAPTR; ++j) { + group_size.divmod(qo_packed_idx_base + mma_q * 16 + lane_idx / TPR + + LIS * j, + q[mma_q][j], r[mma_q][j]); + } + } + +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { +#pragma unroll + for (uint32_t reg_id = 0; reg_id < KTraits::HALF_ELEMS_PER_THREAD; + ++reg_id) + { + const uint32_t q_idx = q[mma_q][reg_id % NAPTR]; + const uint32_t qo_head_idx = + kv_head_idx * group_size + r[mma_q][reg_id % NAPTR]; + const uint32_t kv_idx = kv_idx_base + mma_kv * 16 + + 2 * (lane_idx % TPR) + + 8 * (reg_id / 2) + reg_id % 2; + +#ifdef FP16_QK_REDUCTION_SUPPORTED + if constexpr (std::is_same::value) { + logits = std::bit_cast( + fp16_ieee_to_fp32_value(s_frag[mma_q][mma_kv][reg_id])); + } + else if constexpr (!std::is_same::value) { + logits = s_frag[mma_q][mma_kv][reg_id]; + } +#else + static_assert( + !std::is_same::value, + "Set -DFP16_QK_REDUCTION_SUPPORTED and install boost_math " + "then recompile to support fp16 reduction"); + logits = s_frag[mma_q][mma_kv][reg_id]; +#endif + logitsTransformed = + variant.LogitsTransform(params, logits, batch_idx, q_idx, + kv_idx, qo_head_idx, kv_head_idx); +#ifdef FP16_QK_REDUCTION_SUPPORTED + if constexpr (std::is_same::value) { + s_frag[mma_q][mma_kv][reg_id] = std::bit_cast( + fp16_ieee_from_fp32_value(logitsTransformed)); + } + else if constexpr (!std::is_same::value) { + s_frag[mma_q][mma_kv][reg_id] = logitsTransformed; + } +#else + s_frag[mma_q][mma_kv][reg_id] = logitsTransformed; +#endif + } + } + } +} + +template +__device__ __forceinline__ void +logits_mask(const Params ¶ms, + typename KTraits::AttentionVariant variant, + const uint32_t batch_idx, + const uint32_t qo_packed_idx_base, + const uint32_t kv_idx_base, + const uint32_t qo_len, + const uint32_t kv_len, + const uint32_t chunk_end, + const uint_fastdiv group_size, + typename KTraits::DTypeQKAccum ( + *s_frag)[KTraits::NUM_MMA_KV][KTraits::HALF_ELEMS_PER_THREAD], + const dim3 tid = threadIdx, + const uint32_t kv_head_idx = blockIdx.z) +{ + const uint32_t lane_idx = tid.x; + constexpr uint32_t NUM_MMA_Q = KTraits::NUM_MMA_Q; + constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; + constexpr MaskMode MASK_MODE = KTraits::MASK_MODE; + constexpr uint32_t TPR = KTraits::THREADS_PER_MATRIX_ROW_SET; + constexpr uint32_t NAPTR = KTraits::NUM_ACCUM_ROWS_PER_THREAD; + constexpr uint32_t LIS = KTraits::LOGITS_INDEX_STRIDE; + + uint32_t q[NUM_MMA_Q][NAPTR], r[NUM_MMA_Q][NAPTR]; +#pragma unroll + for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t j = 0; j < NAPTR; ++j) { + group_size.divmod(qo_packed_idx_base + mma_q * 16 + lane_idx / TPR + + LIS * j, + q[mma_q][j], r[mma_q][j]); + } + } + +#pragma unroll + for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t mma_kv = 0; mma_kv < NUM_MMA_KV; ++mma_kv) { +#pragma unroll + for (uint32_t reg_id = 0; reg_id < KTraits::HALF_ELEMS_PER_THREAD; + ++reg_id) + { + + const uint32_t q_idx = q[mma_q][(reg_id % NAPTR)], + kv_idx = kv_idx_base + mma_kv * 16 + + 2 * (lane_idx % TPR) + + 8 * (reg_id / 2) + reg_id % 2; + const uint32_t qo_head_idx = + kv_head_idx * group_size + r[mma_q][(reg_id % NAPTR)]; + const bool mask = + (!(MASK_MODE == MaskMode::kCausal + ? (kv_idx + qo_len > kv_len + q_idx || + (kv_idx >= chunk_end)) + : kv_idx >= chunk_end)) && + variant.LogitsMask(params, batch_idx, q_idx, kv_idx, + qo_head_idx, kv_head_idx); + s_frag[mma_q][mma_kv][reg_id] = + (mask) ? s_frag[mma_q][mma_kv][reg_id] + : (KTraits::MaskFillValue); + } + } + } +} + +template +__device__ __forceinline__ void update_mdo_states( + typename KTraits::AttentionVariant variant, + typename KTraits::DTypeQKAccum ( + *s_frag)[KTraits::NUM_MMA_KV][KTraits::HALF_ELEMS_PER_THREAD], + float (*o_frag)[KTraits::NUM_MMA_D_VO][KTraits::HALF_ELEMS_PER_THREAD], + typename KTraits::DTypeQKAccum (*m)[KTraits::NUM_ACCUM_ROWS_PER_THREAD], + float (*d)[KTraits::NUM_ACCUM_ROWS_PER_THREAD]) +{ + using DTypeQKAccum = typename KTraits::DTypeQKAccum; + using AttentionVariant = typename KTraits::AttentionVariant; + constexpr uint32_t NUM_ACCUM_ROWS_PER_THREAD = + KTraits::NUM_ACCUM_ROWS_PER_THREAD; + constexpr bool use_softmax = AttentionVariant::use_softmax; + + if constexpr (use_softmax) { + const float sm_scale = variant.sm_scale_log2; + if constexpr (std::is_same_v) { +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t j = 0; j < NUM_ACCUM_ROWS_PER_THREAD; ++j) { + float m_prev = m[mma_q][j]; +#pragma unroll + for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; + ++mma_kv) + { + m[mma_q][j] = + max(m[mma_q][j], s_frag[mma_q][mma_kv][j]); + } + // Butterfly reduction across all threads in the band (16 + // threads) for CDNA3's 64-thread wavefront + m[mma_q][j] = + max(m[mma_q][j], gpu_iface::math::shfl_xor_sync( + m[mma_q][j], 0x8)); // 16 apart + m[mma_q][j] = + max(m[mma_q][j], gpu_iface::math::shfl_xor_sync( + m[mma_q][j], 0x4)); // 8 apart + m[mma_q][j] = + max(m[mma_q][j], gpu_iface::math::shfl_xor_sync( + m[mma_q][j], 0x2)); // 4 apart + m[mma_q][j] = + max(m[mma_q][j], gpu_iface::math::shfl_xor_sync( + m[mma_q][j], 0x1)); // 2 apart + + float o_scale = gpu_iface::math::ptx_exp2( + m_prev * sm_scale - m[mma_q][j] * sm_scale); + d[mma_q][j] *= o_scale; + + // Scale output fragments for this specific row +#pragma unroll + for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; + ++mma_d) + { + o_frag[mma_q][mma_d][j] *= o_scale; // Direct indexing + } + + // Convert logits to probabilities for this row +#pragma unroll + for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; + ++mma_kv) + { + s_frag[mma_q][mma_kv][j] = gpu_iface::math::ptx_exp2( + s_frag[mma_q][mma_kv][j] * sm_scale - + m[mma_q][j] * sm_scale); + } + } + } + } + } +} + +template +__device__ __forceinline__ void compute_sfm_v( + smem_t *v_smem, + uint32_t *v_smem_offset_r, + typename KTraits::DTypeQKAccum ( + *s_frag)[KTraits::NUM_MMA_KV][KTraits::HALF_ELEMS_PER_THREAD], + float (*o_frag)[KTraits::NUM_MMA_D_VO][KTraits::HALF_ELEMS_PER_THREAD], + float (*d)[KTraits::NUM_ACCUM_ROWS_PER_THREAD]) +{ + constexpr uint32_t UPCAST_STRIDE_V = KTraits::UPCAST_STRIDE_V; + constexpr uint32_t HALF_ELEMS_PER_THREAD = KTraits::HALF_ELEMS_PER_THREAD; + constexpr uint32_t INT32_ELEMS_PER_THREAD = KTraits::INT32_ELEMS_PER_THREAD; + + constexpr uint32_t V_SMEM_COLUMN_ADVANCE = + 16 / KTraits::HALF_ELEMS_PER_THREAD; + + typename KTraits::DTypeQ s_frag_f16[KTraits::NUM_MMA_Q][KTraits::NUM_MMA_KV] + [HALF_ELEMS_PER_THREAD]; + if constexpr (std::is_same_v) { +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { + vec_cast::template cast< + HALF_ELEMS_PER_THREAD>(s_frag_f16[mma_q][mma_kv], + s_frag[mma_q][mma_kv]); + } + } + } + + if constexpr (KTraits::AttentionVariant::use_softmax) { +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { + if constexpr (std::is_same_v) + { + mma::m16k16_rowsum_f16f16f32(d[mma_q], + s_frag_f16[mma_q][mma_kv]); + } + else { + static_assert( + !std::is_same_v, + "FP16 reduction path not implemented for CDNA3"); + } + } + } + } + +#pragma unroll + for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { +#pragma unroll + for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { + uint32_t b_frag[INT32_ELEMS_PER_THREAD]; + + v_smem->load_fragment_4x4_transposed(*v_smem_offset_r, b_frag); + +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { + if constexpr (std::is_same_v) + { + mma::mma_sync_m16n16k16_row_col_f16f16f32< + typename KTraits::DTypeQ>( + o_frag[mma_q][mma_d], + (uint32_t *)s_frag_f16[mma_q][mma_kv], b_frag); + } + else { + mma::mma_sync_m16n16k16_row_col_f16f16f32< + typename KTraits::DTypeQ>( + o_frag[mma_q][mma_d], (uint32_t *)s_frag[mma_q][mma_kv], + b_frag); + } + } + if constexpr (sizeof(typename KTraits::DTypeKV) == 1) { + if (mma_d % 2 == 1) { + *v_smem_offset_r = + v_smem->template advance_offset_by_column< + V_SMEM_COLUMN_ADVANCE>(*v_smem_offset_r, mma_d / 2); + } + } + else { + *v_smem_offset_r = v_smem->template advance_offset_by_column< + V_SMEM_COLUMN_ADVANCE>(*v_smem_offset_r, mma_d); + } + } + *v_smem_offset_r = + v_smem->template advance_offset_by_row<16, UPCAST_STRIDE_V>( + *v_smem_offset_r) - + sizeof(typename KTraits::DTypeKV) * KTraits::NUM_MMA_D_VO; + } + *v_smem_offset_r -= 16 * KTraits::NUM_MMA_KV * UPCAST_STRIDE_V; +} + +template +__device__ __forceinline__ void normalize_d( + float (*o_frag)[KTraits::NUM_MMA_D_VO][KTraits::HALF_ELEMS_PER_THREAD], + typename KTraits::DTypeQKAccum (*m)[KTraits::NUM_ACCUM_ROWS_PER_THREAD], + float (*d)[KTraits::NUM_ACCUM_ROWS_PER_THREAD]) +{ + using AttentionVariant = typename KTraits::AttentionVariant; + constexpr uint32_t NAPTR = KTraits::NUM_ACCUM_ROWS_PER_THREAD; + + if constexpr (AttentionVariant::use_softmax) { + float d_rcp[KTraits::NUM_MMA_Q][KTraits::NUM_ACCUM_ROWS_PER_THREAD]; + // compute reciprocal of d +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t j = 0; j < KTraits::NUM_ACCUM_ROWS_PER_THREAD; ++j) { + d_rcp[mma_q][j] = + (m[mma_q][j] != + typename KTraits::DTypeQKAccum(-gpu_iface::math::inf)) + ? gpu_iface::math::ptx_rcp(d[mma_q][j]) + : 0.f; + } + } + +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { +#pragma unroll + for (uint32_t reg_id = 0; + reg_id < KTraits::HALF_ELEMS_PER_THREAD; ++reg_id) + { + o_frag[mma_q][mma_d][reg_id] = + o_frag[mma_q][mma_d][reg_id] * + d_rcp[mma_q][reg_id % NAPTR]; + } + } + } + } +} + +template +__device__ __forceinline__ void finalize_m( + typename KTraits::AttentionVariant variant, + typename KTraits::DTypeQKAccum (*m)[KTraits::NUM_ACCUM_ROWS_PER_THREAD]) +{ + if constexpr (variant.use_softmax) { +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t j = 0; j < KTraits::NUM_ACCUM_ROWS_PER_THREAD; ++j) { + if (m[mma_q][j] != + typename KTraits::DTypeQKAccum(-gpu_iface::math::inf)) + { + m[mma_q][j] *= variant.sm_scale_log2; + } + } + } + } +} + +/*! + * \brief Synchronize the states of the MDO kernel across the threadblock along + * threadIdx.z. + */ +template +__device__ __forceinline__ void threadblock_sync_mdo_states( + float (*o_frag)[KTraits::NUM_MMA_D_VO][KTraits::HALF_ELEMS_PER_THREAD], + typename KTraits::SharedStorage *smem_storage, + typename KTraits::DTypeQKAccum (*m)[KTraits::NUM_ACCUM_ROWS_PER_THREAD], + float (*d)[KTraits::NUM_ACCUM_ROWS_PER_THREAD], + const uint32_t warp_idx, + const uint32_t lane_idx, + const dim3 tid = threadIdx) +{ + constexpr uint32_t TPR = KTraits::THREADS_PER_MATRIX_ROW_SET; + constexpr uint32_t NARPT = KTraits::NUM_ACCUM_ROWS_PER_THREAD; + + static_assert(WARP_SIZE % TPR == 0, + "THREADS_PER_MATRIX_ROW_SET must divide WARP_SIZE"); + constexpr uint32_t GROUPS_PER_WARP = WARP_SIZE / TPR; + const uint32_t lane_group_idx = lane_idx / TPR; + + // only necessary when blockDim.z > 1 + if constexpr (KTraits::NUM_WARPS_KV > 1) { + float *smem_o = smem_storage->cta_sync_o_smem; + float2 *smem_md = smem_storage->cta_sync_md_smem; + // o: [num_warps, + // NUM_MMA_Q, + // NUM_MMA_D_VO, + // WARP_SIZE, + // HALF_ELEMS_PER_THREAD] + // md: [num_warps, NUM_MMA_Q, 16, 2 (m/d)] +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { + vec_t::memcpy( + smem_o + (((warp_idx * KTraits::NUM_MMA_Q + mma_q) * + KTraits::NUM_MMA_D_VO + + mma_d) * + WARP_SIZE + + lane_idx) * + KTraits::HALF_ELEMS_PER_THREAD, + o_frag[mma_q][mma_d]); + } + } + + if constexpr (KTraits::AttentionVariant::use_softmax) { +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t j = 0; j < NARPT; ++j) { + smem_md[((warp_idx * KTraits::NUM_MMA_Q + mma_q) * NARPT + + j) * + GROUPS_PER_WARP + + lane_group_idx] = + make_float2(float(m[mma_q][j]), d[mma_q][j]); + } + } + + // synchronize m,d first + __syncthreads(); +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { + float o_scale[NARPT][KTraits::NUM_WARPS_KV]; +#pragma unroll + for (uint32_t j = 0; j < NARPT; ++j) { + float m_new = -gpu_iface::math::inf, d_new = 1.f; +#pragma unroll + for (uint32_t i = 0; i < KTraits::NUM_WARPS_KV; ++i) { + float2 md = smem_md[(((i * KTraits::NUM_WARPS_Q + + get_warp_idx_q(tid.y)) * + KTraits::NUM_MMA_Q + + mma_q) * + NARPT + + j) * + GROUPS_PER_WARP + + lane_group_idx]; + float m_prev = m_new, d_prev = d_new; + m_new = max(m_new, md.x); + d_new = + d_prev * gpu_iface::math::ptx_exp2(m_prev - m_new) + + md.y * gpu_iface::math::ptx_exp2(md.x - m_new); + } + +#pragma unroll + for (uint32_t i = 0; i < KTraits::NUM_WARPS_KV; ++i) { + float2 md = smem_md[(((i * KTraits::NUM_WARPS_Q + + get_warp_idx_q(tid.y)) * + KTraits::NUM_MMA_Q + + mma_q) * + NARPT + + j) * + GROUPS_PER_WARP + + lane_group_idx]; + float mi = md.x; + o_scale[j][i] = + gpu_iface::math::ptx_exp2(float(mi - m_new)); + } + m[mma_q][j] = typename KTraits::DTypeQKAccum(m_new); + d[mma_q][j] = d_new; + } + +#pragma unroll + for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) + { + vec_t o_new; + o_new.fill(0.f); +#pragma unroll + for (uint32_t i = 0; i < KTraits::NUM_WARPS_KV; ++i) { + vec_t oi; + oi.load(smem_o + ((((i * KTraits::NUM_WARPS_Q + + get_warp_idx_q(tid.y)) * + KTraits::NUM_MMA_Q + + mma_q) * + KTraits::NUM_MMA_D_VO + + mma_d) * + WARP_SIZE + + lane_idx) * + KTraits::HALF_ELEMS_PER_THREAD); + +#pragma unroll + for (uint32_t reg_id = 0; + reg_id < KTraits::HALF_ELEMS_PER_THREAD; ++reg_id) + { + // CDNA3: Direct mapping - each reg_id corresponds + // to one accumulator row + o_new[reg_id] += oi[reg_id] * o_scale[reg_id][i]; + } + } + o_new.store(o_frag[mma_q][mma_d]); + } + } + } + else { + // synchronize m,d first + __syncthreads(); +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) + { + vec_t o_new; + o_new.fill(0.f); +#pragma unroll + for (uint32_t i = 0; i < KTraits::NUM_WARPS_KV; ++i) { + vec_t oi; + oi.load(smem_o + ((((i * KTraits::NUM_WARPS_Q + + get_warp_idx_q(tid.y)) * + KTraits::NUM_MMA_Q + + mma_q) * + KTraits::NUM_MMA_D_VO + + mma_d) * + WARP_SIZE + + lane_idx) * + KTraits::HALF_ELEMS_PER_THREAD); +#pragma unroll + for (uint32_t reg_id = 0; + reg_id < KTraits::HALF_ELEMS_PER_THREAD; ++reg_id) + { + o_new[reg_id] += oi[reg_id]; + } + } + o_new.store(o_frag[mma_q][mma_d]); + } + } + } + } +} + +template +__device__ __forceinline__ void write_o_reg_gmem( + float (*o_frag)[KTraits::NUM_MMA_D_VO][KTraits::HALF_ELEMS_PER_THREAD], + smem_t *o_smem, + typename KTraits::DTypeO *o_ptr_base, + const uint32_t o_packed_idx_base, + const uint32_t qo_upper_bound, + const uint32_t o_stride_n, + const uint32_t o_stride_h, + const uint_fastdiv group_size, + const dim3 tid = threadIdx) +{ + using DTypeO = typename KTraits::DTypeO; + constexpr uint32_t UPCAST_STRIDE_O = KTraits::UPCAST_STRIDE_O; + constexpr uint32_t TPR = KTraits::THREADS_PER_MATRIX_ROW_SET; + constexpr uint32_t NAPTR = KTraits::NUM_ACCUM_ROWS_PER_THREAD; + constexpr uint32_t HALF_ELEMS_PER_THREAD = KTraits::HALF_ELEMS_PER_THREAD; + constexpr uint32_t WARP_THREAD_COLS = KTraits::WARP_THREAD_COLS; + constexpr uint32_t VECTOR_BIT_WIDTH = KTraits::VECTOR_BIT_WIDTH; + + const uint32_t warp_idx_x = get_warp_idx_q(tid.y); + const uint32_t lane_idx = tid.x; + + if constexpr (sizeof(DTypeO) == 4) { +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t j = 0; j < NAPTR; ++j) { + uint32_t q, r; + group_size.divmod(o_packed_idx_base + lane_idx / TPR + + mma_q * 16 + j * 8, + q, r); + const uint32_t o_idx = q; +#pragma unroll + for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) + { + if (o_idx < qo_upper_bound) { + auto base_addr = o_ptr_base + q * o_stride_n + + r * o_stride_h + mma_d * 16; + auto col_offset = lane_idx % 16; + *(base_addr + col_offset) = o_frag[mma_q][mma_d][j]; + } + } + } + } + } + else { + if (get_warp_idx_kv(tid.z) == 0) { +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) + { + uint32_t o_frag_f16[HALF_ELEMS_PER_THREAD / 2]; + vec_cast::template cast< + HALF_ELEMS_PER_THREAD>((DTypeO *)o_frag_f16, + o_frag[mma_q][mma_d]); + +#ifdef FLASHINFER_STMATRIX_M8N8X4_ENABLED + uint32_t o_smem_offset_w = + o_smem->template get_permuted_offset( + (warp_idx_x * KTraits::NUM_MMA_Q + mma_q) * 16 + + lane_idx % 16, + mma_d * 2 + lane_idx / 16); + o_smem->stmatrix_m8n8x4(o_smem_offset_w, o_frag_f16); +#else + uint32_t o_smem_offset_w = + o_smem->template get_permuted_offset( + (warp_idx_x * KTraits::NUM_MMA_Q + mma_q) * 16 + + lane_idx / TPR, + mma_d * 2); + ((uint32_t *)(o_smem->base + + o_smem_offset_w))[lane_idx % TPR] = + o_frag_f16[0]; + // Move 2 elements forward in the same row + uint32_t offset_2 = o_smem_offset_w + 2; + ((uint32_t *)(o_smem->base + offset_2))[lane_idx % 16] = + o_frag_f16[1]; + +#endif + } + } + + uint32_t o_smem_offset_w = + o_smem->template get_permuted_offset( + warp_idx_x * KTraits::NUM_MMA_Q * 16 + + lane_idx / WARP_THREAD_COLS, + lane_idx % WARP_THREAD_COLS); + +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t j = 0; j < 2 * 2; ++j) { + uint32_t q, r; + group_size.divmod(o_packed_idx_base + + lane_idx / WARP_THREAD_COLS + + mma_q * 16 + j * 4, + q, r); + const uint32_t o_idx = q; + DTypeO *o_ptr = o_ptr_base + q * o_stride_n + + r * o_stride_h + + (lane_idx % WARP_THREAD_COLS) * + upcast_size(); +#pragma unroll + for (uint32_t mma_do = 0; + mma_do < KTraits::NUM_MMA_D_VO / 4; ++mma_do) + { + if (o_idx < qo_upper_bound) { + o_smem->store_vector(o_smem_offset_w, o_ptr); + } + o_ptr += WARP_THREAD_COLS * + upcast_size(); + o_smem_offset_w = + o_smem->template advance_offset_by_column< + WARP_THREAD_COLS>(o_smem_offset_w, mma_do); + } + o_smem_offset_w = o_smem->template advance_offset_by_row< + 4, UPCAST_STRIDE_O>(o_smem_offset_w) - + 2 * KTraits::NUM_MMA_D_VO; + } + } + } + } +} + +} // namespace + +/*! + * \brief FlashAttention prefill CUDA kernel for a single request. + * \tparam partition_kv Whether to split kv_len into chunks. + * \tparam mask_mode The mask mode used in the attention operation. + * \tparam POS_ENCODING_MODE The positional encoding mode. + * \tparam NUM_MMA_Q The number of fragments in x dimension. + * \tparam NUM_MMA_D_VO The number of fragments in y dimension. + * \tparam NUM_MMA_KV The number of fragments in z dimension. + * \tparam num_warps The number of warps in the threadblock. + * \tparam DTypeQ The data type of the query tensor. + * \tparam DTypeKV The data type of the key/value tensor. + * \tparam DTypeO The data type of the output tensor. + * \param q The query tensor. + * \param k The key tensor. + * \param v The value tensor. + * \param o The output tensor. + * \param tmp The temporary buffer (used when partition_kv is true). + * \param lse The logsumexp value. + * \param rope_rcp_scale 1/(rope_scale), where rope_scale is the scaling + * factor used in RoPE interpolation. + * \param rope_rcp_theta 1/(rope_theta), where rope_theta is the theta + * used in RoPE. + */ +template +__device__ __forceinline__ void +SinglePrefillWithKVCacheDevice(const Params params, + typename KTraits::SharedStorage &smem_storage, + const dim3 tid = threadIdx, + const uint32_t bx = blockIdx.x, + const uint32_t chunk_idx = blockIdx.y, + const uint32_t kv_head_idx = blockIdx.z, + const uint32_t num_chunks = gridDim.y, + const uint32_t num_kv_heads = gridDim.z) +{ + using DTypeQ = typename Params::DTypeQ; + using DTypeKV = typename Params::DTypeKV; + using DTypeO = typename Params::DTypeO; + using DTypeQKAccum = typename KTraits::DTypeQKAccum; + using AttentionVariant = typename KTraits::AttentionVariant; + [[maybe_unused]] constexpr uint32_t NUM_MMA_Q = KTraits::NUM_MMA_Q; + [[maybe_unused]] constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; + [[maybe_unused]] constexpr uint32_t NUM_MMA_D_QK = KTraits::NUM_MMA_D_QK; + [[maybe_unused]] constexpr uint32_t NUM_MMA_D_VO = KTraits::NUM_MMA_D_VO; + [[maybe_unused]] constexpr uint32_t HEAD_DIM_QK = KTraits::HEAD_DIM_QK; + [[maybe_unused]] constexpr uint32_t HEAD_DIM_VO = KTraits::HEAD_DIM_VO; + [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_Q = + KTraits::UPCAST_STRIDE_Q; + [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_K = + KTraits::UPCAST_STRIDE_K; + [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_V = + KTraits::UPCAST_STRIDE_V; + [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_O = + KTraits::UPCAST_STRIDE_O; + [[maybe_unused]] constexpr uint32_t CTA_TILE_Q = KTraits::CTA_TILE_Q; + [[maybe_unused]] constexpr uint32_t CTA_TILE_KV = KTraits::CTA_TILE_KV; + [[maybe_unused]] constexpr uint32_t NUM_WARPS_Q = KTraits::NUM_WARPS_Q; + [[maybe_unused]] constexpr uint32_t NUM_WARPS_KV = KTraits::NUM_WARPS_KV; + [[maybe_unused]] constexpr SwizzleMode SWIZZLE_MODE_Q = + KTraits::SWIZZLE_MODE_Q; + [[maybe_unused]] constexpr SwizzleMode SWIZZLE_MODE_KV = + KTraits::SWIZZLE_MODE_KV; + [[maybe_unused]] constexpr uint32_t KV_THR_LAYOUT_ROW = + KTraits::KV_THR_LAYOUT_ROW; + [[maybe_unused]] constexpr uint32_t KV_THR_LAYOUT_COL = + KTraits::KV_THR_LAYOUT_COL; + [[maybe_unused]] constexpr MaskMode MASK_MODE = KTraits::MASK_MODE; + [[maybe_unused]] constexpr uint32_t HALF_ELEMS_PER_THREAD = + KTraits::HALF_ELEMS_PER_THREAD; + [[maybe_unused]] constexpr uint32_t NUM_ACCUM_ROWS_PER_THREAD = + KTraits::NUM_ACCUM_ROWS_PER_THREAD; + [[maybe_unused]] constexpr uint32_t LOGITS_INDEX_STRIDE = + KTraits::LOGITS_INDEX_STRIDE; + [[maybe_unused]] constexpr uint32_t THREADS_PER_MATRIX_ROW_SET = + KTraits::THREADS_PER_MATRIX_ROW_SET; + [[maybe_unused]] constexpr uint32_t VECTOR_BIT_WIDTH = + KTraits::VECTOR_BIT_WIDTH; + + DTypeQ *q = params.q; + DTypeKV *k = params.k; + DTypeKV *v = params.v; + DTypeO *o = params.o; + float *lse = params.lse; + const uint32_t qo_len = params.qo_len; + const uint32_t kv_len = params.kv_len; + const bool partition_kv = params.partition_kv; + const uint32_t q_stride_n = params.q_stride_n; + const uint32_t q_stride_h = params.q_stride_h; + const uint32_t k_stride_n = params.k_stride_n; + const uint32_t k_stride_h = params.k_stride_h; + const uint32_t v_stride_n = params.v_stride_n; + const uint32_t v_stride_h = params.v_stride_h; + const uint_fastdiv &group_size = params.group_size; + + static_assert(sizeof(DTypeQ) == 2); + const uint32_t lane_idx = tid.x, + warp_idx = get_warp_idx(tid.y, tid.z); + const uint32_t num_qo_heads = num_kv_heads * group_size; + + const uint32_t max_chunk_size = + partition_kv ? ceil_div(kv_len, num_chunks) : kv_len; + const uint32_t chunk_start = partition_kv ? chunk_idx * max_chunk_size : 0; + const uint32_t chunk_end = + partition_kv ? min((chunk_idx + 1) * max_chunk_size, kv_len) : kv_len; + const uint32_t chunk_size = chunk_end - chunk_start; + + auto block = cg::this_thread_block(); + auto smem = reinterpret_cast(&smem_storage); + AttentionVariant variant(params, /*batch_idx=*/0, smem); + const uint32_t window_left = variant.window_left; + + DTypeQKAccum s_frag[NUM_MMA_Q][NUM_MMA_KV][HALF_ELEMS_PER_THREAD]; + alignas(16) float o_frag[NUM_MMA_Q][NUM_MMA_D_VO][HALF_ELEMS_PER_THREAD]; + DTypeQKAccum m[NUM_MMA_Q][NUM_ACCUM_ROWS_PER_THREAD]; + float d[NUM_MMA_Q][NUM_ACCUM_ROWS_PER_THREAD]; + float rope_freq[NUM_MMA_D_QK / 2][4]; + + init_states(variant, o_frag, m, d); + + // cooperative fetch q fragment from gmem to reg + const uint32_t qo_packed_idx_base = + (bx * NUM_WARPS_Q + get_warp_idx_q(tid.y)) * NUM_MMA_Q * 16; + smem_t qo_smem( + smem_storage.q_smem); + const uint32_t o_stride_n = num_qo_heads * HEAD_DIM_VO, + o_stride_h = HEAD_DIM_VO; + DTypeQ *q_ptr_base = q + (kv_head_idx * group_size) * q_stride_h; + DTypeO *o_ptr_base = partition_kv + ? o + chunk_idx * o_stride_n + + (kv_head_idx * group_size) * o_stride_h + : o + (kv_head_idx * group_size) * o_stride_h; + + uint32_t q_smem_offset_r = + qo_smem.template get_permuted_offset( + get_warp_idx_q(tid.y) * NUM_MMA_Q * 16 + lane_idx % 16, + lane_idx / 16); + + load_q_global_smem(qo_packed_idx_base, qo_len, q_ptr_base, + q_stride_n, q_stride_h, group_size, &qo_smem, + tid); + + memory::commit_group(); + + smem_t k_smem( + smem_storage.k_smem); + smem_t v_smem( + smem_storage.v_smem); + + const uint32_t num_iterations = ceil_div( + MASK_MODE == MaskMode::kCausal + ? min(chunk_size, + sub_if_greater_or_zero( + kv_len - qo_len + ((bx + 1) * CTA_TILE_Q) / group_size, + chunk_start)) + : chunk_size, + CTA_TILE_KV); + + const uint32_t window_iteration = ceil_div( + sub_if_greater_or_zero(kv_len + (bx + 1) * CTA_TILE_Q / group_size, + qo_len + window_left + chunk_start), + CTA_TILE_KV); + + const uint32_t mask_iteration = + (MASK_MODE == MaskMode::kCausal + ? min(chunk_size, + sub_if_greater_or_zero( + kv_len + (bx * CTA_TILE_Q) / group_size - qo_len, + chunk_start)) + : chunk_size) / + CTA_TILE_KV; + + DTypeKV *k_ptr = k + + (chunk_start + warp_idx * KV_THR_LAYOUT_ROW + + lane_idx / KV_THR_LAYOUT_COL) * + k_stride_n + + kv_head_idx * k_stride_h + + (lane_idx % KV_THR_LAYOUT_COL) * + upcast_size(); + DTypeKV *v_ptr = v + + (chunk_start + warp_idx * KV_THR_LAYOUT_ROW + + lane_idx / KV_THR_LAYOUT_COL) * + v_stride_n + + kv_head_idx * v_stride_h + + (lane_idx % KV_THR_LAYOUT_COL) * + upcast_size(); + uint32_t k_smem_offset_r = + k_smem.template get_permuted_offset( + get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + lane_idx % 16, + (lane_idx / 16)); + + uint32_t + v_smem_offset_r = v_smem.template get_permuted_offset( + get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + lane_idx % 16, + lane_idx / 16), + k_smem_offset_w = k_smem.template get_permuted_offset( + warp_idx * KV_THR_LAYOUT_ROW + lane_idx / KV_THR_LAYOUT_COL, + lane_idx % KV_THR_LAYOUT_COL), + v_smem_offset_w = v_smem.template get_permuted_offset( + warp_idx * KV_THR_LAYOUT_ROW + lane_idx / KV_THR_LAYOUT_COL, + lane_idx % KV_THR_LAYOUT_COL); + produce_kv( + k_smem, &k_smem_offset_w, &k_ptr, k_stride_n, 0, chunk_size, tid); + memory::commit_group(); + produce_kv( + v_smem, &v_smem_offset_w, &v_ptr, v_stride_n, 0, chunk_size, tid); + memory::commit_group(); + +#if Debug + int global_idx = (blockIdx.z * gridDim.y * gridDim.x + + blockIdx.y * gridDim.x + blockIdx.x) * + (blockDim.z * blockDim.y * blockDim.x) + + (threadIdx.z * blockDim.y * blockDim.x + + threadIdx.y * blockDim.x + threadIdx.x); + + if (global_idx == 0) { + printf("partition_kv : %d\n", partition_kv); + printf("kv_len : %d\n", kv_len); + printf("max_chunk_size : %d\n", max_chunk_size); + printf("chunk_end : %d\n", chunk_end); + printf("chunk_start : %d\n", chunk_start); + } + // Test Q + // if (global_idx == 0) { + // uint32_t q_smem_offset_r_debug; + // //for (auto i = 0; i < 4; ++i) { + // for (auto j = 0; j < 16; ++j) { + // uint32_t q_smem_offset_r_debug = + // qo_smem.template + // get_permuted_offset( + // get_warp_idx_q(tid.y) * NUM_MMA_Q * 16 + // + (j) % 16, (j) / 16); + // uint32_t a_frag[KTraits::INT32_ELEMS_PER_THREAD]; + // k_smem.load_fragment(q_smem_offset_r_debug, a_frag); + // auto frag_T = reinterpret_cast<__half *>(a_frag); + // for (auto i = 0ul; i < 4; ++i) { + // printf("%f ", (float)(*(frag_T + i))); + // } + // printf("\n"); + // } + // // q_smem_offset_r_debug = qo_smem.template + // advance_offset_by_column<4>( + // // q_smem_offset_r_debug, 0); + // // } + // } + + // for (auto mma_q = 0ul; mma_q < 4; ++mma_q) { + // uint32_t a_frag[KTraits::INT32_ELEMS_PER_THREAD]; + // qo_smem.load_fragment(q_smem_offset_r, a_frag); + // if (global_idx == 0) { + // auto frag_T = reinterpret_cast<__half *>(a_frag); + // printf("DEBUG: Q Frag in permuted_smem for mma_q %lu \n", + // mma_q); for (auto i = 0ul; i < 4; ++i) { + // printf("%f ", (float)(*(frag_T + i))); + // } + // printf("\n"); + // } + + // q_smem_offset_r = qo_smem.template advance_offset_by_column<4>( + // q_smem_offset_r, 0); + // } + + uint32_t a_frag[KTraits::INT32_ELEMS_PER_THREAD]; + qo_smem.load_fragment(q_smem_offset_r, a_frag); + if (global_idx == 0) { + auto frag_T = reinterpret_cast<__half *>(a_frag); + printf("DEBUG: Q Frag \n"); + for (auto i = 0ul; i < 4; ++i) { + printf("%f ", (float)(*(frag_T + i))); + } + printf("\n"); + } + + memory::wait_group<0>(); + block.sync(); + q_smem_inplace_apply_rotary(qo_packed_idx_base, qo_len, kv_len, + group_size, &qo_smem, &q_smem_offset_r, + rope_freq, tid); + block.sync(); + + qo_smem.load_fragment(q_smem_offset_r, a_frag); + if (global_idx == 0) { + auto frag_T = reinterpret_cast<__half *>(a_frag); + printf("DEBUG: LLAMA Rope transformed Q Frag \n"); + for (auto i = 0ul; i < 4; ++i) { + printf("%f ", (float)(*(frag_T + i))); + } + printf("\n"); + } + + // // Test K loads + // if (global_idx == 0) { + + // for (auto j = 0; j < 64; ++j) { + // uint32_t k_smem_offset_r_test = + // k_smem.template get_permuted_offset( + // get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + + // j % 16, + // (j / 16)); + // uint32_t b_frag[KTraits::INT32_ELEMS_PER_THREAD]; + // k_smem.load_fragment(k_smem_offset_r_test, b_frag); + // auto frag_T = reinterpret_cast<__half *>(b_frag); + // // printf("DEBUG: K Frag in permuted_smem for mma_kv %lu \n", + // // mma_kv); + // for (auto i = 0ul; i < 4; ++i) { + // printf("%f ", (float)(*(frag_T + i))); + // } + // printf("\n"); + // } + // } + + // if (global_idx == 0) { + // printf("DEBUG Q ORIGINAL (HIP):\n"); + + // for (uint32_t seq_idx = 0; seq_idx < 16; ++seq_idx) { + // printf("Q[%u] original: ", seq_idx); + + // // Load all feature groups for this sequence + // for (uint32_t feat_group = 0; feat_group < NUM_MMA_D_QK; + // ++feat_group) { + // uint32_t feat_offset = qo_smem.template + // get_permuted_offset( + // seq_idx, feat_group * HALF_ELEMS_PER_THREAD); + + // uint32_t q_frag[KTraits::INT32_ELEMS_PER_THREAD]; + // qo_smem.load_fragment(feat_offset, q_frag); + // auto frag_T = reinterpret_cast<__half *>(q_frag); + + // // Print 4 features from this group + // for (auto feat = 0ul; feat < HALF_ELEMS_PER_THREAD; + // ++feat) { + // printf("%f ", (float)(*(frag_T + feat))); + // } + // } + // printf("\n"); + // } + // } + + // memory::wait_group<0>(); + // block.sync(); + // q_smem_inplace_apply_rotary( + // qo_packed_idx_base, qo_len, kv_len, group_size, &qo_smem, + // &q_smem_offset_r, rope_freq, tid); + // block.sync(); + + // // Debug: Print Q fragments after RoPE + // if (global_idx == 0) { + // printf("DEBUG Q LLAMA ROPE (HIP):\n"); + + // // Reset q_smem_offset_r to start + // uint32_t q_smem_offset_r_debug = + // qo_smem.template get_permuted_offset( + // get_warp_idx_q(tid.y) * NUM_MMA_Q * 16 + + // lane_idx % 16, lane_idx / 16); + + // for (uint32_t seq_idx = 0; seq_idx < 16; ++seq_idx) { + // // Calculate offset for this sequence + // uint32_t seq_offset = qo_smem.template + // get_permuted_offset( + // seq_idx, 0); + + // printf("Q[%u] after RoPE: ", seq_idx); + + // // Load all feature groups for this sequence + // for (uint32_t feat_group = 0; feat_group < NUM_MMA_D_QK; + // ++feat_group) { + // uint32_t feat_offset = qo_smem.template + // get_permuted_offset( + // seq_idx, feat_group * HALF_ELEMS_PER_THREAD); + + // uint32_t q_frag[KTraits::INT32_ELEMS_PER_THREAD]; + // qo_smem.load_fragment(feat_offset, q_frag); + // auto frag_T = reinterpret_cast<__half *>(q_frag); + + // // Print 4 features from this group + // for (auto feat = 0ul; feat < HALF_ELEMS_PER_THREAD; + // ++feat) { + // printf("%f ", (float)(*(frag_T + feat))); + // } + // } + // printf("\n"); + // } + // } +#endif + +#pragma unroll 1 + for (uint32_t iter = 0; iter < num_iterations; ++iter) { + memory::wait_group<1>(); + block.sync(); + + // compute attention score + compute_qk(&qo_smem, &q_smem_offset_r, &k_smem, + &k_smem_offset_r, s_frag); + + logits_transform( + params, variant, /*batch_idx=*/0, qo_packed_idx_base, + chunk_start + + (iter * NUM_WARPS_KV + get_warp_idx_kv(tid.z)) * + NUM_MMA_KV * 16, + qo_len, kv_len, group_size, s_frag, tid, kv_head_idx); + + // // apply mask + // if (MASK_MODE == MaskMode::kCustom || + // (iter >= mask_iteration || iter < window_iteration)) + // { + // logits_mask( + // params, variant, /*batch_idx=*/0, qo_packed_idx_base, + // chunk_start + (iter * NUM_WARPS_KV + + // get_warp_idx_kv(tid.z)) * + // NUM_MMA_KV * 16, + // qo_len, kv_len, chunk_end, group_size, s_frag, tid, + // kv_head_idx); + // } + + // compute m,d states in online softmax + update_mdo_states(variant, s_frag, o_frag, m, d); + + block.sync(); + + produce_kv( + k_smem, &k_smem_offset_w, &k_ptr, k_stride_n, + (iter + 1) * CTA_TILE_KV, chunk_size, tid); + memory::commit_group(); + memory::wait_group<1>(); + block.sync(); + + // compute sfm*v + compute_sfm_v(&v_smem, &v_smem_offset_r, s_frag, o_frag, d); + + block.sync(); + produce_kv( + v_smem, &v_smem_offset_w, &v_ptr, v_stride_n, + (iter + 1) * CTA_TILE_KV, chunk_size, tid); + memory::commit_group(); + } + memory::wait_group<0>(); + block.sync(); + + finalize_m(variant, m); + + // threadblock synchronization + threadblock_sync_mdo_states(o_frag, &smem_storage, m, d, warp_idx, + lane_idx, tid); + + // normalize d + normalize_d(o_frag, m, d); + + // write back + write_o_reg_gmem( + o_frag, &qo_smem, o_ptr_base, qo_packed_idx_base, qo_len, + /*o_stride_n=*/ + partition_kv ? num_chunks * o_stride_n : o_stride_n, + /*o_stride_h=*/o_stride_h, group_size, tid); + + // write lse + if constexpr (variant.use_softmax) { + if (lse != nullptr || partition_kv) { + if (get_warp_idx_kv(tid.z) == 0) { +#pragma unroll + for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t j = 0; j < NUM_ACCUM_ROWS_PER_THREAD; ++j) { + uint32_t q, r; + group_size.divmod( + qo_packed_idx_base + + lane_idx / THREADS_PER_MATRIX_ROW_SET + + j * LOGITS_INDEX_STRIDE + mma_q * 16, + q, r); + const uint32_t qo_head_idx = + kv_head_idx * group_size + r; + const uint32_t qo_idx = q; + if (qo_idx < qo_len) { + if (partition_kv) { + lse[(qo_idx * num_chunks + chunk_idx) * + num_qo_heads + + qo_head_idx] = + gpu_iface::math::ptx_log2(d[mma_q][j]) + + float(m[mma_q][j]); + } + else { + lse[qo_idx * num_qo_heads + qo_head_idx] = + gpu_iface::math::ptx_log2(d[mma_q][j]) + + float(m[mma_q][j]); + } + } + } + } + } + } + } +} + +template +__global__ +__launch_bounds__(KTraits::NUM_THREADS) void SinglePrefillWithKVCacheKernel( + const __grid_constant__ Params params) +{ + extern __shared__ uint8_t smem[]; + auto &smem_storage = + reinterpret_cast(smem); + SinglePrefillWithKVCacheDevice(params, smem_storage); +} + +template +gpuError_t SinglePrefillWithKVCacheDispatched(Params params, + typename Params::DTypeO *tmp, + gpuStream_t stream) +{ + using DTypeQ = typename Params::DTypeQ; + using DTypeKV = typename Params::DTypeKV; + using DTypeO = typename Params::DTypeO; + const uint32_t num_qo_heads = params.num_qo_heads; + const uint32_t num_kv_heads = params.num_kv_heads; + const uint32_t qo_len = params.qo_len; + const uint32_t kv_len = params.kv_len; + if (kv_len < qo_len && MASK_MODE == MaskMode::kCausal) { + std::ostringstream err_msg; + err_msg << "When mask_mode is set to MaskMode::kCausal, kv_len must be " + "greater than or equal to qo_len, got kv_len" + << kv_len << " and qo_len " << qo_len; + FLASHINFER_ERROR(err_msg.str()); + } + + const uint32_t group_size = num_qo_heads / num_kv_heads; + constexpr uint32_t NUM_MMA_D_QK = HEAD_DIM_QK / 16; + constexpr uint32_t NUM_MMA_D_VO = HEAD_DIM_VO / 16; + int64_t packed_qo_len = qo_len * group_size; + uint32_t cta_tile_q = FA2DetermineCtaTileQ(packed_qo_len, HEAD_DIM_VO); + + DISPATCH_CTA_TILE_Q(cta_tile_q, CTA_TILE_Q, { + constexpr uint32_t NUM_WARPS_Q = get_num_warps_q(CTA_TILE_Q); + constexpr uint32_t NUM_WARPS_KV = get_num_warps_kv(CTA_TILE_Q); + constexpr uint32_t NUM_MMA_Q = get_num_mma_q(CTA_TILE_Q); + + using DTypeQKAccum = + typename std::conditional, + half, float>::type; + + int dev_id = 0; + FI_GPU_CALL(gpuGetDevice(&dev_id)); + int max_smem_per_sm = getMaxSharedMemPerMultiprocessor(dev_id); + // we expect each sm execute two threadblocks + const int num_ctas_per_sm = + max_smem_per_sm >= 2 * (CTA_TILE_Q * HEAD_DIM_QK * sizeof(DTypeQ) + + (HEAD_DIM_QK + HEAD_DIM_VO) * 16 * + NUM_WARPS_KV * sizeof(DTypeKV)) + ? 2 + : 1; + const int max_smem_per_threadblock = max_smem_per_sm / num_ctas_per_sm; + + const uint32_t max_num_mma_kv_reg = + (HEAD_DIM_VO >= 128 && NUM_MMA_Q == 2 && + POS_ENCODING_MODE == PosEncodingMode::kRoPELlama && + !USE_FP16_QK_REDUCTION) + ? 2 + : (8 / NUM_MMA_Q); + const uint32_t max_num_mma_kv_smem = + (max_smem_per_threadblock - + CTA_TILE_Q * HEAD_DIM_QK * sizeof(DTypeQ)) / + ((HEAD_DIM_QK + HEAD_DIM_VO) * 16 * NUM_WARPS_KV * sizeof(DTypeKV)); + + // control NUM_MMA_KV for maximum warp occupancy + DISPATCH_NUM_MMA_KV( + min(max_num_mma_kv_smem, max_num_mma_kv_reg), NUM_MMA_KV, { + using KTraits = + KernelTraits; + if constexpr (KTraits::IsInvalid()) { + // Invalid configuration, skip + std::ostringstream err_msg; + err_msg << "FlashInfer Internal Error: Invalid " + "configuration : NUM_MMA_Q=" + << NUM_MMA_Q << " NUM_MMA_D_QK=" << NUM_MMA_D_QK + << " NUM_MMA_D_VO=" << NUM_MMA_D_VO + << " NUM_MMA_KV=" << NUM_MMA_KV + << " NUM_WARPS_Q=" << NUM_WARPS_Q + << " NUM_WARPS_KV=" << NUM_WARPS_KV + << " please create an issue " + "(https://github.com/flashinfer-ai/flashinfer/" + "issues)" + " and report the issue to the developers."; + FLASHINFER_ERROR(err_msg.str()); + } + else { + constexpr uint32_t num_threads = + (NUM_WARPS_Q * NUM_WARPS_KV) * WARP_SIZE; + auto kernel = + SinglePrefillWithKVCacheKernel; + size_t smem_size = sizeof(typename KTraits::SharedStorage); + FI_GPU_CALL(gpuFuncSetAttribute( + kernel, gpuFuncAttributeMaxDynamicSharedMemorySize, + smem_size)); + int num_blocks_per_sm = 0; + int num_sm = 0; + FI_GPU_CALL(gpuDeviceGetAttribute( + &num_sm, gpuDevAttrMultiProcessorCount, dev_id)); + FI_GPU_CALL(gpuOccupancyMaxActiveBlocksPerMultiprocessor( + &num_blocks_per_sm, kernel, num_threads, smem_size)); + uint32_t max_num_kv_chunks = + (num_blocks_per_sm * num_sm) / + (num_kv_heads * + ceil_div(qo_len * group_size, CTA_TILE_Q)); + uint32_t num_chunks; + if (max_num_kv_chunks > 0) { + uint32_t chunk_size = + max(ceil_div(kv_len, max_num_kv_chunks), 256); + num_chunks = ceil_div(kv_len, chunk_size); + } + else { + num_chunks = 0; + } + + if (num_chunks <= 1 || tmp == nullptr) { + // Enough parallelism, do not split-kv + params.partition_kv = false; + void *args[] = {(void *)¶ms}; + dim3 nblks(ceil_div(qo_len * group_size, CTA_TILE_Q), 1, + num_kv_heads); + dim3 nthrs(WARP_SIZE, NUM_WARPS_Q, NUM_WARPS_KV); + FI_GPU_CALL(gpuLaunchKernel((void *)kernel, nblks, + nthrs, args, smem_size, + stream)); + } + else { + // Use cooperative groups to increase occupancy + params.partition_kv = true; + float *tmp_lse = + (float *)(tmp + num_chunks * qo_len * num_qo_heads * + HEAD_DIM_VO); + auto o = params.o; + auto lse = params.lse; + params.o = tmp; + params.lse = tmp_lse; + void *args[] = {(void *)¶ms}; + dim3 nblks(ceil_div(qo_len * group_size, CTA_TILE_Q), + num_chunks, num_kv_heads); + dim3 nthrs(WARP_SIZE, NUM_WARPS_Q, NUM_WARPS_KV); + FI_GPU_CALL(gpuLaunchKernel((void *)kernel, nblks, + nthrs, args, smem_size, + stream)); + if constexpr (AttentionVariant::use_softmax) { + FI_GPU_CALL(MergeStates( + tmp, tmp_lse, o, lse, num_chunks, qo_len, + num_qo_heads, HEAD_DIM_VO, stream)); + } + else { + FI_GPU_CALL(AttentionSum(tmp, o, num_chunks, qo_len, + num_qo_heads, HEAD_DIM_VO, + stream)); + } + } + } + }) + }); + return gpuSuccess; +} + +} // namespace flashinfer + +#endif // FLASHINFER_PREFILL_CUH_ diff --git a/libflashinfer/tests/hip/test_compute_sfm.cpp b/libflashinfer/tests/hip/test_compute_sfm.cpp new file mode 100644 index 0000000000..e0cda09e50 --- /dev/null +++ b/libflashinfer/tests/hip/test_compute_sfm.cpp @@ -0,0 +1,234 @@ +// SPDX - FileCopyrightText : 2023 - 2025 Flashinfer team +// SPDX - FileCopyrightText : 2025 Advanced Micro Devices, Inc. +// +// SPDX - License - Identifier : Apache 2.0 + +#include "../../utils/flashinfer_prefill_ops.hip.h" +#include "../../utils/utils_hip.h" +#include "flashinfer/attention/generic/prefill.cuh" +#include "gpu_iface/gpu_runtime_compat.hpp" + +#include + +#include + +#define HIP_ENABLE_WARP_SYNC_BUILTINS 1 + +using namespace flashinfer; + +namespace +{ +template +std::vector test_compute_qk_and_softmax_cpu( + const std::vector &q, + const std::vector &k, + const std::vector &v, + size_t qo_len, + size_t kv_len, + size_t num_qo_heads, + size_t num_kv_heads, + size_t head_dim, + bool causal = true, + QKVLayout kv_layout = QKVLayout::kHND, + PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, + float rope_scale = 1.f, + float rope_theta = 1e4) +{ + assert(qo_len <= kv_len); + assert(num_qo_heads % num_kv_heads == 0); + float sm_scale = 1.f / std::sqrt(float(head_dim)); + std::vector o(qo_len * num_qo_heads * head_dim); + std::vector att(kv_len); + std::vector q_rotary_local(head_dim); + std::vector k_rotary_local(head_dim); + DISPATCH_head_dim(head_dim, HEAD_DIM, { + tensor_info_t info(qo_len, kv_len, num_qo_heads, num_kv_heads, + kv_layout, HEAD_DIM); + + for (size_t qo_head_idx = 0; qo_head_idx < num_qo_heads; ++qo_head_idx) + { + const size_t kv_head_idx = qo_head_idx / info.get_group_size(); + for (size_t q_idx = 0; q_idx < qo_len; ++q_idx) { + float max_val = -5e4; + + for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { + att[kv_idx] = 0.; + switch (pos_encoding_mode) { + case PosEncodingMode::kNone: + { + for (size_t feat_idx = 0; feat_idx < head_dim; + ++feat_idx) + { + att[kv_idx] += + fi::con::explicit_casting( + q[info.get_q_elem_offset(q_idx, qo_head_idx, + feat_idx)]) * + fi::con::explicit_casting( + k[info.get_kv_elem_offset( + kv_idx, kv_head_idx, feat_idx)]) * + sm_scale; + } + break; + } + default: + { + std::ostringstream err_msg; + err_msg << "Unsupported rotary mode."; + FLASHINFER_ERROR(err_msg.str()); + } + } + max_val = std::max(max_val, att[kv_idx]); + } + // exp minus max + float denom = 0; + for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { + att[kv_idx] = std::exp(att[kv_idx] - max_val); + denom += att[kv_idx]; + } + + // divide by denom + for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { + att[kv_idx] /= denom; + } + } + } + }); + return std::move(att); +} +} // namespace + +template +void _TestComputeSFMCorrectness(size_t qo_len, + size_t kv_len, + size_t num_qo_heads, + size_t num_kv_heads, + size_t head_dim, + bool causal, + QKVLayout kv_layout, + PosEncodingMode pos_encoding_mode, + bool use_fp16_qk_reduction, + float rtol = 1e-3, + float atol = 1e-3) +{ + std::vector q(qo_len * num_qo_heads * head_dim); + std::vector k(kv_len * num_kv_heads * head_dim); + std::vector v(kv_len * num_kv_heads * head_dim); + std::vector o(qo_len * num_qo_heads * head_dim); + + utils::generate_data(q); + utils::generate_data(k); + utils::generate_data(v); + utils::generate_data(o); + + DTypeQ *q_d; + FI_GPU_CALL(hipMalloc(&q_d, q.size() * sizeof(DTypeQ))); + FI_GPU_CALL(hipMemcpy(q_d, q.data(), q.size() * sizeof(DTypeQ), + hipMemcpyHostToDevice)); + + DTypeKV *k_d; + FI_GPU_CALL(hipMalloc(&k_d, k.size() * sizeof(DTypeKV))); + FI_GPU_CALL(hipMemcpy(k_d, k.data(), k.size() * sizeof(DTypeKV), + hipMemcpyHostToDevice)); + + DTypeKV *v_d; + FI_GPU_CALL(hipMalloc(&v_d, v.size() * sizeof(DTypeKV))); + FI_GPU_CALL(hipMemcpy(v_d, v.data(), v.size() * sizeof(DTypeKV), + hipMemcpyHostToDevice)); + + DTypeO *o_d; + FI_GPU_CALL(hipMalloc(&o_d, o.size() * sizeof(DTypeO))); + FI_GPU_CALL(hipMemcpy(o_d, o.data(), o.size() * sizeof(DTypeO), + hipMemcpyHostToDevice)); + + DTypeO *tmp_d; + FI_GPU_CALL(hipMalloc(&tmp_d, 16 * 1024 * 1024 * sizeof(DTypeO))); + + hipError_t status = + flashinfer::SinglePrefillWithKVCache( + q_d, k_d, v_d, o_d, tmp_d, + /*lse=*/nullptr, num_qo_heads, num_kv_heads, qo_len, kv_len, + head_dim, causal, kv_layout, pos_encoding_mode, + use_fp16_qk_reduction); + + EXPECT_EQ(status, hipSuccess) + << "SinglePrefillWithKVCache kernel launch failed, error message: " + << hipGetErrorString(status); + + std::vector o_h(o.size()); + FI_GPU_CALL(hipMemcpy(o_h.data(), o_d, o_h.size() * sizeof(DTypeO), + hipMemcpyDeviceToHost)); + + // Print the first 10 elements of the output vector for debugging + // std::cout << "Output vector (first 10 elements):"; + // std::cout << "[" << std::endl; + // for (int i = 0; i < 10; ++i) { + // std::cout << fi::con::explicit_casting(o_h[i]) << " "; + // } + // std::cout << "]" << std::endl; + + bool isEmpty = o_h.empty(); + EXPECT_EQ(isEmpty, false) << "Output vector is empty"; + + std::vector o_ref = + test_compute_qk_and_softmax_cpu( + q, k, v, qo_len, kv_len, num_qo_heads, num_kv_heads, head_dim, + causal, kv_layout, pos_encoding_mode); + size_t num_results_error_atol = 0; + bool nan_detected = false; + + for (size_t i = 0; i < o_ref.size(); ++i) { + float o_h_val = fi::con::explicit_casting(o_h[i]); + float o_ref_val = fi::con::explicit_casting(o_ref[i]); + + if (isnan(o_h_val)) { + nan_detected = true; + } + + num_results_error_atol += + (!utils::isclose(o_ref_val, o_h_val, rtol, atol)); + if (!utils::isclose(o_ref_val, o_h_val, rtol, atol)) { + std::cout << "i=" << i << ", o_ref[i]=" << o_ref_val + << ", o_h[i]=" << o_h_val << std::endl; + } + } + + float result_accuracy = + 1. - float(num_results_error_atol) / float(o_ref.size()); + std::cout << "num_qo_heads=" << num_qo_heads + << ", num_kv_heads=" << num_kv_heads << ", qo_len=" << qo_len + << ", kv_len=" << kv_len << ", head_dim=" << head_dim + << ", causal=" << causal + << ", kv_layout=" << QKVLayoutToString(kv_layout) + << ", pos_encoding_mode=" + << PosEncodingModeToString(pos_encoding_mode) + << ", result_accuracy=" << result_accuracy << std::endl; + + EXPECT_GT(result_accuracy, 0.90) << "Result correctness test failed."; + EXPECT_FALSE(nan_detected) << "Nan detected in the result."; + + FI_GPU_CALL(hipFree(q_d)); + FI_GPU_CALL(hipFree(k_d)); + FI_GPU_CALL(hipFree(v_d)); + FI_GPU_CALL(hipFree(o_d)); + FI_GPU_CALL(hipFree(tmp_d)); +} + +int main(int argc, char **argv) +{ + + using DTypeIn = __half; + using DTypeO = __half; + bool use_fp16_qk_reduction = false; + size_t qo_len = 399; + size_t kv_len = 533; + size_t num_heads = 1; + size_t head_dim = 64; + bool causal = false; + size_t pos_encoding_mode = 0; + size_t kv_layout = 0; + + _TestComputeSFMCorrectness( + qo_len, kv_len, num_heads, num_heads, head_dim, causal, + QKVLayout(kv_layout), PosEncodingMode(pos_encoding_mode), + use_fp16_qk_reduction); +} diff --git a/libflashinfer/utils/flashinfer_prefill_ops.hip.h b/libflashinfer/utils/flashinfer_prefill_ops.hip.h index 38b58d4f80..bbbb2bb98f 100644 --- a/libflashinfer/utils/flashinfer_prefill_ops.hip.h +++ b/libflashinfer/utils/flashinfer_prefill_ops.hip.h @@ -7,11 +7,12 @@ #include "utils_hip.h" -#include "compute_qk_stub.cuh" +// #include "compute_qk_stub.cuh" #include "flashinfer/attention/generic/allocator.h" #include "flashinfer/attention/generic/default_prefill_params.cuh" #include "flashinfer/attention/generic/exception.h" -#include "flashinfer/attention/generic/prefill.cuh" +// #include "flashinfer/attention/generic/prefill.cuh" +#include "flashinfer/attention/generic/prefill_tester.cuh" #include "flashinfer/attention/generic/scheduler.cuh" #include "flashinfer/attention/generic/variants.cuh" @@ -165,63 +166,66 @@ hipError_t SinglePrefillWithKVCache( return hipSuccess; } -template -hipError_t -ComputeQKStubCaller(DTypeQ *q, - DTypeKV *k, - DTypeKV *v, - DTypeO *o, - DTypeO *tmp, - float *lse, - float *qk_scores_output, - uint32_t num_qo_heads, - uint32_t num_kv_heads, - uint32_t qo_len, - uint32_t kv_len, - uint32_t head_dim, - bool causal = true, - QKVLayout kv_layout = QKVLayout::kNHD, - PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, - bool use_fp16_qk_reduction = false, - std::optional maybe_sm_scale = std::nullopt, - float rope_scale = 1.f, - float rope_theta = 1e4, - hipStream_t stream = nullptr) -{ - const float sm_scale = - maybe_sm_scale.value_or(1.f / std::sqrt(float(head_dim))); - const MaskMode mask_mode = causal ? MaskMode::kCausal : MaskMode::kNone; - auto [qo_stride_n, qo_stride_h, kv_stride_n, kv_stride_h] = get_qkv_strides( - kv_layout, kv_len, num_qo_heads, num_kv_heads, head_dim); - DISPATCH_use_fp16_qk_reduction( - static_cast(use_fp16_qk_reduction), USE_FP16_QK_REDUCTION, - {DISPATCH_mask_mode( - mask_mode, MASK_MODE, - {DISPATCH_head_dim( - head_dim, HEAD_DIM, - {DISPATCH_pos_encoding_mode( - pos_encoding_mode, POS_ENCODING_MODE, { - using Params = - SinglePrefillParams; - using AttentionVariant = DefaultAttention< - /*use_custom_mask=*/(MASK_MODE == - MaskMode::kCustom), - /*use_sliding_window=*/false, - /*use_logits_soft_cap=*/false, /*use_alibi=*/false>; - Params params(q, k, v, /*custom_mask=*/nullptr, o, lse, - /*alibi_slopes=*/nullptr, num_qo_heads, - num_kv_heads, qo_len, kv_len, qo_stride_n, - qo_stride_h, kv_stride_n, kv_stride_h, - head_dim, - /*window_left=*/-1, - /*logits_soft_cap=*/0.f, sm_scale, - rope_scale, rope_theta); - return ComputeQKStubDispatched< - HEAD_DIM, HEAD_DIM, POS_ENCODING_MODE, - USE_FP16_QK_REDUCTION, MASK_MODE, AttentionVariant, - Params>(params, tmp, qk_scores_output, stream); - })})})}); - return hipSuccess; -} +// template +// hipError_t +// ComputeQKStubCaller(DTypeQ *q, +// DTypeKV *k, +// DTypeKV *v, +// DTypeO *o, +// DTypeO *tmp, +// float *lse, +// float *qk_scores_output, +// uint32_t num_qo_heads, +// uint32_t num_kv_heads, +// uint32_t qo_len, +// uint32_t kv_len, +// uint32_t head_dim, +// bool causal = true, +// QKVLayout kv_layout = QKVLayout::kNHD, +// PosEncodingMode pos_encoding_mode = +// PosEncodingMode::kNone, bool use_fp16_qk_reduction = +// false, std::optional maybe_sm_scale = +// std::nullopt, float rope_scale = 1.f, float rope_theta = +// 1e4, hipStream_t stream = nullptr) +// { +// const float sm_scale = +// maybe_sm_scale.value_or(1.f / std::sqrt(float(head_dim))); +// const MaskMode mask_mode = causal ? MaskMode::kCausal : MaskMode::kNone; +// auto [qo_stride_n, qo_stride_h, kv_stride_n, kv_stride_h] = +// get_qkv_strides( +// kv_layout, kv_len, num_qo_heads, num_kv_heads, head_dim); +// DISPATCH_use_fp16_qk_reduction( +// static_cast(use_fp16_qk_reduction), USE_FP16_QK_REDUCTION, +// {DISPATCH_mask_mode( +// mask_mode, MASK_MODE, +// {DISPATCH_head_dim( +// head_dim, HEAD_DIM, +// {DISPATCH_pos_encoding_mode( +// pos_encoding_mode, POS_ENCODING_MODE, { +// using Params = +// SinglePrefillParams; +// using AttentionVariant = DefaultAttention< +// /*use_custom_mask=*/(MASK_MODE == +// MaskMode::kCustom), +// /*use_sliding_window=*/false, +// /*use_logits_soft_cap=*/false, +// /*use_alibi=*/false>; +// Params params(q, k, v, /*custom_mask=*/nullptr, o, +// lse, +// /*alibi_slopes=*/nullptr, num_qo_heads, +// num_kv_heads, qo_len, kv_len, +// qo_stride_n, qo_stride_h, kv_stride_n, +// kv_stride_h, head_dim, +// /*window_left=*/-1, +// /*logits_soft_cap=*/0.f, sm_scale, +// rope_scale, rope_theta); +// return ComputeQKStubDispatched< +// HEAD_DIM, HEAD_DIM, POS_ENCODING_MODE, +// USE_FP16_QK_REDUCTION, MASK_MODE, +// AttentionVariant, Params>(params, tmp, +// qk_scores_output, stream); +// })})})}); +// return hipSuccess; +// } } // namespace flashinfer From e3b770c47d924fe98be85c60e01982108aaff059 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Thu, 28 Aug 2025 07:31:52 -0400 Subject: [PATCH 057/109] Debug --- .../flashinfer/attention/generic/prefill.cuh | 55 +++++-------------- .../tests/hip/test_apply_llama_rope.cpp | 2 +- 2 files changed, 16 insertions(+), 41 deletions(-) diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh index 64f24ab475..a346f431ab 100644 --- a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -163,7 +163,7 @@ struct KernelTraits // in attention score computation and cross-warp synchronization. // CUDA: 4 threads (each thread handles 2 elements from same row group) // CDNA3: 16 threads (each thread handles 1 element from same row group) - static constexpr uint32_t THREADS_PER_MATRIX_ROW_SET = 16; + static constexpr uint32_t THREADS_PER_BMATRIX_ROW_SET = 16; // controls the indexing stride used in logits-related functions // (logits_transform, logits_mask, and LSE writing). static constexpr uint32_t LOGITS_INDEX_STRIDE = 4; @@ -193,7 +193,7 @@ struct KernelTraits // Refer: // https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-fragment-mma-16816-i8-f8 static constexpr uint32_t NUM_ACCUM_ROWS_PER_THREAD = 2; - static constexpr uint32_t THREADS_PER_MATRIX_ROW_SET = 4; + static constexpr uint32_t THREADS_PER_BMATRIX_ROW_SET = 4; static constexpr uint32_t LOGITS_INDEX_STRIDE = 8; #endif static constexpr uint32_t UPCAST_STRIDE_Q = @@ -336,23 +336,6 @@ q_frag_apply_llama_rope(T *x_first_half, const uint32_t qo_packed_offset, const uint_fastdiv group_size) { - -#if Debug - int global_idx = (blockIdx.z * gridDim.y * gridDim.x + - blockIdx.y * gridDim.x + blockIdx.x) * - (blockDim.z * blockDim.y * blockDim.x) + - (threadIdx.z * blockDim.y * blockDim.x + - threadIdx.y * blockDim.x + threadIdx.x); - - if (global_idx == 0) { - printf("=== Q_FRAG_APPLY_LLAMA_ROPE DEBUG ===\n"); - printf("qo_packed_offset=%u, group_size=%u, HALF_ELEMS_PER_THREAD=%u\n", - qo_packed_offset, (uint32_t)group_size, HALF_ELEMS_PER_THREAD); - printf("Input frequencies: %f %f %f %f\n", rope_freq[0], rope_freq[1], - rope_freq[2], rope_freq[3]); - } -#endif - #pragma unroll for (uint32_t reg_id = 0; reg_id < HALF_ELEMS_PER_THREAD; ++reg_id) { float cos, sin, tmp; @@ -369,14 +352,6 @@ q_frag_apply_llama_rope(T *x_first_half, #endif __sincosf(float(position / group_size) * rope_freq[freq_idx], &sin, &cos); -#if Debug - if (global_idx == 0) { - printf("reg_id=%u: freq_idx=%u, position=%u, angle=%f\n", reg_id, - freq_idx, position, - float(position / group_size) * rope_freq[freq_idx]); - } -#endif - tmp = x_first_half[reg_id]; x_first_half[reg_id] = (tmp * cos - (float)x_second_half[reg_id] * sin); x_second_half[reg_id] = @@ -951,7 +926,7 @@ __device__ __forceinline__ void q_smem_inplace_apply_rotary_with_pos( (typename KTraits::DTypeQ *)q_frag_local[1], rope_freq[mma_di], q_packed_idx_base + mma_q * 16 + - lane_idx / KTraits::THREADS_PER_MATRIX_ROW_SET, + lane_idx / KTraits::THREADS_PER_BMATRIX_ROW_SET, group_size, q_rope_offset); q_smem->store_fragment(q_smem_offset_r_last_half, q_frag_local[1]); @@ -978,8 +953,8 @@ __device__ __forceinline__ void k_smem_inplace_apply_rotary( using DTypeKV = typename KTraits::DTypeKV; static_assert(sizeof(DTypeKV) == 2); constexpr uint32_t UPCAST_STRIDE_K = KTraits::UPCAST_STRIDE_K; - constexpr uint32_t THREADS_PER_MATRIX_ROW_SET = - KTraits::THREADS_PER_MATRIX_ROW_SET; + constexpr uint32_t THREADS_PER_BMATRIX_ROW_SET = + KTraits::THREADS_PER_BMATRIX_ROW_SET; constexpr uint32_t HALF_ELEMS_PER_THREAD = KTraits::HALF_ELEMS_PER_THREAD; uint32_t k_frag_local[2][KTraits::INT32_ELEMS_PER_THREAD]; const uint32_t lane_idx = tid.x; @@ -995,7 +970,7 @@ __device__ __forceinline__ void k_smem_inplace_apply_rotary( KTraits::NUM_MMA_KV % 2 == 0, "when NUM_MMA_D_QK == 4, NUM_MMA_KV must be a multiple of 2"); uint32_t kv_idx = kv_idx_base + (warp_idx / 2) * 16 + - lane_idx / THREADS_PER_MATRIX_ROW_SET; + lane_idx / THREADS_PER_BMATRIX_ROW_SET; *k_smem_offset_r = (*k_smem_offset_r ^ (0x2 * (warp_idx % 2))) + (warp_idx / 2) * 16 * UPCAST_STRIDE_K; #pragma unroll @@ -1032,7 +1007,7 @@ __device__ __forceinline__ void k_smem_inplace_apply_rotary( // ... uint32_t kv_idx = kv_idx_base + (warp_idx_z * KTraits::NUM_MMA_KV * 16) + - lane_idx / THREADS_PER_MATRIX_ROW_SET; + lane_idx / THREADS_PER_BMATRIX_ROW_SET; *k_smem_offset_r = *k_smem_offset_r ^ (0x2 * warp_idx_x); #pragma unroll for (uint32_t i = 0; i < KTraits::NUM_MMA_KV; ++i) { @@ -1205,7 +1180,7 @@ __device__ __forceinline__ void logits_transform( const dim3 tid = threadIdx, const uint32_t kv_head_idx = blockIdx.z) { - constexpr uint32_t TPR = KTraits::THREADS_PER_MATRIX_ROW_SET; + constexpr uint32_t TPR = KTraits::THREADS_PER_BMATRIX_ROW_SET; constexpr uint32_t NAPTR = KTraits::NUM_ACCUM_ROWS_PER_THREAD; constexpr uint32_t LIS = KTraits::LOGITS_INDEX_STRIDE; @@ -1301,7 +1276,7 @@ logits_mask(const Params ¶ms, constexpr uint32_t NUM_MMA_Q = KTraits::NUM_MMA_Q; constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; constexpr MaskMode MASK_MODE = KTraits::MASK_MODE; - constexpr uint32_t TPR = KTraits::THREADS_PER_MATRIX_ROW_SET; + constexpr uint32_t TPR = KTraits::THREADS_PER_BMATRIX_ROW_SET; constexpr uint32_t NAPTR = KTraits::NUM_ACCUM_ROWS_PER_THREAD; constexpr uint32_t LIS = KTraits::LOGITS_INDEX_STRIDE; @@ -1751,11 +1726,11 @@ __device__ __forceinline__ void threadblock_sync_mdo_states( const uint32_t lane_idx, const dim3 tid = threadIdx) { - constexpr uint32_t TPR = KTraits::THREADS_PER_MATRIX_ROW_SET; + constexpr uint32_t TPR = KTraits::THREADS_PER_BMATRIX_ROW_SET; constexpr uint32_t NARPT = KTraits::NUM_ACCUM_ROWS_PER_THREAD; static_assert(WARP_SIZE % TPR == 0, - "THREADS_PER_MATRIX_ROW_SET must divide WARP_SIZE"); + "THREADS_PER_BMATRIX_ROW_SET must divide WARP_SIZE"); constexpr uint32_t GROUPS_PER_WARP = WARP_SIZE / TPR; const uint32_t lane_group_idx = lane_idx / TPR; @@ -1928,7 +1903,7 @@ __device__ __forceinline__ void write_o_reg_gmem( { using DTypeO = typename KTraits::DTypeO; constexpr uint32_t UPCAST_STRIDE_O = KTraits::UPCAST_STRIDE_O; - constexpr uint32_t TPR = KTraits::THREADS_PER_MATRIX_ROW_SET; + constexpr uint32_t TPR = KTraits::THREADS_PER_BMATRIX_ROW_SET; constexpr uint32_t NAPTR = KTraits::NUM_ACCUM_ROWS_PER_THREAD; constexpr uint32_t HALF_ELEMS_PER_THREAD = KTraits::HALF_ELEMS_PER_THREAD; constexpr uint32_t WARP_THREAD_COLS = KTraits::WARP_THREAD_COLS; @@ -2149,8 +2124,8 @@ SinglePrefillWithKVCacheDevice(const Params params, KTraits::NUM_ACCUM_ROWS_PER_THREAD; [[maybe_unused]] constexpr uint32_t LOGITS_INDEX_STRIDE = KTraits::LOGITS_INDEX_STRIDE; - [[maybe_unused]] constexpr uint32_t THREADS_PER_MATRIX_ROW_SET = - KTraits::THREADS_PER_MATRIX_ROW_SET; + [[maybe_unused]] constexpr uint32_t THREADS_PER_BMATRIX_ROW_SET = + KTraits::THREADS_PER_BMATRIX_ROW_SET; [[maybe_unused]] constexpr uint32_t VECTOR_BIT_WIDTH = KTraits::VECTOR_BIT_WIDTH; @@ -2584,7 +2559,7 @@ SinglePrefillWithKVCacheDevice(const Params params, uint32_t q, r; group_size.divmod( qo_packed_idx_base + - lane_idx / THREADS_PER_MATRIX_ROW_SET + + lane_idx / THREADS_PER_BMATRIX_ROW_SET + j * LOGITS_INDEX_STRIDE + mma_q * 16, q, r); const uint32_t qo_head_idx = diff --git a/libflashinfer/tests/hip/test_apply_llama_rope.cpp b/libflashinfer/tests/hip/test_apply_llama_rope.cpp index f6a800766d..390f4d1648 100644 --- a/libflashinfer/tests/hip/test_apply_llama_rope.cpp +++ b/libflashinfer/tests/hip/test_apply_llama_rope.cpp @@ -356,7 +356,7 @@ TEST_P(LLamaRopeTestWithFP16, VectorSizeIsCorrect) TEST_P(LLamaRopeTestWithFP16, TestQFragApplyRopeComparison) { - constexpr float RELATIVE_EPSILON = 1e-3f; + constexpr float RELATIVE_EPSILON = 1e-2f; auto cpu_result = this->apply_cpu_rope(744); auto gpu_result = this->test_gpu_q_frag_apply_rope(); From ee6797060bc13d498fdc69569068b4b4a813c1ca Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Thu, 28 Aug 2025 08:28:49 -0400 Subject: [PATCH 058/109] Debugging changes --- .../include/flashinfer/attention/generic/prefill.cuh | 3 ++- libflashinfer/tests/hip/test_single_prefill.cpp | 12 +++++++++--- libflashinfer/utils/flashinfer_prefill_ops.hip.h | 4 ++-- libflashinfer/utils/utils_hip.h | 7 +++++++ 4 files changed, 20 insertions(+), 6 deletions(-) diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh index a346f431ab..b995245377 100644 --- a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -2470,7 +2470,8 @@ SinglePrefillWithKVCacheDevice(const Params params, #endif #pragma unroll 1 - for (uint32_t iter = 0; iter < num_iterations; ++iter) { + // for (uint32_t iter = 0; iter < num_iterations; ++iter) { + for (uint32_t iter = 0; iter < 1; ++iter) { memory::wait_group<1>(); block.sync(); diff --git a/libflashinfer/tests/hip/test_single_prefill.cpp b/libflashinfer/tests/hip/test_single_prefill.cpp index 23749dad3e..89b06949bc 100644 --- a/libflashinfer/tests/hip/test_single_prefill.cpp +++ b/libflashinfer/tests/hip/test_single_prefill.cpp @@ -17,6 +17,7 @@ using namespace flashinfer; +#if 0 template void _TestComputeQKCorrectness(size_t qo_len, size_t kv_len, @@ -174,6 +175,8 @@ void _TestComputeQKCorrectness(size_t qo_len, FI_GPU_CALL(hipFree(qk_scores_d)); } +#endif + template void _TestSinglePrefillKernelCorrectness(size_t qo_len, size_t kv_len, @@ -192,9 +195,12 @@ void _TestSinglePrefillKernelCorrectness(size_t qo_len, std::vector v(kv_len * num_kv_heads * head_dim); std::vector o(qo_len * num_qo_heads * head_dim); - utils::vec_normal_(q); - utils::vec_normal_(k); - utils::vec_normal_(v); + // utils::vec_normal_(q); + // utils::vec_normal_(k); + // utils::vec_normal_(v); + utils::vec_lexicographic_(q); + utils::vec_fill_(k, __float2half(1.0f)); + utils::vec_fill_(v, __float2half(1.0f)); utils::vec_zero_(o); DTypeQ *q_d; diff --git a/libflashinfer/utils/flashinfer_prefill_ops.hip.h b/libflashinfer/utils/flashinfer_prefill_ops.hip.h index bbbb2bb98f..2866368dab 100644 --- a/libflashinfer/utils/flashinfer_prefill_ops.hip.h +++ b/libflashinfer/utils/flashinfer_prefill_ops.hip.h @@ -11,8 +11,8 @@ #include "flashinfer/attention/generic/allocator.h" #include "flashinfer/attention/generic/default_prefill_params.cuh" #include "flashinfer/attention/generic/exception.h" -// #include "flashinfer/attention/generic/prefill.cuh" -#include "flashinfer/attention/generic/prefill_tester.cuh" +#include "flashinfer/attention/generic/prefill.cuh" +// #include "flashinfer/attention/generic/prefill_tester.cuh" #include "flashinfer/attention/generic/scheduler.cuh" #include "flashinfer/attention/generic/variants.cuh" diff --git a/libflashinfer/utils/utils_hip.h b/libflashinfer/utils/utils_hip.h index 6ca5e7b85f..0c1068b8ff 100644 --- a/libflashinfer/utils/utils_hip.h +++ b/libflashinfer/utils/utils_hip.h @@ -85,6 +85,13 @@ template void generate_data(std::vector &vec) } } +template void vec_lexicographic_(std::vector &vec) +{ + for (int i = 0; i < vec.size(); i++) { + vec[i] = fi::con::explicit_casting(static_cast(i)); + } +} + template void vec_normal_(std::vector &vec, float mean = 0.f, float std = 1.f) { From 7137d9ce8e3ecc72506d42bd70202a0efe1ef905 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Thu, 28 Aug 2025 12:18:18 -0400 Subject: [PATCH 059/109] Debugging --- .../flashinfer/attention/generic/prefill.cuh | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh index b995245377..620dcee24c 100644 --- a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -25,6 +25,8 @@ #include "pos_enc.cuh" #include "variants.cuh" +#include + namespace flashinfer { @@ -255,6 +257,29 @@ struct KernelTraits namespace { +template +__device__ __forceinline__ void +debug_printer(uint32_t threadid, const char *var_name, T val) +{ + int global_idx = (blockIdx.z * gridDim.y * gridDim.x + + blockIdx.y * gridDim.x + blockIdx.x) * + (blockDim.z * blockDim.y * blockDim.x) + + (threadIdx.z * blockDim.y * blockDim.x + + threadIdx.y * blockDim.x + threadIdx.x); + + if (global_idx == threadid) { + if constexpr (std::is_integral_v) { + printf("%s : %d\n", var_name, (int)val); + } + else if constexpr (std::is_floating_point_v) { + printf("%s : %f\n", var_name, (float)val); + } + else { + printf("%s : (unsupported type)\n", var_name); + } + } +} + template __device__ __forceinline__ uint32_t get_warp_idx_q(const uint32_t tid_y = threadIdx.y) @@ -1069,6 +1094,8 @@ __device__ __forceinline__ void compute_qk( *q_smem_offset_r = q_smem->template advance_offset_by_row<16, UPCAST_STRIDE_Q>( *q_smem_offset_r); + printf("---------------------->\n"); + debug_printer(0, "a_frag: ", float(a_frag[mma_q][0])); } *q_smem_offset_r = From a0a57abb122c212393e78194b12b016ffcbccbd6 Mon Sep 17 00:00:00 2001 From: rtmadduri Date: Thu, 28 Aug 2025 20:06:38 +0000 Subject: [PATCH 060/109] verified q, k logic --- .../flashinfer/attention/generic/prefill.cuh | 214 +++++++++++------- .../tests/hip/test_single_prefill.cpp | 26 +-- libflashinfer/utils/cpu_reference_hip.h | 73 ++++-- libflashinfer/utils/utils_hip.h | 6 +- 4 files changed, 202 insertions(+), 117 deletions(-) diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh index 620dcee24c..f59cb1945d 100644 --- a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -267,7 +267,9 @@ debug_printer(uint32_t threadid, const char *var_name, T val) (threadIdx.z * blockDim.y * blockDim.x + threadIdx.y * blockDim.x + threadIdx.x); - if (global_idx == threadid) { + if (global_idx == 0 || global_idx == 16 || global_idx == 32 || + global_idx == 48) + { if constexpr (std::is_integral_v) { printf("%s : %d\n", var_name, (int)val); } @@ -1094,8 +1096,14 @@ __device__ __forceinline__ void compute_qk( *q_smem_offset_r = q_smem->template advance_offset_by_row<16, UPCAST_STRIDE_Q>( *q_smem_offset_r); - printf("---------------------->\n"); - debug_printer(0, "a_frag: ", float(a_frag[mma_q][0])); + + // __half* a_frag_half = reinterpret_cast<__half*>(a_frag[mma_q]); + // debug_printer(0, "a_frag_half_0: ", + // float(a_frag_half[0])); debug_printer(0, "a_frag_half_1: + // ", float(a_frag_half[1])); debug_printer(0, + // "a_frag_half_2: ", float(a_frag_half[2])); + // debug_printer(0, "a_frag_half_3: ", + // float(a_frag_half[3])); } *q_smem_offset_r = @@ -1133,6 +1141,15 @@ __device__ __forceinline__ void compute_qk( k_smem->load_fragment(*k_smem_offset_r, b_frag); #endif } + + // __half* b_frag_half = reinterpret_cast<__half*>(b_frag); + // debug_printer(0, "b_frag_half_0: ", + // float(b_frag_half[0])); debug_printer(0, "b_frag_half_1: + // ", float(b_frag_half[1])); debug_printer(0, + // "b_frag_half_2: ", float(b_frag_half[2])); + // debug_printer(0, "b_frag_half_3: ", + // float(b_frag_half[3])); + *k_smem_offset_r = k_smem->template advance_offset_by_row<16, UPCAST_STRIDE_K>( *k_smem_offset_r); @@ -1152,7 +1169,49 @@ __device__ __forceinline__ void compute_qk( typename KTraits::DTypeQ>(s_frag[mma_q][mma_kv], a_frag[mma_q], b_frag); } + +#if Debug + if (mma_q == 0) { + __half *a_frag_half = + reinterpret_cast<__half *>(a_frag[mma_q]); + debug_printer( + 0, "a_frag_half_0: ", float(a_frag_half[0])); + debug_printer( + 0, "a_frag_half_1: ", float(a_frag_half[1])); + debug_printer( + 0, "a_frag_half_2: ", float(a_frag_half[2])); + debug_printer( + 0, "a_frag_half_3: ", float(a_frag_half[3])); + + __syncthreads(); + + // __half* b_frag_half = + // reinterpret_cast<__half*>(b_frag); + // debug_printer(0, "b_frag_half_0: ", + // float(b_frag_half[0])); debug_printer(0, + // "b_frag_half_1: ", float(b_frag_half[1])); + // debug_printer(0, "b_frag_half_2: ", + // float(b_frag_half[2])); debug_printer(0, + // "b_frag_half_3: ", float(b_frag_half[3])); + + // __syncthreads(); + + __half *s_frag_half = + reinterpret_cast<__half *>(s_frag[mma_q][mma_kv]); + debug_printer( + 0, "s_frag_half: ", float(s_frag_half[0])); + debug_printer( + 0, "s_frag_half: ", float(s_frag_half[1])); + debug_printer( + 0, "s_frag_half: ", float(s_frag_half[2])); + debug_printer( + 0, "s_frag_half: ", float(s_frag_half[3])); + + __syncthreads(); + } +#endif } + else if (std::is_same_v) { #if defined(PLATFORM_HIP_DEVICE) static_assert( @@ -2326,36 +2385,35 @@ SinglePrefillWithKVCacheDevice(const Params params, (threadIdx.z * blockDim.y * blockDim.x + threadIdx.y * blockDim.x + threadIdx.x); + // if (global_idx == 0) { + // printf("partition_kv : %d\n", partition_kv); + // printf("kv_len : %d\n", kv_len); + // printf("max_chunk_size : %d\n", max_chunk_size); + // printf("chunk_end : %d\n", chunk_end); + // printf("chunk_start : %d\n", chunk_start); + // } + // Test Q + if (global_idx == 0) { - printf("partition_kv : %d\n", partition_kv); - printf("kv_len : %d\n", kv_len); - printf("max_chunk_size : %d\n", max_chunk_size); - printf("chunk_end : %d\n", chunk_end); - printf("chunk_start : %d\n", chunk_start); + printf("\n DEBUG Q ORIGINAL (HIP):\n"); + uint32_t q_smem_offset_r_debug; + for (auto i = 0; i < 16; ++i) { + for (auto j = 0; j < 4; ++j) { + uint32_t q_smem_offset_r_debug = + qo_smem.template get_permuted_offset( + i, j); + uint32_t a_frag[KTraits::INT32_ELEMS_PER_THREAD]; + qo_smem.load_fragment(q_smem_offset_r_debug, a_frag); + auto frag_T = reinterpret_cast<__half *>(a_frag); + for (auto i = 0ul; i < 4; ++i) { + printf("%f ", (float)(*(frag_T + i))); + } + } + printf("\n"); + qo_smem.template advance_offset_by_row< + 16, KTraits::UPCAST_STRIDE_Q>(q_smem_offset_r); + } } - // Test Q - // if (global_idx == 0) { - // uint32_t q_smem_offset_r_debug; - // //for (auto i = 0; i < 4; ++i) { - // for (auto j = 0; j < 16; ++j) { - // uint32_t q_smem_offset_r_debug = - // qo_smem.template - // get_permuted_offset( - // get_warp_idx_q(tid.y) * NUM_MMA_Q * 16 - // + (j) % 16, (j) / 16); - // uint32_t a_frag[KTraits::INT32_ELEMS_PER_THREAD]; - // k_smem.load_fragment(q_smem_offset_r_debug, a_frag); - // auto frag_T = reinterpret_cast<__half *>(a_frag); - // for (auto i = 0ul; i < 4; ++i) { - // printf("%f ", (float)(*(frag_T + i))); - // } - // printf("\n"); - // } - // // q_smem_offset_r_debug = qo_smem.template - // advance_offset_by_column<4>( - // // q_smem_offset_r_debug, 0); - // // } - // } // for (auto mma_q = 0ul; mma_q < 4; ++mma_q) { // uint32_t a_frag[KTraits::INT32_ELEMS_PER_THREAD]; @@ -2373,55 +2431,59 @@ SinglePrefillWithKVCacheDevice(const Params params, // q_smem_offset_r, 0); // } - uint32_t a_frag[KTraits::INT32_ELEMS_PER_THREAD]; - qo_smem.load_fragment(q_smem_offset_r, a_frag); - if (global_idx == 0) { - auto frag_T = reinterpret_cast<__half *>(a_frag); - printf("DEBUG: Q Frag \n"); - for (auto i = 0ul; i < 4; ++i) { - printf("%f ", (float)(*(frag_T + i))); - } - printf("\n"); - } - - memory::wait_group<0>(); - block.sync(); - q_smem_inplace_apply_rotary(qo_packed_idx_base, qo_len, kv_len, - group_size, &qo_smem, - &q_smem_offset_r, rope_freq, tid); - block.sync(); + // uint32_t a_frag[KTraits::INT32_ELEMS_PER_THREAD]; + // qo_smem.load_fragment(q_smem_offset_r, a_frag); + // if (global_idx == 0) { + // auto frag_T = reinterpret_cast<__half *>(a_frag); + // printf("DEBUG: Q Frag \n"); + // for (auto i = 0ul; i < 4; ++i) { + // printf("%f ", (float)(*(frag_T + i))); + // } + // printf("\n"); + // } - qo_smem.load_fragment(q_smem_offset_r, a_frag); - if (global_idx == 0) { - auto frag_T = reinterpret_cast<__half *>(a_frag); - printf("DEBUG: LLAMA Rope transformed Q Frag \n"); - for (auto i = 0ul; i < 4; ++i) { - printf("%f ", (float)(*(frag_T + i))); - } - printf("\n"); - } + // memory::wait_group<0>(); + // block.sync(); + // q_smem_inplace_apply_rotary(qo_packed_idx_base, qo_len, + // kv_len, + // group_size, &qo_smem, + // &q_smem_offset_r, rope_freq, + // tid); + // block.sync(); - // // Test K loads + // qo_smem.load_fragment(q_smem_offset_r, a_frag); // if (global_idx == 0) { - - // for (auto j = 0; j < 64; ++j) { - // uint32_t k_smem_offset_r_test = - // k_smem.template get_permuted_offset( - // get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + - // j % 16, - // (j / 16)); - // uint32_t b_frag[KTraits::INT32_ELEMS_PER_THREAD]; - // k_smem.load_fragment(k_smem_offset_r_test, b_frag); - // auto frag_T = reinterpret_cast<__half *>(b_frag); - // // printf("DEBUG: K Frag in permuted_smem for mma_kv %lu \n", - // // mma_kv); - // for (auto i = 0ul; i < 4; ++i) { - // printf("%f ", (float)(*(frag_T + i))); - // } - // printf("\n"); + // auto frag_T = reinterpret_cast<__half *>(a_frag); + // printf("DEBUG: LLAMA Rope transformed Q Frag \n"); + // for (auto i = 0ul; i < 4; ++i) { + // printf("%f ", (float)(*(frag_T + i))); // } + // printf("\n"); // } + // Test K loads + if (global_idx == 0) { + printf("\n DEBUG K ORIGINAL (HIP):\n"); + uint32_t k_smem_offset_r_debug; + for (auto i = 0; i < 16; ++i) { + for (auto j = 0; j < 4; ++j) { + uint32_t k_smem_offset_r_debug = + k_smem.template get_permuted_offset(i, + j); + uint32_t a_frag[KTraits::INT32_ELEMS_PER_THREAD]; + k_smem.load_fragment(k_smem_offset_r_debug, a_frag); + auto frag_T = reinterpret_cast<__half *>(a_frag); + for (auto i = 0ul; i < 4; ++i) { + printf("%f ", (float)(*(frag_T + i))); + } + } + printf("\n"); + k_smem.template advance_offset_by_row<16, + KTraits::UPCAST_STRIDE_K>( + k_smem_offset_r); + } + } + // if (global_idx == 0) { // printf("DEBUG Q ORIGINAL (HIP):\n"); @@ -2512,8 +2574,8 @@ SinglePrefillWithKVCacheDevice(const Params params, } // compute attention score - compute_qk(&qo_smem, &q_smem_offset_r, &k_smem, - &k_smem_offset_r, s_frag); + // compute_qk(&qo_smem, &q_smem_offset_r, &k_smem, + // &k_smem_offset_r, s_frag); logits_transform( params, variant, /*batch_idx=*/0, qo_packed_idx_base, diff --git a/libflashinfer/tests/hip/test_single_prefill.cpp b/libflashinfer/tests/hip/test_single_prefill.cpp index 89b06949bc..aa171dc21c 100644 --- a/libflashinfer/tests/hip/test_single_prefill.cpp +++ b/libflashinfer/tests/hip/test_single_prefill.cpp @@ -195,12 +195,12 @@ void _TestSinglePrefillKernelCorrectness(size_t qo_len, std::vector v(kv_len * num_kv_heads * head_dim); std::vector o(qo_len * num_qo_heads * head_dim); - // utils::vec_normal_(q); - // utils::vec_normal_(k); - // utils::vec_normal_(v); - utils::vec_lexicographic_(q); - utils::vec_fill_(k, __float2half(1.0f)); - utils::vec_fill_(v, __float2half(1.0f)); + utils::vec_normal_(q); + utils::vec_normal_(k); + utils::vec_normal_(v); + // utils::vec_lexicographic_(q); + // utils::vec_lexicographic_(k); + // utils::vec_fill_(v, __float2half(1.0f)); utils::vec_zero_(o); DTypeQ *q_d; @@ -270,10 +270,10 @@ void _TestSinglePrefillKernelCorrectness(size_t qo_len, num_results_error_atol += (!utils::isclose(o_ref_val, o_h_val, rtol, atol)); - if (!utils::isclose(o_ref_val, o_h_val, rtol, atol)) { - std::cout << "i=" << i << ", o_ref[i]=" << o_ref_val - << ", o_h[i]=" << o_h_val << std::endl; - } + // if (!utils::isclose(o_ref_val, o_h_val, rtol, atol)) { + // std::cout << "i=" << i << ", o_ref[i]=" << o_ref_val + // << ", o_h[i]=" << o_h_val << std::endl; + // } } // std::cout<<"Printing att_out vector:\n"; // for(auto i: att_out) { @@ -557,12 +557,12 @@ int main(int argc, char **argv) using DTypeIn = __half; using DTypeO = __half; bool use_fp16_qk_reduction = false; - size_t qo_len = 399; - size_t kv_len = 533; + size_t qo_len = 64; + size_t kv_len = 64; size_t num_heads = 1; size_t head_dim = 64; bool causal = false; - size_t pos_encoding_mode = 1; // 1 == kRopeLLama + size_t pos_encoding_mode = 0; // 1 == kRopeLLama size_t kv_layout = 0; _TestSinglePrefillKernelCorrectness( diff --git a/libflashinfer/utils/cpu_reference_hip.h b/libflashinfer/utils/cpu_reference_hip.h index bd32a04a77..448efb9927 100644 --- a/libflashinfer/utils/cpu_reference_hip.h +++ b/libflashinfer/utils/cpu_reference_hip.h @@ -216,9 +216,23 @@ single_mha(const std::vector &q, kv_layout, HEAD_DIM); #if Debug std::cout << "DEBUG: Original Q (CPU): " << '\n'; - for (auto i = 0ul; i < 4; ++i) { + for (auto i = 0ul; i < 16; ++i) { + for (int j = 0; j < 16; ++j) { + std::cout << (float)q[info.get_q_elem_offset(i, 0, j)] << " "; + } + std::cout << std::endl; + // q[info.get_q_elem_offset(q_idx, qo_head_idx, feat_idx) + // std::cout << (float)q[info.get_q_elem_offset(0, 0, i)] << " "; + } + std::cout << std::endl; + + std::cout << "DEBUG: Original K (CPU): " << '\n'; + for (auto i = 0ul; i < 16; ++i) { + for (int j = 0ul; j < 16; ++j) { + std::cout << (float)k[info.get_kv_elem_offset(i, 0, j)] << " "; + } + std::cout << std::endl; // q[info.get_q_elem_offset(q_idx, qo_head_idx, feat_idx) - std::cout << (float)q[info.get_q_elem_offset(0, 0, i)] << " "; } std::cout << std::endl; @@ -227,35 +241,38 @@ single_mha(const std::vector &q, // for (auto i = 0ul; i < 64; ++i) { // // k[info.get_kv_elem_offset(kv_idx, kv_head_idx, feat_idx) // // std::cout << (float)k[info.get_kv_elem_offset(15, 0, j * 4 - // + + // // + // // i)] - // std::cout << (float)k[info.get_kv_elem_offset(j, 0, i)] << " + // std::cout << (float)k[info.get_kv_elem_offset(j, 0, i)] <<" // "; // } // std::cout << '\n'; // } + // std::cout << std::endl; - std::cout << "num_qo_heads " << num_qo_heads << '\n'; - std::cout << "qo_len " << qo_len << '\n'; - for (size_t qo_head_idx = 0; qo_head_idx < num_qo_heads; ++qo_head_idx) - { - for (size_t q_idx = 0; q_idx < qo_len; ++q_idx) { - q_rotary_local = - std::move(cpu_reference::apply_llama_rope_debug( - q.data() + - info.get_q_elem_offset(q_idx, qo_head_idx, 0), - head_dim, q_idx + kv_len - qo_len, rope_scale, - rope_theta)); - } - } + // std::cout << "num_qo_heads " << num_qo_heads << '\n'; + // std::cout << "qo_len " << qo_len << '\n'; + // for (size_t qo_head_idx = 0; qo_head_idx < num_qo_heads; + // ++qo_head_idx) + // { + // for (size_t q_idx = 0; q_idx < qo_len; ++q_idx) { + // q_rotary_local = + // std::move(cpu_reference::apply_llama_rope_debug( + // q.data() + + // info.get_q_elem_offset(q_idx, qo_head_idx, 0), + // head_dim, q_idx + kv_len - qo_len, rope_scale, + // rope_theta)); + // } + // } - std::cout << "DEBUG: LLAMA Rope Transformed Q (CPU): " << '\n'; - for (auto i = 0ul; i < 4; ++i) { - // q[info.get_q_elem_offset(q_idx, qo_head_idx, feat_idx) - std::cout << (float)q_rotary_local[info.get_q_elem_offset(0, 0, i)] - << " "; - } - std::cout << std::endl; + // std::cout << "DEBUG: LLAMA Rope Transformed Q (CPU): " << '\n'; + // for (auto i = 0ul; i < 4; ++i) { + // // q[info.get_q_elem_offset(q_idx, qo_head_idx, feat_idx) + // std::cout << (float)q_rotary_local[info.get_q_elem_offset(0, 0, + // i)] + // << " "; + // } + // std::cout << std::endl; #endif for (size_t qo_head_idx = 0; qo_head_idx < num_qo_heads; ++qo_head_idx) { @@ -274,6 +291,9 @@ single_mha(const std::vector &q, switch (pos_encoding_mode) { case PosEncodingMode::kNone: { +#if Debug + sm_scale = 1.0f; +#endif for (size_t feat_idx = 0; feat_idx < head_dim; ++feat_idx) { @@ -310,12 +330,15 @@ single_mha(const std::vector &q, FLASHINFER_ERROR(err_msg.str()); } } - // apply mask +// apply mask +#if 0 if (causal && kv_idx > kv_len + q_idx - qo_len) { att[kv_idx] = -5e4; } +#endif max_val = std::max(max_val, att[kv_idx]); } + // exp minus max float denom = 0; for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { diff --git a/libflashinfer/utils/utils_hip.h b/libflashinfer/utils/utils_hip.h index 0c1068b8ff..8184bb59b4 100644 --- a/libflashinfer/utils/utils_hip.h +++ b/libflashinfer/utils/utils_hip.h @@ -96,7 +96,7 @@ template void vec_normal_(std::vector &vec, float mean = 0.f, float std = 1.f) { std::random_device rd{}; - std::mt19937 gen{rd()}; + std::mt19937 gen{1234}; std::normal_distribution d{mean, std}; for (size_t i = 0; i < vec.size(); ++i) { float value = static_cast(d(gen)); @@ -108,7 +108,7 @@ template void vec_uniform_(std::vector &vec, float a = 0.f, float b = 1.f) { std::random_device rd{}; - std::mt19937 gen{rd()}; + std::mt19937 gen{1234}; std::uniform_real_distribution d{a, b}; for (size_t i = 0; i < vec.size(); ++i) { float value = static_cast(d(gen)); @@ -130,7 +130,7 @@ template void vec_fill_(std::vector &vec, T val) template void vec_randint_(std::vector &vec, int low, int high) { std::random_device rd{}; - std::mt19937 gen{rd()}; + std::mt19937 gen{1234}; std::uniform_int_distribution d{low, high}; for (size_t i = 0; i < vec.size(); ++i) { float value = static_cast(d(gen)); From ca902d3bc79d96a32fe270c7c53b6d5d7847221b Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Mon, 1 Sep 2025 06:00:53 -0400 Subject: [PATCH 061/109] Debugging produce_kv --- .../flashinfer/attention/generic/prefill.cuh | 17 ++++++++++------- libflashinfer/utils/cpu_reference_hip.h | 6 +++--- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh index f59cb1945d..5795965491 100644 --- a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -83,6 +83,9 @@ struct SharedStorageQKVO { alignas(16) DTypeQ q_smem[CTA_TILE_Q * HEAD_DIM_QK]; alignas(16) DTypeKV k_smem[CTA_TILE_KV * HEAD_DIM_QK]; +#if Debug + alignas(16) DTypeKV qk_scratch[CTA_TILE_KV * HEAD_DIM_QK]; +#endif alignas(16) DTypeKV v_smem[CTA_TILE_KV * HEAD_DIM_VO]; }; struct @@ -2398,8 +2401,8 @@ SinglePrefillWithKVCacheDevice(const Params params, printf("\n DEBUG Q ORIGINAL (HIP):\n"); uint32_t q_smem_offset_r_debug; for (auto i = 0; i < 16; ++i) { - for (auto j = 0; j < 4; ++j) { - uint32_t q_smem_offset_r_debug = + for (auto j = 0; j < 16; ++j) { + q_smem_offset_r_debug = qo_smem.template get_permuted_offset( i, j); uint32_t a_frag[KTraits::INT32_ELEMS_PER_THREAD]; @@ -2411,7 +2414,7 @@ SinglePrefillWithKVCacheDevice(const Params params, } printf("\n"); qo_smem.template advance_offset_by_row< - 16, KTraits::UPCAST_STRIDE_Q>(q_smem_offset_r); + 16, KTraits::UPCAST_STRIDE_Q>(q_smem_offset_r_debug); } } @@ -2465,9 +2468,9 @@ SinglePrefillWithKVCacheDevice(const Params params, if (global_idx == 0) { printf("\n DEBUG K ORIGINAL (HIP):\n"); uint32_t k_smem_offset_r_debug; - for (auto i = 0; i < 16; ++i) { - for (auto j = 0; j < 4; ++j) { - uint32_t k_smem_offset_r_debug = + for (auto i = 0; i < 128; ++i) { + for (auto j = 0; j < 16; ++j) { + k_smem_offset_r_debug = k_smem.template get_permuted_offset(i, j); uint32_t a_frag[KTraits::INT32_ELEMS_PER_THREAD]; @@ -2480,7 +2483,7 @@ SinglePrefillWithKVCacheDevice(const Params params, printf("\n"); k_smem.template advance_offset_by_row<16, KTraits::UPCAST_STRIDE_K>( - k_smem_offset_r); + k_smem_offset_r_debug); } } diff --git a/libflashinfer/utils/cpu_reference_hip.h b/libflashinfer/utils/cpu_reference_hip.h index 448efb9927..de8da43977 100644 --- a/libflashinfer/utils/cpu_reference_hip.h +++ b/libflashinfer/utils/cpu_reference_hip.h @@ -217,7 +217,7 @@ single_mha(const std::vector &q, #if Debug std::cout << "DEBUG: Original Q (CPU): " << '\n'; for (auto i = 0ul; i < 16; ++i) { - for (int j = 0; j < 16; ++j) { + for (int j = 0; j < 64; ++j) { std::cout << (float)q[info.get_q_elem_offset(i, 0, j)] << " "; } std::cout << std::endl; @@ -227,8 +227,8 @@ single_mha(const std::vector &q, std::cout << std::endl; std::cout << "DEBUG: Original K (CPU): " << '\n'; - for (auto i = 0ul; i < 16; ++i) { - for (int j = 0ul; j < 16; ++j) { + for (auto i = 0ul; i < 128; ++i) { + for (int j = 0ul; j < 64; ++j) { std::cout << (float)k[info.get_kv_elem_offset(i, 0, j)] << " "; } std::cout << std::endl; From ea744457f0151f80f45b989d5fa90315b4210e44 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Tue, 2 Sep 2025 11:17:58 -0400 Subject: [PATCH 062/109] Debug.... --- .../flashinfer/attention/generic/prefill.cuh | 163 +++++++++++------- 1 file changed, 96 insertions(+), 67 deletions(-) diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh index 5795965491..701cc34705 100644 --- a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -536,35 +536,34 @@ __device__ __forceinline__ void produce_kv_impl_cdna3_( // CDNA3: Load complete 16×HEAD_DIM tile per i iteration #pragma unroll - for (uint32_t k = 0; k < THREAD_GROUPS_PER_MMA_TILE; ++k) - { // 4 sequence groups -#pragma unroll - for (uint32_t j = 0; j < FEATURE_CHUNKS_PER_THREAD_GROUP; ++j) - { // Feature chunks - smem.template load_vector_async(*smem_offset, *gptr, - kv_idx < kv_len); + for (uint32_t k = 0; k < 4; ++k) { // 4 sequence groups + // #pragma unroll + // for (uint32_t j = 0; j < 4; ++j) + // { // Feature chunks + smem.template load_vector_async(*smem_offset, *gptr, + kv_idx < kv_len); - // Advance to next feature chunk (same sequence group) - *smem_offset = - smem.template advance_offset_by_column( - *smem_offset, j); - *gptr += KV_THR_LAYOUT_COL * - upcast_size(); - } + // Advance to next feature chunk (same sequence group) + *smem_offset = + smem.template advance_offset_by_row<4, UPCAST_STRIDE>( + *smem_offset); + *gptr += 4 * 16 * upcast_size(); + //} // Advance to next sequence group within same MMA tile - if (k < THREAD_GROUPS_PER_MMA_TILE - 1) - { // Don't advance after last group - kv_idx += NUM_WARPS * KV_THR_LAYOUT_ROW; - *smem_offset = - smem.template advance_offset_by_row< - NUM_WARPS * KV_THR_LAYOUT_ROW, UPCAST_STRIDE>( - *smem_offset) - - COLUMN_RESET_OFFSET; - *gptr += NUM_WARPS * KV_THR_LAYOUT_ROW * stride_n - - FEATURE_CHUNKS_PER_THREAD_GROUP * KV_THR_LAYOUT_COL * - upcast_size(); - } + // if (k < THREAD_GROUPS_PER_MMA_TILE - 1) + // { // Don't advance after last group + // kv_idx += NUM_WARPS * KV_THR_LAYOUT_ROW; + // *smem_offset = + // smem.template advance_offset_by_row< + // NUM_WARPS * KV_THR_LAYOUT_ROW, UPCAST_STRIDE>( + // *smem_offset) - + // COLUMN_RESET_OFFSET; + // *gptr += NUM_WARPS * KV_THR_LAYOUT_ROW * stride_n - + // FEATURE_CHUNKS_PER_THREAD_GROUP * KV_THR_LAYOUT_COL + // * + // upcast_size(); + // } } // Final advance to next MMA tile @@ -576,6 +575,11 @@ __device__ __forceinline__ void produce_kv_impl_cdna3_( *gptr += NUM_WARPS * KV_THR_LAYOUT_ROW * stride_n - FEATURE_CHUNKS_PER_THREAD_GROUP * KV_THR_LAYOUT_COL * upcast_size(); + + // *smem_offset = + // smem.template advance_offset_by_row(*smem_offset); + // *gptr += 4*16 * upcast_size(); } *smem_offset -= KTraits::CTA_TILE_KV * UPCAST_STRIDE; } @@ -2282,15 +2286,15 @@ SinglePrefillWithKVCacheDevice(const Params params, (kv_head_idx * group_size) * o_stride_h : o + (kv_head_idx * group_size) * o_stride_h; + load_q_global_smem(qo_packed_idx_base, qo_len, q_ptr_base, + q_stride_n, q_stride_h, group_size, + &qo_smem, tid); + uint32_t q_smem_offset_r = qo_smem.template get_permuted_offset( get_warp_idx_q(tid.y) * NUM_MMA_Q * 16 + lane_idx % 16, lane_idx / 16); - load_q_global_smem(qo_packed_idx_base, qo_len, q_ptr_base, - q_stride_n, q_stride_h, group_size, - &qo_smem, tid); - memory::commit_group(); if constexpr (KTraits::POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) { @@ -2395,28 +2399,29 @@ SinglePrefillWithKVCacheDevice(const Params params, // printf("chunk_end : %d\n", chunk_end); // printf("chunk_start : %d\n", chunk_start); // } - // Test Q - if (global_idx == 0) { - printf("\n DEBUG Q ORIGINAL (HIP):\n"); - uint32_t q_smem_offset_r_debug; - for (auto i = 0; i < 16; ++i) { - for (auto j = 0; j < 16; ++j) { - q_smem_offset_r_debug = - qo_smem.template get_permuted_offset( - i, j); - uint32_t a_frag[KTraits::INT32_ELEMS_PER_THREAD]; - qo_smem.load_fragment(q_smem_offset_r_debug, a_frag); - auto frag_T = reinterpret_cast<__half *>(a_frag); - for (auto i = 0ul; i < 4; ++i) { - printf("%f ", (float)(*(frag_T + i))); - } - } - printf("\n"); - qo_smem.template advance_offset_by_row< - 16, KTraits::UPCAST_STRIDE_Q>(q_smem_offset_r_debug); - } - } + // // Test Q + // if (global_idx == 0) { + // printf("\n DEBUG Q ORIGINAL (HIP):\n"); + // uint32_t q_smem_offset_r_debug; + // for (auto i = 0; i < 16; ++i) { + // for (auto j = 0; j < 16; ++j) { + // q_smem_offset_r_debug = + // qo_smem.template + // get_permuted_offset( + // i, j); + // uint32_t a_frag[KTraits::INT32_ELEMS_PER_THREAD]; + // qo_smem.load_fragment(q_smem_offset_r_debug, a_frag); + // auto frag_T = reinterpret_cast<__half *>(a_frag); + // for (auto i = 0ul; i < 4; ++i) { + // printf("%f ", (float)(*(frag_T + i))); + // } + // } + // printf("\n"); + // qo_smem.template advance_offset_by_row< + // 16, KTraits::UPCAST_STRIDE_Q>(q_smem_offset_r_debug); + // } + // } // for (auto mma_q = 0ul; mma_q < 4; ++mma_q) { // uint32_t a_frag[KTraits::INT32_ELEMS_PER_THREAD]; @@ -2464,29 +2469,53 @@ SinglePrefillWithKVCacheDevice(const Params params, // printf("\n"); // } - // Test K loads + // Test K Global values if (global_idx == 0) { - printf("\n DEBUG K ORIGINAL (HIP):\n"); - uint32_t k_smem_offset_r_debug; + printf("\n DEBUG K Global (HIP):\n"); + printf("k_stride_n : %d\n", k_stride_n); + printf("k_stride_h : %d\n", k_stride_h); + printf("kv_head_idx : %d\n", kv_head_idx); + DTypeKV *k_ptr_tmp = k + + (chunk_start + warp_idx * KV_THR_LAYOUT_ROW + + lane_idx / KV_THR_LAYOUT_COL) * + k_stride_n + + kv_head_idx * k_stride_h + + (lane_idx % KV_THR_LAYOUT_COL) * + upcast_size(); for (auto i = 0; i < 128; ++i) { - for (auto j = 0; j < 16; ++j) { - k_smem_offset_r_debug = - k_smem.template get_permuted_offset(i, - j); - uint32_t a_frag[KTraits::INT32_ELEMS_PER_THREAD]; - k_smem.load_fragment(k_smem_offset_r_debug, a_frag); - auto frag_T = reinterpret_cast<__half *>(a_frag); - for (auto i = 0ul; i < 4; ++i) { - printf("%f ", (float)(*(frag_T + i))); - } + for (auto j = 0; j < 64; ++j) { + auto fKval = (float)*(k_ptr_tmp); + k_ptr_tmp += 1; + printf("%f ", fKval); } printf("\n"); - k_smem.template advance_offset_by_row<16, - KTraits::UPCAST_STRIDE_K>( - k_smem_offset_r_debug); } } + // Test K loads + // if (global_idx == 0) { + // printf("\n DEBUG K ORIGINAL (HIP):\n"); + // uint32_t k_smem_offset_r_debug; + // for (auto i = 0; i < 128; ++i) { + // for (auto j = 0; j < 16; ++j) { + // k_smem_offset_r_debug = + // k_smem.template + // get_permuted_offset(i, + // j); + // uint32_t a_frag[KTraits::INT32_ELEMS_PER_THREAD]; + // k_smem.load_fragment(k_smem_offset_r_debug, a_frag); + // auto frag_T = reinterpret_cast<__half *>(a_frag); + // for (auto i = 0ul; i < 4; ++i) { + // printf("%f ", (float)(*(frag_T + i))); + // } + // } + // printf("\n"); + // k_smem.template advance_offset_by_row<16, + // KTraits::UPCAST_STRIDE_K>( + // k_smem_offset_r_debug); + // } + // } + // if (global_idx == 0) { // printf("DEBUG Q ORIGINAL (HIP):\n"); From e67ea64039b39f010f049d45f7f19a26b3b61393 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Wed, 3 Sep 2025 04:28:23 -0400 Subject: [PATCH 063/109] Debug produce_kv_impl_cdna3_ --- .../flashinfer/attention/generic/prefill.cuh | 127 +++++++----------- .../tests/hip/test_single_prefill.cpp | 4 +- 2 files changed, 53 insertions(+), 78 deletions(-) diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh index 701cc34705..a332d9dc76 100644 --- a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -511,18 +511,9 @@ __device__ __forceinline__ void produce_kv_impl_cdna3_( constexpr uint32_t UPCAST_STRIDE = produce_v ? KTraits::UPCAST_STRIDE_V : KTraits::UPCAST_STRIDE_K; constexpr uint32_t VECTOR_BIT_WIDTH = KTraits::VECTOR_BIT_WIDTH; - constexpr uint32_t HALF_ELEMS_PER_THREAD = - KTraits::HALF_ELEMS_PER_THREAD; // 4 - - // CDNA3-specific constants - constexpr uint32_t SEQUENCES_PER_MMA_TILE = 16; - constexpr uint32_t SEQUENCES_PER_THREAD_GROUP = KV_THR_LAYOUT_ROW; // 4 - constexpr uint32_t THREAD_GROUPS_PER_MMA_TILE = - SEQUENCES_PER_MMA_TILE / SEQUENCES_PER_THREAD_GROUP; // 4 - constexpr uint32_t FEATURE_CHUNKS_PER_THREAD_GROUP = - NUM_MMA_D / HALF_ELEMS_PER_THREAD; // NUM_MMA_D/4 - constexpr uint32_t COLUMN_RESET_OFFSET = - FEATURE_CHUNKS_PER_THREAD_GROUP * KV_THR_LAYOUT_COL; + constexpr uint32_t HALF_ELEMS_PER_THREAD = KTraits::HALF_ELEMS_PER_THREAD; + constexpr uint32_t FEATURES_PER_THREAD_ROW = + HALF_ELEMS_PER_THREAD * KV_THR_LAYOUT_COL; uint32_t row = lane_idx / KV_THR_LAYOUT_COL; uint32_t kv_idx = kv_idx_base + warp_idx * KV_THR_LAYOUT_ROW + row; @@ -531,55 +522,36 @@ __device__ __forceinline__ void produce_kv_impl_cdna3_( static_assert(NUM_MMA_KV * 4 % NUM_WARPS_Q == 0); #pragma unroll - for (uint32_t i = 0; i < NUM_MMA_KV * 4 / NUM_WARPS_Q; ++i) - { // MMA tile iterations - - // CDNA3: Load complete 16×HEAD_DIM tile per i iteration + for (uint32_t i = 0; i < NUM_MMA_KV * 4 / NUM_WARPS_Q; ++i) { + for (uint32_t j = 0; j < KTraits::HEAD_DIM_QK / FEATURES_PER_THREAD_ROW; + ++j) + { #pragma unroll - for (uint32_t k = 0; k < 4; ++k) { // 4 sequence groups - // #pragma unroll - // for (uint32_t j = 0; j < 4; ++j) - // { // Feature chunks - smem.template load_vector_async(*smem_offset, *gptr, - kv_idx < kv_len); - - // Advance to next feature chunk (same sequence group) + for (uint32_t k = 0; k < FEATURES_PER_THREAD_ROW; ++k) { + smem.template load_vector_async(*smem_offset, *gptr, + kv_idx < kv_len); + *smem_offset = + smem.template advance_offset_by_row<4, UPCAST_STRIDE>( + *smem_offset); + // FIXME: The below logic will not handle cases where HEAD_DIMS + // > 64 + *gptr += 64 * upcast_size(); + } *smem_offset = - smem.template advance_offset_by_row<4, UPCAST_STRIDE>( + smem.template advance_offset_by_row<64, UPCAST_STRIDE>( *smem_offset); - *gptr += 4 * 16 * upcast_size(); - //} - - // Advance to next sequence group within same MMA tile - // if (k < THREAD_GROUPS_PER_MMA_TILE - 1) - // { // Don't advance after last group - // kv_idx += NUM_WARPS * KV_THR_LAYOUT_ROW; - // *smem_offset = - // smem.template advance_offset_by_row< - // NUM_WARPS * KV_THR_LAYOUT_ROW, UPCAST_STRIDE>( - // *smem_offset) - - // COLUMN_RESET_OFFSET; - // *gptr += NUM_WARPS * KV_THR_LAYOUT_ROW * stride_n - - // FEATURE_CHUNKS_PER_THREAD_GROUP * KV_THR_LAYOUT_COL - // * - // upcast_size(); - // } } // Final advance to next MMA tile kv_idx += NUM_WARPS * KV_THR_LAYOUT_ROW; - *smem_offset = - smem.template advance_offset_by_row(*smem_offset) - - COLUMN_RESET_OFFSET; - *gptr += NUM_WARPS * KV_THR_LAYOUT_ROW * stride_n - - FEATURE_CHUNKS_PER_THREAD_GROUP * KV_THR_LAYOUT_COL * - upcast_size(); - - // *smem_offset = - // smem.template advance_offset_by_row(*smem_offset); - // *gptr += 4*16 * upcast_size(); + *smem_offset = smem.template advance_offset_by_row<4, UPCAST_STRIDE>( + *smem_offset) - + 64; + // *gptr += NUM_WARPS * KV_THR_LAYOUT_ROW * stride_n - + // FEATURE_CHUNKS_PER_THREAD_GROUP * KV_THR_LAYOUT_COL * + // upcast_size(); + // FIXME: The below logic will not handle cases where HEAD_DIMS > 64 + *gptr += 64 * upcast_size(); } *smem_offset -= KTraits::CTA_TILE_KV * UPCAST_STRIDE; } @@ -2475,6 +2447,10 @@ SinglePrefillWithKVCacheDevice(const Params params, printf("k_stride_n : %d\n", k_stride_n); printf("k_stride_h : %d\n", k_stride_h); printf("kv_head_idx : %d\n", kv_head_idx); + printf("num_qo_heads : %d\n", num_qo_heads); + printf("num_kv_heads : %d\n", num_kv_heads); + printf("k_stride_n : %d\n", k_stride_n); + DTypeKV *k_ptr_tmp = k + (chunk_start + warp_idx * KV_THR_LAYOUT_ROW + lane_idx / KV_THR_LAYOUT_COL) * @@ -2493,28 +2469,27 @@ SinglePrefillWithKVCacheDevice(const Params params, } // Test K loads - // if (global_idx == 0) { - // printf("\n DEBUG K ORIGINAL (HIP):\n"); - // uint32_t k_smem_offset_r_debug; - // for (auto i = 0; i < 128; ++i) { - // for (auto j = 0; j < 16; ++j) { - // k_smem_offset_r_debug = - // k_smem.template - // get_permuted_offset(i, - // j); - // uint32_t a_frag[KTraits::INT32_ELEMS_PER_THREAD]; - // k_smem.load_fragment(k_smem_offset_r_debug, a_frag); - // auto frag_T = reinterpret_cast<__half *>(a_frag); - // for (auto i = 0ul; i < 4; ++i) { - // printf("%f ", (float)(*(frag_T + i))); - // } - // } - // printf("\n"); - // k_smem.template advance_offset_by_row<16, - // KTraits::UPCAST_STRIDE_K>( - // k_smem_offset_r_debug); - // } - // } + if (global_idx == 0) { + printf("\n DEBUG K LDS ORIGINAL (HIP):\n"); + uint32_t k_smem_offset_r_debug; + for (auto i = 0; i < 128; ++i) { + for (auto j = 0; j < 16; ++j) { + k_smem_offset_r_debug = + k_smem.template get_permuted_offset(i, + j); + uint32_t a_frag[KTraits::INT32_ELEMS_PER_THREAD]; + k_smem.load_fragment(k_smem_offset_r_debug, a_frag); + auto frag_T = reinterpret_cast<__half *>(a_frag); + for (auto i = 0ul; i < 4; ++i) { + printf("%f ", (float)(*(frag_T + i))); + } + } + printf("\n"); + k_smem.template advance_offset_by_row<16, + KTraits::UPCAST_STRIDE_K>( + k_smem_offset_r_debug); + } + } // if (global_idx == 0) { // printf("DEBUG Q ORIGINAL (HIP):\n"); diff --git a/libflashinfer/tests/hip/test_single_prefill.cpp b/libflashinfer/tests/hip/test_single_prefill.cpp index aa171dc21c..a68e67fbf5 100644 --- a/libflashinfer/tests/hip/test_single_prefill.cpp +++ b/libflashinfer/tests/hip/test_single_prefill.cpp @@ -557,8 +557,8 @@ int main(int argc, char **argv) using DTypeIn = __half; using DTypeO = __half; bool use_fp16_qk_reduction = false; - size_t qo_len = 64; - size_t kv_len = 64; + size_t qo_len = 128; + size_t kv_len = 128; size_t num_heads = 1; size_t head_dim = 64; bool causal = false; From a6633f04c43998abe261b33b482d8d01cac36986 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Wed, 3 Sep 2025 06:00:12 -0400 Subject: [PATCH 064/109] Fixed produce_kv_impl_cdna3_ --- .../flashinfer/attention/generic/prefill.cuh | 53 +++++++------------ 1 file changed, 20 insertions(+), 33 deletions(-) diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh index a332d9dc76..e710e036d1 100644 --- a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -471,8 +471,7 @@ __device__ __forceinline__ void produce_kv_impl_cuda_( } else { uint32_t kv_idx = kv_idx_base + warp_idx * 8 + lane_idx / 4; - // NOTE: NUM_MMA_KV * 2 / NUM_WARPS_Q = NUM_WARPS_KV * NUM_MMA_KV * 2 / - // num_warps + // NOTE: NUM_MMA_KV * 2 / NUM_WARPS_Q = NUM_WARPS_KV * NUM_MMA_KV * 2 / num_warps static_assert(NUM_MMA_KV * 2 % NUM_WARPS_Q == 0); #pragma unroll for (uint32_t i = 0; i < NUM_MMA_KV * 2 / NUM_WARPS_Q; ++i) { @@ -512,46 +511,31 @@ __device__ __forceinline__ void produce_kv_impl_cdna3_( produce_v ? KTraits::UPCAST_STRIDE_V : KTraits::UPCAST_STRIDE_K; constexpr uint32_t VECTOR_BIT_WIDTH = KTraits::VECTOR_BIT_WIDTH; constexpr uint32_t HALF_ELEMS_PER_THREAD = KTraits::HALF_ELEMS_PER_THREAD; - constexpr uint32_t FEATURES_PER_THREAD_ROW = - HALF_ELEMS_PER_THREAD * KV_THR_LAYOUT_COL; - - uint32_t row = lane_idx / KV_THR_LAYOUT_COL; - uint32_t kv_idx = kv_idx_base + warp_idx * KV_THR_LAYOUT_ROW + row; // NOTE: NUM_MMA_KV*4/NUM_WARPS_Q = NUM_WARPS_KV*NUM_MMA_KV*4/num_warps static_assert(NUM_MMA_KV * 4 % NUM_WARPS_Q == 0); + uint32_t kv_idx = kv_idx_base + warp_idx * 4 + lane_idx / KV_THR_LAYOUT_COL; + // NOTE: NUM_MMA_KV * 4 / NUM_WARPS_Q = NUM_WARPS_KV * NUM_MMA_KV * 4 / num_warps + static_assert(NUM_MMA_KV * 4 % NUM_WARPS_Q == 0); #pragma unroll for (uint32_t i = 0; i < NUM_MMA_KV * 4 / NUM_WARPS_Q; ++i) { - for (uint32_t j = 0; j < KTraits::HEAD_DIM_QK / FEATURES_PER_THREAD_ROW; - ++j) - { #pragma unroll - for (uint32_t k = 0; k < FEATURES_PER_THREAD_ROW; ++k) { - smem.template load_vector_async(*smem_offset, *gptr, - kv_idx < kv_len); - *smem_offset = - smem.template advance_offset_by_row<4, UPCAST_STRIDE>( - *smem_offset); - // FIXME: The below logic will not handle cases where HEAD_DIMS - // > 64 - *gptr += 64 * upcast_size(); - } + for (uint32_t j = 0; j < NUM_MMA_D / (8 / sizeof(DTypeKV)); ++j) { + smem.template load_vector_async(*smem_offset, *gptr, + kv_idx < kv_len); *smem_offset = - smem.template advance_offset_by_row<64, UPCAST_STRIDE>( - *smem_offset); + smem.template advance_offset_by_column<16>(*smem_offset, j); + *gptr += 16 * upcast_size(); } - - // Final advance to next MMA tile - kv_idx += NUM_WARPS * KV_THR_LAYOUT_ROW; - *smem_offset = smem.template advance_offset_by_row<4, UPCAST_STRIDE>( - *smem_offset) - - 64; - // *gptr += NUM_WARPS * KV_THR_LAYOUT_ROW * stride_n - - // FEATURE_CHUNKS_PER_THREAD_GROUP * KV_THR_LAYOUT_COL * - // upcast_size(); - // FIXME: The below logic will not handle cases where HEAD_DIMS > 64 - *gptr += 64 * upcast_size(); + kv_idx += NUM_WARPS * 4; + *smem_offset = smem.template advance_offset_by_row( + *smem_offset) - + sizeof(DTypeKV) * NUM_MMA_D; + *gptr += NUM_WARPS * 4 * stride_n - + sizeof(DTypeKV) * NUM_MMA_D * + upcast_size(); } *smem_offset -= KTraits::CTA_TILE_KV * UPCAST_STRIDE; } @@ -2450,6 +2434,9 @@ SinglePrefillWithKVCacheDevice(const Params params, printf("num_qo_heads : %d\n", num_qo_heads); printf("num_kv_heads : %d\n", num_kv_heads); printf("k_stride_n : %d\n", k_stride_n); + printf("KTraits::NUM_MMA_D_QK : %d\n", KTraits::NUM_MMA_D_QK); + printf("NUM_MMA_KV : %d\n", NUM_MMA_KV); + printf("NUM_MMA_Q : %d\n", NUM_MMA_Q); DTypeKV *k_ptr_tmp = k + (chunk_start + warp_idx * KV_THR_LAYOUT_ROW + From 9ef0a4b5820564ca008c40c66a6582c696fb4436 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Wed, 3 Sep 2025 14:56:53 -0400 Subject: [PATCH 065/109] Off-by-one error --- .../flashinfer/attention/generic/prefill.cuh | 186 +++--------------- libflashinfer/utils/cpu_reference_hip.h | 2 +- 2 files changed, 30 insertions(+), 158 deletions(-) diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh index e710e036d1..b8551c149c 100644 --- a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -471,7 +471,8 @@ __device__ __forceinline__ void produce_kv_impl_cuda_( } else { uint32_t kv_idx = kv_idx_base + warp_idx * 8 + lane_idx / 4; - // NOTE: NUM_MMA_KV * 2 / NUM_WARPS_Q = NUM_WARPS_KV * NUM_MMA_KV * 2 / num_warps + // NOTE: NUM_MMA_KV * 2 / NUM_WARPS_Q = NUM_WARPS_KV * NUM_MMA_KV * 2 / + // num_warps static_assert(NUM_MMA_KV * 2 % NUM_WARPS_Q == 0); #pragma unroll for (uint32_t i = 0; i < NUM_MMA_KV * 2 / NUM_WARPS_Q; ++i) { @@ -516,26 +517,27 @@ __device__ __forceinline__ void produce_kv_impl_cdna3_( static_assert(NUM_MMA_KV * 4 % NUM_WARPS_Q == 0); uint32_t kv_idx = kv_idx_base + warp_idx * 4 + lane_idx / KV_THR_LAYOUT_COL; - // NOTE: NUM_MMA_KV * 4 / NUM_WARPS_Q = NUM_WARPS_KV * NUM_MMA_KV * 4 / num_warps + // NOTE: NUM_MMA_KV * 4 / NUM_WARPS_Q = NUM_WARPS_KV * NUM_MMA_KV * 4 / + // num_warps static_assert(NUM_MMA_KV * 4 % NUM_WARPS_Q == 0); #pragma unroll for (uint32_t i = 0; i < NUM_MMA_KV * 4 / NUM_WARPS_Q; ++i) { #pragma unroll for (uint32_t j = 0; j < NUM_MMA_D / (8 / sizeof(DTypeKV)); ++j) { smem.template load_vector_async(*smem_offset, *gptr, - kv_idx < kv_len); + kv_idx < kv_len); *smem_offset = smem.template advance_offset_by_column<16>(*smem_offset, j); *gptr += 16 * upcast_size(); } kv_idx += NUM_WARPS * 4; - *smem_offset = smem.template advance_offset_by_row( - *smem_offset) - - sizeof(DTypeKV) * NUM_MMA_D; + *smem_offset = + smem.template advance_offset_by_row( + *smem_offset) - + sizeof(DTypeKV) * NUM_MMA_D; *gptr += NUM_WARPS * 4 * stride_n - - sizeof(DTypeKV) * NUM_MMA_D * - upcast_size(); + sizeof(DTypeKV) * NUM_MMA_D * + upcast_size(); } *smem_offset -= KTraits::CTA_TILE_KV * UPCAST_STRIDE; } @@ -1059,14 +1061,6 @@ __device__ __forceinline__ void compute_qk( *q_smem_offset_r = q_smem->template advance_offset_by_row<16, UPCAST_STRIDE_Q>( *q_smem_offset_r); - - // __half* a_frag_half = reinterpret_cast<__half*>(a_frag[mma_q]); - // debug_printer(0, "a_frag_half_0: ", - // float(a_frag_half[0])); debug_printer(0, "a_frag_half_1: - // ", float(a_frag_half[1])); debug_printer(0, - // "a_frag_half_2: ", float(a_frag_half[2])); - // debug_printer(0, "a_frag_half_3: ", - // float(a_frag_half[3])); } *q_smem_offset_r = @@ -1105,14 +1099,6 @@ __device__ __forceinline__ void compute_qk( #endif } - // __half* b_frag_half = reinterpret_cast<__half*>(b_frag); - // debug_printer(0, "b_frag_half_0: ", - // float(b_frag_half[0])); debug_printer(0, "b_frag_half_1: - // ", float(b_frag_half[1])); debug_printer(0, - // "b_frag_half_2: ", float(b_frag_half[2])); - // debug_printer(0, "b_frag_half_3: ", - // float(b_frag_half[3])); - *k_smem_offset_r = k_smem->template advance_offset_by_row<16, UPCAST_STRIDE_K>( *k_smem_offset_r); @@ -2348,13 +2334,13 @@ SinglePrefillWithKVCacheDevice(const Params params, (threadIdx.z * blockDim.y * blockDim.x + threadIdx.y * blockDim.x + threadIdx.x); - // if (global_idx == 0) { - // printf("partition_kv : %d\n", partition_kv); - // printf("kv_len : %d\n", kv_len); - // printf("max_chunk_size : %d\n", max_chunk_size); - // printf("chunk_end : %d\n", chunk_end); - // printf("chunk_start : %d\n", chunk_start); - // } + if (global_idx == 0) { + printf("partition_kv : %d\n", partition_kv); + printf("kv_len : %d\n", kv_len); + printf("max_chunk_size : %d\n", max_chunk_size); + printf("chunk_end : %d\n", chunk_end); + printf("chunk_start : %d\n", chunk_start); + } // // Test Q // if (global_idx == 0) { @@ -2379,53 +2365,8 @@ SinglePrefillWithKVCacheDevice(const Params params, // } // } - // for (auto mma_q = 0ul; mma_q < 4; ++mma_q) { - // uint32_t a_frag[KTraits::INT32_ELEMS_PER_THREAD]; - // qo_smem.load_fragment(q_smem_offset_r, a_frag); - // if (global_idx == 0) { - // auto frag_T = reinterpret_cast<__half *>(a_frag); - // printf("DEBUG: Q Frag in permuted_smem for mma_q %lu \n", - // mma_q); for (auto i = 0ul; i < 4; ++i) { - // printf("%f ", (float)(*(frag_T + i))); - // } - // printf("\n"); - // } - - // q_smem_offset_r = qo_smem.template advance_offset_by_column<4>( - // q_smem_offset_r, 0); - // } - - // uint32_t a_frag[KTraits::INT32_ELEMS_PER_THREAD]; - // qo_smem.load_fragment(q_smem_offset_r, a_frag); - // if (global_idx == 0) { - // auto frag_T = reinterpret_cast<__half *>(a_frag); - // printf("DEBUG: Q Frag \n"); - // for (auto i = 0ul; i < 4; ++i) { - // printf("%f ", (float)(*(frag_T + i))); - // } - // printf("\n"); - // } - - // memory::wait_group<0>(); - // block.sync(); - // q_smem_inplace_apply_rotary(qo_packed_idx_base, qo_len, - // kv_len, - // group_size, &qo_smem, - // &q_smem_offset_r, rope_freq, - // tid); - // block.sync(); - - // qo_smem.load_fragment(q_smem_offset_r, a_frag); - // if (global_idx == 0) { - // auto frag_T = reinterpret_cast<__half *>(a_frag); - // printf("DEBUG: LLAMA Rope transformed Q Frag \n"); - // for (auto i = 0ul; i < 4; ++i) { - // printf("%f ", (float)(*(frag_T + i))); - // } - // printf("\n"); - // } - - // Test K Global values + // Test K Global values: + // Prints the (NUM_MMA_KV*16) x (NUM_MMA_D*16) matrix from global mem. if (global_idx == 0) { printf("\n DEBUG K Global (HIP):\n"); printf("k_stride_n : %d\n", k_stride_n); @@ -2445,8 +2386,8 @@ SinglePrefillWithKVCacheDevice(const Params params, kv_head_idx * k_stride_h + (lane_idx % KV_THR_LAYOUT_COL) * upcast_size(); - for (auto i = 0; i < 128; ++i) { - for (auto j = 0; j < 64; ++j) { + for (auto i = 0; i < NUM_MMA_KV * 16; ++i) { + for (auto j = 0; j < NUM_MMA_D_QK * 16; ++j) { auto fKval = (float)*(k_ptr_tmp); k_ptr_tmp += 1; printf("%f ", fKval); @@ -2455,12 +2396,16 @@ SinglePrefillWithKVCacheDevice(const Params params, } } - // Test K loads + // Test K LDS values: + // Prints the (NUM_MMA_KV*16) x (NUM_MMA_D*16) matrix from shared mem. + // Note that LDS is loaded collaboratively by all warps and not each + // warp accesses the whole K matrix loaded into LDS. Each warp will + // only access 1/4 of the K values loaded into LDS> if (global_idx == 0) { printf("\n DEBUG K LDS ORIGINAL (HIP):\n"); uint32_t k_smem_offset_r_debug; - for (auto i = 0; i < 128; ++i) { - for (auto j = 0; j < 16; ++j) { + for (auto i = 0; i < NUM_MMA_KV * 16; ++i) { + for (auto j = 0; j < NUM_MMA_D_QK * 4; ++j) { k_smem_offset_r_debug = k_smem.template get_permuted_offset(i, j); @@ -2477,79 +2422,6 @@ SinglePrefillWithKVCacheDevice(const Params params, k_smem_offset_r_debug); } } - - // if (global_idx == 0) { - // printf("DEBUG Q ORIGINAL (HIP):\n"); - - // for (uint32_t seq_idx = 0; seq_idx < 16; ++seq_idx) { - // printf("Q[%u] original: ", seq_idx); - - // // Load all feature groups for this sequence - // for (uint32_t feat_group = 0; feat_group < NUM_MMA_D_QK; - // ++feat_group) { - // uint32_t feat_offset = qo_smem.template - // get_permuted_offset( - // seq_idx, feat_group * HALF_ELEMS_PER_THREAD); - - // uint32_t q_frag[KTraits::INT32_ELEMS_PER_THREAD]; - // qo_smem.load_fragment(feat_offset, q_frag); - // auto frag_T = reinterpret_cast<__half *>(q_frag); - - // // Print 4 features from this group - // for (auto feat = 0ul; feat < HALF_ELEMS_PER_THREAD; - // ++feat) { - // printf("%f ", (float)(*(frag_T + feat))); - // } - // } - // printf("\n"); - // } - // } - - // memory::wait_group<0>(); - // block.sync(); - // q_smem_inplace_apply_rotary( - // qo_packed_idx_base, qo_len, kv_len, group_size, &qo_smem, - // &q_smem_offset_r, rope_freq, tid); - // block.sync(); - - // // Debug: Print Q fragments after RoPE - // if (global_idx == 0) { - // printf("DEBUG Q LLAMA ROPE (HIP):\n"); - - // // Reset q_smem_offset_r to start - // uint32_t q_smem_offset_r_debug = - // qo_smem.template get_permuted_offset( - // get_warp_idx_q(tid.y) * NUM_MMA_Q * 16 + - // lane_idx % 16, lane_idx / 16); - - // for (uint32_t seq_idx = 0; seq_idx < 16; ++seq_idx) { - // // Calculate offset for this sequence - // uint32_t seq_offset = qo_smem.template - // get_permuted_offset( - // seq_idx, 0); - - // printf("Q[%u] after RoPE: ", seq_idx); - - // // Load all feature groups for this sequence - // for (uint32_t feat_group = 0; feat_group < NUM_MMA_D_QK; - // ++feat_group) { - // uint32_t feat_offset = qo_smem.template - // get_permuted_offset( - // seq_idx, feat_group * HALF_ELEMS_PER_THREAD); - - // uint32_t q_frag[KTraits::INT32_ELEMS_PER_THREAD]; - // qo_smem.load_fragment(feat_offset, q_frag); - // auto frag_T = reinterpret_cast<__half *>(q_frag); - - // // Print 4 features from this group - // for (auto feat = 0ul; feat < HALF_ELEMS_PER_THREAD; - // ++feat) { - // printf("%f ", (float)(*(frag_T + feat))); - // } - // } - // printf("\n"); - // } - // } #endif #pragma unroll 1 diff --git a/libflashinfer/utils/cpu_reference_hip.h b/libflashinfer/utils/cpu_reference_hip.h index de8da43977..9abc3a04cc 100644 --- a/libflashinfer/utils/cpu_reference_hip.h +++ b/libflashinfer/utils/cpu_reference_hip.h @@ -227,7 +227,7 @@ single_mha(const std::vector &q, std::cout << std::endl; std::cout << "DEBUG: Original K (CPU): " << '\n'; - for (auto i = 0ul; i < 128; ++i) { + for (auto i = 0ul; i < 64; ++i) { for (int j = 0ul; j < 64; ++j) { std::cout << (float)k[info.get_kv_elem_offset(i, 0, j)] << " "; } From b9b34b2c0fd0aef408774086141234b249af33bd Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Fri, 5 Sep 2025 17:43:31 -0400 Subject: [PATCH 066/109] Changes to logits functions --- .../flashinfer/attention/generic/prefill.cuh | 44 ++++++++++++------- 1 file changed, 27 insertions(+), 17 deletions(-) diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh index b8551c149c..c60219250d 100644 --- a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -517,15 +517,31 @@ __device__ __forceinline__ void produce_kv_impl_cdna3_( static_assert(NUM_MMA_KV * 4 % NUM_WARPS_Q == 0); uint32_t kv_idx = kv_idx_base + warp_idx * 4 + lane_idx / KV_THR_LAYOUT_COL; - // NOTE: NUM_MMA_KV * 4 / NUM_WARPS_Q = NUM_WARPS_KV * NUM_MMA_KV * 4 / - // num_warps + // NOTE: NUM_MMA_KV * 4 / NUM_WARPS_Q = NUM_WARPS_KV*NUM_MMA_KV*4/num_warps static_assert(NUM_MMA_KV * 4 % NUM_WARPS_Q == 0); #pragma unroll for (uint32_t i = 0; i < NUM_MMA_KV * 4 / NUM_WARPS_Q; ++i) { #pragma unroll for (uint32_t j = 0; j < NUM_MMA_D / (8 / sizeof(DTypeKV)); ++j) { +#if Debug1 + if (warp_idx == 0 && lane_idx == 0) { + printf("Gptr deref : %f\n", (float)(*(*gptr))); + } +#endif smem.template load_vector_async(*smem_offset, *gptr, kv_idx < kv_len); +#if Debug1 + if (warp_idx == 0 && lane_idx == 0) { + uint32_t tmp[KTraits::INT32_ELEMS_PER_THREAD]; + smem.load_fragment(*smem_offset, tmp); + auto frag_T = reinterpret_cast<__half *>(tmp); + printf("==============\n"); + for (auto i = 0ul; i < 4; ++i) { + printf("Shared %f \n", (float)(*(frag_T + i))); + } + printf("==============\n"); + } +#endif *smem_offset = smem.template advance_offset_by_column<16>(*smem_offset, j); *gptr += 16 * upcast_size(); @@ -1245,9 +1261,8 @@ __device__ __forceinline__ void logits_transform( const uint32_t q_idx = q[mma_q][reg_id % NAPTR]; const uint32_t qo_head_idx = kv_head_idx * group_size + r[mma_q][reg_id % NAPTR]; - const uint32_t kv_idx = kv_idx_base + mma_kv * 16 + - 2 * (lane_idx % TPR) + - 8 * (reg_id / 2) + reg_id % 2; + const uint32_t kv_idx = + kv_idx_base + mma_kv * 16 + (lane_idx % TPR); #else const uint32_t q_idx = q[mma_q][(reg_id % 4) / 2], kv_idx = kv_idx_base + mma_kv * 16 + @@ -1335,10 +1350,9 @@ logits_mask(const Params ¶ms, ++reg_id) { #if defined(PLATFORM_HIP_DEVICE) - const uint32_t q_idx = q[mma_q][(reg_id % NAPTR)], - kv_idx = kv_idx_base + mma_kv * 16 + - 2 * (lane_idx % TPR) + - 8 * (reg_id / 2) + reg_id % 2; + const uint32_t q_idx = q[mma_q][(reg_id % NAPTR)]; + const uint32_t kv_idx = + kv_idx_base + mma_kv * 16 + (lane_idx % TPR); const uint32_t qo_head_idx = kv_head_idx * group_size + r[mma_q][(reg_id % NAPTR)]; #else @@ -2328,13 +2342,9 @@ SinglePrefillWithKVCacheDevice(const Params params, memory::commit_group(); #if Debug - int global_idx = (blockIdx.z * gridDim.y * gridDim.x + - blockIdx.y * gridDim.x + blockIdx.x) * - (blockDim.z * blockDim.y * blockDim.x) + - (threadIdx.z * blockDim.y * blockDim.x + - threadIdx.y * blockDim.x + threadIdx.x); + __syncthreads(); - if (global_idx == 0) { + if (warp_idx == 0 && lane_idx == 0) { printf("partition_kv : %d\n", partition_kv); printf("kv_len : %d\n", kv_len); printf("max_chunk_size : %d\n", max_chunk_size); @@ -2367,7 +2377,7 @@ SinglePrefillWithKVCacheDevice(const Params params, // Test K Global values: // Prints the (NUM_MMA_KV*16) x (NUM_MMA_D*16) matrix from global mem. - if (global_idx == 0) { + if (warp_idx == 0 && lane_idx == 0) { printf("\n DEBUG K Global (HIP):\n"); printf("k_stride_n : %d\n", k_stride_n); printf("k_stride_h : %d\n", k_stride_h); @@ -2401,7 +2411,7 @@ SinglePrefillWithKVCacheDevice(const Params params, // Note that LDS is loaded collaboratively by all warps and not each // warp accesses the whole K matrix loaded into LDS. Each warp will // only access 1/4 of the K values loaded into LDS> - if (global_idx == 0) { + if (warp_idx == 0 && lane_idx == 0) { printf("\n DEBUG K LDS ORIGINAL (HIP):\n"); uint32_t k_smem_offset_r_debug; for (auto i = 0; i < NUM_MMA_KV * 16; ++i) { From 37bfad79a90f2e7226c052663276e96033958079 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Sat, 6 Sep 2025 12:37:19 -0400 Subject: [PATCH 067/109] Fix produce_kv --- .../flashinfer/attention/generic/prefill.cuh | 90 +++++-------------- 1 file changed, 22 insertions(+), 68 deletions(-) diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh index c60219250d..6c2d6d299d 100644 --- a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -84,7 +84,7 @@ struct SharedStorageQKVO alignas(16) DTypeQ q_smem[CTA_TILE_Q * HEAD_DIM_QK]; alignas(16) DTypeKV k_smem[CTA_TILE_KV * HEAD_DIM_QK]; #if Debug - alignas(16) DTypeKV qk_scratch[CTA_TILE_KV * HEAD_DIM_QK]; + alignas(16) DTypeKV qk_scratch[CTA_TILE_Q * CTA_TILE_KV]; #endif alignas(16) DTypeKV v_smem[CTA_TILE_KV * HEAD_DIM_VO]; }; @@ -523,25 +523,8 @@ __device__ __forceinline__ void produce_kv_impl_cdna3_( for (uint32_t i = 0; i < NUM_MMA_KV * 4 / NUM_WARPS_Q; ++i) { #pragma unroll for (uint32_t j = 0; j < NUM_MMA_D / (8 / sizeof(DTypeKV)); ++j) { -#if Debug1 - if (warp_idx == 0 && lane_idx == 0) { - printf("Gptr deref : %f\n", (float)(*(*gptr))); - } -#endif smem.template load_vector_async(*smem_offset, *gptr, kv_idx < kv_len); -#if Debug1 - if (warp_idx == 0 && lane_idx == 0) { - uint32_t tmp[KTraits::INT32_ELEMS_PER_THREAD]; - smem.load_fragment(*smem_offset, tmp); - auto frag_T = reinterpret_cast<__half *>(tmp); - printf("==============\n"); - for (auto i = 0ul; i < 4; ++i) { - printf("Shared %f \n", (float)(*(frag_T + i))); - } - printf("==============\n"); - } -#endif *smem_offset = smem.template advance_offset_by_column<16>(*smem_offset, j); *gptr += 16 * upcast_size(); @@ -550,9 +533,9 @@ __device__ __forceinline__ void produce_kv_impl_cdna3_( *smem_offset = smem.template advance_offset_by_row( *smem_offset) - - sizeof(DTypeKV) * NUM_MMA_D; + (sizeof(DTypeKV) * NUM_MMA_D * 2); *gptr += NUM_WARPS * 4 * stride_n - - sizeof(DTypeKV) * NUM_MMA_D * + sizeof(DTypeKV) * NUM_MMA_D * 2 * upcast_size(); } *smem_offset -= KTraits::CTA_TILE_KV * UPCAST_STRIDE; @@ -623,8 +606,7 @@ __device__ __forceinline__ void page_produce_kv( lane_idx = tid.x; if constexpr (KTraits::SWIZZLE_MODE_KV == SwizzleMode::k128B) { uint32_t kv_idx = kv_idx_base + warp_idx * 4 + lane_idx / 8; - // NOTE: NUM_MMA_KV * 4 / NUM_WARPS_Q = NUM_WARPS_KV * NUM_MMA_KV * 4 / - // num_warps + // NOTE: NUM_MMA_KV * 4/NUM_WARPS_Q=NUM_WARPS_KV*NUM_MMA_KV*4/num_warps static_assert(NUM_MMA_KV * 4 % NUM_WARPS_Q == 0); #pragma unroll for (uint32_t i = 0; i < NUM_MMA_KV * 4 / NUM_WARPS_Q; ++i) { @@ -1134,47 +1116,6 @@ __device__ __forceinline__ void compute_qk( typename KTraits::DTypeQ>(s_frag[mma_q][mma_kv], a_frag[mma_q], b_frag); } - -#if Debug - if (mma_q == 0) { - __half *a_frag_half = - reinterpret_cast<__half *>(a_frag[mma_q]); - debug_printer( - 0, "a_frag_half_0: ", float(a_frag_half[0])); - debug_printer( - 0, "a_frag_half_1: ", float(a_frag_half[1])); - debug_printer( - 0, "a_frag_half_2: ", float(a_frag_half[2])); - debug_printer( - 0, "a_frag_half_3: ", float(a_frag_half[3])); - - __syncthreads(); - - // __half* b_frag_half = - // reinterpret_cast<__half*>(b_frag); - // debug_printer(0, "b_frag_half_0: ", - // float(b_frag_half[0])); debug_printer(0, - // "b_frag_half_1: ", float(b_frag_half[1])); - // debug_printer(0, "b_frag_half_2: ", - // float(b_frag_half[2])); debug_printer(0, - // "b_frag_half_3: ", float(b_frag_half[3])); - - // __syncthreads(); - - __half *s_frag_half = - reinterpret_cast<__half *>(s_frag[mma_q][mma_kv]); - debug_printer( - 0, "s_frag_half: ", float(s_frag_half[0])); - debug_printer( - 0, "s_frag_half: ", float(s_frag_half[1])); - debug_printer( - 0, "s_frag_half: ", float(s_frag_half[2])); - debug_printer( - 0, "s_frag_half: ", float(s_frag_half[3])); - - __syncthreads(); - } -#endif } else if (std::is_same_v) { @@ -2318,7 +2259,6 @@ SinglePrefillWithKVCacheDevice(const Params params, 8 * (lane_idx / 16) + lane_idx % 8, (lane_idx % 16) / 8); #endif - uint32_t v_smem_offset_r = v_smem.template get_permuted_offset( get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + @@ -2342,8 +2282,6 @@ SinglePrefillWithKVCacheDevice(const Params params, memory::commit_group(); #if Debug - __syncthreads(); - if (warp_idx == 0 && lane_idx == 0) { printf("partition_kv : %d\n", partition_kv); printf("kv_len : %d\n", kv_len); @@ -2450,8 +2388,24 @@ SinglePrefillWithKVCacheDevice(const Params params, } // compute attention score - // compute_qk(&qo_smem, &q_smem_offset_r, &k_smem, - // &k_smem_offset_r, s_frag); + compute_qk(&qo_smem, &q_smem_offset_r, &k_smem, + &k_smem_offset_r, s_frag); +#if DEBUG + + smem_t scratch( + smem_storage.qk_scratch); + // copy sfrag into scratch + for (auto mma_q = 0ul; mma_q < NUM_MMA_Q; ++mma_q) { + for (auto mma_kv = 0ul; mma_kv < NUM_MMA_KV; ++mma_kv) { + for (auto reg_id = 0ul; reg_id < HALF_ELEMS_PER_THREAD; + ++reg_id) + { + auto tmp = s_frag[mma_q][mma_kv]; + // store into scratch + } + } + } +#endif logits_transform( params, variant, /*batch_idx=*/0, qo_packed_idx_base, From 8abae3d8e84da538778956fb794aa244698801fd Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Mon, 8 Sep 2025 16:52:44 -0400 Subject: [PATCH 068/109] Test transforms debug --- .../flashinfer/attention/generic/prefill.cuh | 141 +++++++++++++----- .../tests/hip/test_single_prefill.cpp | 13 +- libflashinfer/utils/cpu_reference_hip.h | 125 +++++++++------- .../utils/flashinfer_prefill_ops.hip.h | 4 +- 4 files changed, 185 insertions(+), 98 deletions(-) diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh index 6c2d6d299d..4a92a96c26 100644 --- a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -1231,6 +1231,15 @@ __device__ __forceinline__ void logits_transform( logitsTransformed = variant.LogitsTransform(params, logits, batch_idx, q_idx, kv_idx, qo_head_idx, kv_head_idx); +#if Debug + const uint32_t lane_idx = tid.x, + warp_idx = get_warp_idx(tid.y, tid.z); + + if (warp_idx == 0 && lane_idx == 0) { + printf("logits : %f logitsTransformed: %f\n", float(logits), + float(logitsTransformed)); + } +#endif #ifdef FP16_QK_REDUCTION_SUPPORTED if constexpr (std::is_same::value) { s_frag[mma_q][mma_kv][reg_id] = std::bit_cast( @@ -2281,7 +2290,7 @@ SinglePrefillWithKVCacheDevice(const Params params, v_smem, &v_smem_offset_w, &v_ptr, v_stride_n, 0, chunk_size, tid); memory::commit_group(); -#if Debug +#if Debug1 if (warp_idx == 0 && lane_idx == 0) { printf("partition_kv : %d\n", partition_kv); printf("kv_len : %d\n", kv_len); @@ -2290,28 +2299,27 @@ SinglePrefillWithKVCacheDevice(const Params params, printf("chunk_start : %d\n", chunk_start); } - // // Test Q - // if (global_idx == 0) { - // printf("\n DEBUG Q ORIGINAL (HIP):\n"); - // uint32_t q_smem_offset_r_debug; - // for (auto i = 0; i < 16; ++i) { - // for (auto j = 0; j < 16; ++j) { - // q_smem_offset_r_debug = - // qo_smem.template - // get_permuted_offset( - // i, j); - // uint32_t a_frag[KTraits::INT32_ELEMS_PER_THREAD]; - // qo_smem.load_fragment(q_smem_offset_r_debug, a_frag); - // auto frag_T = reinterpret_cast<__half *>(a_frag); - // for (auto i = 0ul; i < 4; ++i) { - // printf("%f ", (float)(*(frag_T + i))); - // } - // } - // printf("\n"); - // qo_smem.template advance_offset_by_row< - // 16, KTraits::UPCAST_STRIDE_Q>(q_smem_offset_r_debug); - // } - // } + // Test Q + if (warp_idx == 0 && lane_idx == 0) { + printf("\n DEBUG Q ORIGINAL (HIP):\n"); + uint32_t q_smem_offset_r_debug; + for (auto i = 0; i < NUM_MMA_Q * 16; ++i) { + for (auto j = 0; j < NUM_MMA_D_QK * 4; ++j) { + q_smem_offset_r_debug = + qo_smem.template get_permuted_offset( + i, j); + uint32_t a_frag[KTraits::INT32_ELEMS_PER_THREAD]; + qo_smem.load_fragment(q_smem_offset_r_debug, a_frag); + auto frag_T = reinterpret_cast<__half *>(a_frag); + for (auto i = 0ul; i < 4; ++i) { + printf("%f ", (float)(*(frag_T + i))); + } + } + printf("\n"); + qo_smem.template advance_offset_by_row< + 16, KTraits::UPCAST_STRIDE_Q>(q_smem_offset_r_debug); + } + } // Test K Global values: // Prints the (NUM_MMA_KV*16) x (NUM_MMA_D*16) matrix from global mem. @@ -2373,8 +2381,8 @@ SinglePrefillWithKVCacheDevice(const Params params, #endif #pragma unroll 1 - // for (uint32_t iter = 0; iter < num_iterations; ++iter) { - for (uint32_t iter = 0; iter < 1; ++iter) { + for (uint32_t iter = 0; iter < num_iterations; ++iter) { + // for (uint32_t iter = 0; iter < 1; ++iter) { memory::wait_group<1>(); block.sync(); @@ -2390,21 +2398,56 @@ SinglePrefillWithKVCacheDevice(const Params params, // compute attention score compute_qk(&qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag); -#if DEBUG - +#if Debug smem_t scratch( smem_storage.qk_scratch); // copy sfrag into scratch - for (auto mma_q = 0ul; mma_q < NUM_MMA_Q; ++mma_q) { - for (auto mma_kv = 0ul; mma_kv < NUM_MMA_KV; ++mma_kv) { - for (auto reg_id = 0ul; reg_id < HALF_ELEMS_PER_THREAD; - ++reg_id) - { + if (warp_idx == 0) { + for (auto mma_q = 0ul; mma_q < NUM_MMA_Q; ++mma_q) { + for (auto mma_kv = 0ul; mma_kv < NUM_MMA_KV; ++mma_kv) { auto tmp = s_frag[mma_q][mma_kv]; - // store into scratch + auto col = lane_idx % 16; + auto row = lane_idx / 16; + auto scratch_offset = + scratch + .template get_permuted_offset( + row, col); + scratch.template store_fragment(scratch_offset, tmp); } } } + + if (warp_idx == 0 && lane_idx == 0) { + auto _hScratch = reinterpret_cast<__half *>(scratch.base); + printf("compute_qk results (Warp 0): \n"); + + for (auto k = 0ul; k < 4; ++k) { + for (auto i = 0ul; i < 16; ++i) { + for (auto j = 0ul; j < 16; ++j) { + printf("%f ", float(_hScratch[k * 64 + i + j * 4])); + } + printf("\n"); + } + } + } + + // if (warp_idx == 0 && lane_idx == 0) { + // printf("s_frag results: \n"); + // for (auto mma_q = 0ul; mma_q < NUM_MMA_Q; ++mma_q) { + // for (auto mma_kv = 0ul; mma_kv < NUM_MMA_KV; ++mma_kv) { + // for (auto reg_id = 0ul; reg_id < + // HALF_ELEMS_PER_THREAD; + // ++reg_id) + // { + // auto tmp = s_frag[mma_q][mma_kv][reg_id]; + // printf("s_frag[%lu][%lu][%lu] : %f ", mma_q, + // mma_kv, + // reg_id, float(tmp)); + // } + // printf("\n"); + // } + // } + // } #endif logits_transform( @@ -2426,7 +2469,37 @@ SinglePrefillWithKVCacheDevice(const Params params, qo_len, kv_len, chunk_end, group_size, s_frag, tid, kv_head_idx); } - +#if Debug + // TODO: + // smem_t scratch( + // smem_storage.qk_scratch); + // // copy sfrag into scratch + // for (auto mma_q = 0ul; mma_q < NUM_MMA_Q; ++mma_q) { + // for (auto mma_kv = 0ul; mma_kv < NUM_MMA_KV; ++mma_kv) { + // for (auto reg_id = 0ul; reg_id < HALF_ELEMS_PER_THREAD; + // ++reg_id) + // { + // auto tmp = s_frag[mma_q][mma_kv]; + // // store into scratch + // } + // } + // } + if (warp_idx == 0 && lane_idx == 0) { + printf("s_frag after logits transform and masking : \n"); + for (auto mma_q = 0ul; mma_q < NUM_MMA_Q; ++mma_q) { + for (auto mma_kv = 0ul; mma_kv < NUM_MMA_KV; ++mma_kv) { + for (auto reg_id = 0ul; reg_id < HALF_ELEMS_PER_THREAD; + ++reg_id) + { + auto tmp = s_frag[mma_q][mma_kv][reg_id]; + printf("s_frag[%lu][%lu][%lu] : %f ", mma_q, mma_kv, + reg_id, float(tmp)); + } + printf("\n"); + } + } + } +#endif // compute m,d states in online softmax update_mdo_states(variant, s_frag, o_frag, m, d); diff --git a/libflashinfer/tests/hip/test_single_prefill.cpp b/libflashinfer/tests/hip/test_single_prefill.cpp index a68e67fbf5..b471bcd752 100644 --- a/libflashinfer/tests/hip/test_single_prefill.cpp +++ b/libflashinfer/tests/hip/test_single_prefill.cpp @@ -242,12 +242,13 @@ void _TestSinglePrefillKernelCorrectness(size_t qo_len, hipMemcpyDeviceToHost)); // Print the first 10 elements of the output vector for debugging - // std::cout << "Output vector (first 10 elements):"; - // std::cout << "[" << std::endl; - // for (int i = 0; i < 10; ++i) { - // std::cout << fi::con::explicit_casting(o_h[i]) << " "; - // } - // std::cout << "]" << std::endl; + // std::cout << "Output vector (first 10 elements):"; + // std::cout << "[" << std::endl; + // for (int i = 0; i < 10; ++i) { + // std::cout << fi::con::explicit_casting(o_h[i]) << " + // "; + // } + // std::cout << "]" << std::endl; bool isEmpty = o_h.empty(); EXPECT_EQ(isEmpty, false) << "Output vector is empty"; diff --git a/libflashinfer/utils/cpu_reference_hip.h b/libflashinfer/utils/cpu_reference_hip.h index 9abc3a04cc..ee6bc43e7c 100644 --- a/libflashinfer/utils/cpu_reference_hip.h +++ b/libflashinfer/utils/cpu_reference_hip.h @@ -12,6 +12,7 @@ #include "utils_hip.h" +#include #include #include #include @@ -201,20 +202,26 @@ single_mha(const std::vector &q, bool causal = true, QKVLayout kv_layout = QKVLayout::kHND, PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, + float logits_soft_cap = 8.0f, float rope_scale = 1.f, - float rope_theta = 1e4) + float rope_theta = 1e4, + bool use_soft_cap = true) { assert(qo_len <= kv_len); assert(num_qo_heads % num_kv_heads == 0); float sm_scale = 1.f / std::sqrt(float(head_dim)); + // float sm_scale = 1.0; std::vector o(qo_len * num_qo_heads * head_dim); std::vector att(kv_len); std::vector q_rotary_local(head_dim); std::vector k_rotary_local(head_dim); + + float soft_cap_pre_tanh_scale = sm_scale / logits_soft_cap; + DISPATCH_head_dim(head_dim, HEAD_DIM, { tensor_info_t info(qo_len, kv_len, num_qo_heads, num_kv_heads, kv_layout, HEAD_DIM); -#if Debug +#if Debug1 std::cout << "DEBUG: Original Q (CPU): " << '\n'; for (auto i = 0ul; i < 16; ++i) { for (int j = 0; j < 64; ++j) { @@ -235,44 +242,6 @@ single_mha(const std::vector &q, // q[info.get_q_elem_offset(q_idx, qo_head_idx, feat_idx) } std::cout << std::endl; - - // std::cout << "DEBUG K (CPU): " << '\n'; - // for (auto j = 0ul; j < 16; ++j) { - // for (auto i = 0ul; i < 64; ++i) { - // // k[info.get_kv_elem_offset(kv_idx, kv_head_idx, feat_idx) - // // std::cout << (float)k[info.get_kv_elem_offset(15, 0, j * 4 - // // + - // // i)] - // std::cout << (float)k[info.get_kv_elem_offset(j, 0, i)] <<" - // "; - // } - // std::cout << '\n'; - // } - - // std::cout << std::endl; - // std::cout << "num_qo_heads " << num_qo_heads << '\n'; - // std::cout << "qo_len " << qo_len << '\n'; - // for (size_t qo_head_idx = 0; qo_head_idx < num_qo_heads; - // ++qo_head_idx) - // { - // for (size_t q_idx = 0; q_idx < qo_len; ++q_idx) { - // q_rotary_local = - // std::move(cpu_reference::apply_llama_rope_debug( - // q.data() + - // info.get_q_elem_offset(q_idx, qo_head_idx, 0), - // head_dim, q_idx + kv_len - qo_len, rope_scale, - // rope_theta)); - // } - // } - - // std::cout << "DEBUG: LLAMA Rope Transformed Q (CPU): " << '\n'; - // for (auto i = 0ul; i < 4; ++i) { - // // q[info.get_q_elem_offset(q_idx, qo_head_idx, feat_idx) - // std::cout << (float)q_rotary_local[info.get_q_elem_offset(0, 0, - // i)] - // << " "; - // } - // std::cout << std::endl; #endif for (size_t qo_head_idx = 0; qo_head_idx < num_qo_heads; ++qo_head_idx) { @@ -291,20 +260,50 @@ single_mha(const std::vector &q, switch (pos_encoding_mode) { case PosEncodingMode::kNone: { -#if Debug - sm_scale = 1.0f; -#endif for (size_t feat_idx = 0; feat_idx < head_dim; ++feat_idx) { - att[kv_idx] += - fi::con::explicit_casting( - q[info.get_q_elem_offset(q_idx, qo_head_idx, - feat_idx)]) * - fi::con::explicit_casting( - k[info.get_kv_elem_offset( - kv_idx, kv_head_idx, feat_idx)]) * - sm_scale; + if (use_soft_cap) { + auto score = + fi::con::explicit_casting( + q[info.get_q_elem_offset( + q_idx, qo_head_idx, feat_idx)]) * + fi::con::explicit_casting( + k[info.get_kv_elem_offset( + kv_idx, kv_head_idx, feat_idx)]); + auto tscore = float( + std::tanh(score * soft_cap_pre_tanh_scale)); + att[kv_idx] += tscore; +#if Debug1 + if (qo_head_idx == 0 && q_idx == 0 && + kv_idx < 16) + { + std::cout << "score: " << score + << " Transformed : " << tscore + << " Final : " << att[kv_idx] + << '\n'; + } +#endif + } + else { + att[kv_idx] += + fi::con::explicit_casting( + q[info.get_q_elem_offset( + q_idx, qo_head_idx, feat_idx)]) * + fi::con::explicit_casting( + k[info.get_kv_elem_offset( + kv_idx, kv_head_idx, feat_idx)]) * + sm_scale; +#if Debug1 + if (qo_head_idx == 0 && q_idx == 0 && + kv_idx < 16) + { + std::cout + << "Post-transform: " << att[kv_idx] + << '\n'; + } +#endif + } } break; } @@ -318,8 +317,17 @@ single_mha(const std::vector &q, for (size_t feat_idx = 0; feat_idx < head_dim; ++feat_idx) { - att[kv_idx] += q_rotary_local[feat_idx] * - k_rotary_local[feat_idx] * sm_scale; + if (use_soft_cap) { + att[kv_idx] += q_rotary_local[feat_idx] * + k_rotary_local[feat_idx]; + att[kv_idx] = std::tanh( + att[kv_idx] * soft_cap_pre_tanh_scale); + } + else { + att[kv_idx] += q_rotary_local[feat_idx] * + k_rotary_local[feat_idx] * + sm_scale; + } } break; } @@ -330,15 +338,20 @@ single_mha(const std::vector &q, FLASHINFER_ERROR(err_msg.str()); } } -// apply mask -#if 0 + // apply mask if (causal && kv_idx > kv_len + q_idx - qo_len) { att[kv_idx] = -5e4; } -#endif max_val = std::max(max_val, att[kv_idx]); } +#if Debug1 + if (qo_head_idx == 0) { + for (auto i = 0ul; i < 16; ++i) { + std::cout << att[i] << '\n'; + } + } +#endif // exp minus max float denom = 0; for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { diff --git a/libflashinfer/utils/flashinfer_prefill_ops.hip.h b/libflashinfer/utils/flashinfer_prefill_ops.hip.h index 2866368dab..dc2d9e1853 100644 --- a/libflashinfer/utils/flashinfer_prefill_ops.hip.h +++ b/libflashinfer/utils/flashinfer_prefill_ops.hip.h @@ -149,14 +149,14 @@ hipError_t SinglePrefillWithKVCache( /*use_custom_mask=*/(MASK_MODE == MaskMode::kCustom), /*use_sliding_window=*/false, - /*use_logits_soft_cap=*/false, /*use_alibi=*/false>; + /*use_logits_soft_cap=*/true, /*use_alibi=*/false>; Params params(q, k, v, /*custom_mask=*/nullptr, o, lse, /*alibi_slopes=*/nullptr, num_qo_heads, num_kv_heads, qo_len, kv_len, qo_stride_n, qo_stride_h, kv_stride_n, kv_stride_h, head_dim, /*window_left=*/-1, - /*logits_soft_cap=*/0.f, sm_scale, + /*logits_soft_cap=*/8.f, sm_scale, rope_scale, rope_theta); return SinglePrefillWithKVCacheDispatched< HEAD_DIM, HEAD_DIM, POS_ENCODING_MODE, From 34e99f8374bdf4334733435b59c197e3ad25a07e Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Tue, 9 Sep 2025 18:46:28 -0400 Subject: [PATCH 069/109] wip fixes --- .../flashinfer/attention/generic/prefill.cuh | 206 +++++++----------- .../flashinfer/attention/generic/variants.cuh | 7 +- libflashinfer/utils/cpu_reference_hip.h | 84 ++----- .../utils/flashinfer_prefill_ops.hip.h | 6 +- 4 files changed, 110 insertions(+), 193 deletions(-) diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh index 4a92a96c26..1e9ee0f89d 100644 --- a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -159,11 +159,11 @@ struct KernelTraits // Presently we use 16x4 thread layout for all cases. static constexpr uint32_t KV_THR_LAYOUT_ROW = WARP_THREAD_ROWS; static constexpr uint32_t KV_THR_LAYOUT_COL = WARP_THREAD_COLS; - // The constant is defined based on the matrix layout of the "D/C" + // FIXME: [The comment is not correct] The constant is defined based on the matrix layout of the "D/C" // accumulator matrix in a D = A*B+C computation. On CDNA3 the D/C matrices // are distributed as four 4x16 bands across the 64 threads. Each thread // owns one element from four different rows. - static constexpr uint32_t NUM_ACCUM_ROWS_PER_THREAD = 4; + static constexpr uint32_t NUM_ACCUM_ROWS_PER_THREAD = 1; // Number of threads that collaboratively handle the same set of matrix rows // in attention score computation and cross-warp synchronization. // CUDA: 4 threads (each thread handles 2 elements from same row group) @@ -1231,7 +1231,7 @@ __device__ __forceinline__ void logits_transform( logitsTransformed = variant.LogitsTransform(params, logits, batch_idx, q_idx, kv_idx, qo_head_idx, kv_head_idx); -#if Debug +#if Debug1 const uint32_t lane_idx = tid.x, warp_idx = get_warp_idx(tid.y, tid.z); @@ -1331,16 +1331,14 @@ logits_mask(const Params ¶ms, template __device__ __forceinline__ void update_mdo_states( typename KTraits::AttentionVariant variant, - typename KTraits::DTypeQKAccum ( - *s_frag)[KTraits::NUM_MMA_KV][KTraits::HALF_ELEMS_PER_THREAD], + typename KTraits::DTypeQKAccum (*s_frag)[KTraits::NUM_MMA_KV][KTraits::HALF_ELEMS_PER_THREAD], float (*o_frag)[KTraits::NUM_MMA_D_VO][KTraits::HALF_ELEMS_PER_THREAD], typename KTraits::DTypeQKAccum (*m)[KTraits::NUM_ACCUM_ROWS_PER_THREAD], float (*d)[KTraits::NUM_ACCUM_ROWS_PER_THREAD]) { using DTypeQKAccum = typename KTraits::DTypeQKAccum; using AttentionVariant = typename KTraits::AttentionVariant; - constexpr uint32_t NUM_ACCUM_ROWS_PER_THREAD = - KTraits::NUM_ACCUM_ROWS_PER_THREAD; + constexpr uint32_t NUM_ACCUM_ROWS_PER_THREAD = KTraits::NUM_ACCUM_ROWS_PER_THREAD; constexpr bool use_softmax = AttentionVariant::use_softmax; if constexpr (use_softmax) { @@ -1352,72 +1350,56 @@ __device__ __forceinline__ void update_mdo_states( for (uint32_t j = 0; j < NUM_ACCUM_ROWS_PER_THREAD; ++j) { float m_prev = m[mma_q][j]; #pragma unroll - for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; - ++mma_kv) + for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { #if defined(PLATFORM_HIP_DEVICE) - m[mma_q][j] = - max(m[mma_q][j], s_frag[mma_q][mma_kv][j]); + m[mma_q][j] = max(max(s_frag[mma_q][mma_kv][0], s_frag[mma_q][mma_kv][1]), + max(s_frag[mma_q][mma_kv][2], s_frag[mma_q][mma_kv][3])); #else float m_local = - max(max(s_frag[mma_q][mma_kv][j * 2 + 0], - s_frag[mma_q][mma_kv][j * 2 + 1]), - max(s_frag[mma_q][mma_kv][j * 2 + 4], - s_frag[mma_q][mma_kv][j * 2 + 5])); + max(max(s_frag[mma_q][mma_kv][j * 2 + 0], s_frag[mma_q][mma_kv][j * 2 + 1]), + max(s_frag[mma_q][mma_kv][j * 2 + 4], s_frag[mma_q][mma_kv][j * 2 + 5])); m[mma_q][j] = max(m[mma_q][j], m_local); #endif } #if defined(PLATFORM_HIP_DEVICE) - // Butterfly reduction across all threads in the band (16 - // threads) for CDNA3's 64-thread wavefront - m[mma_q][j] = - max(m[mma_q][j], gpu_iface::math::shfl_xor_sync( - m[mma_q][j], 0x8)); // 16 apart - m[mma_q][j] = - max(m[mma_q][j], gpu_iface::math::shfl_xor_sync( - m[mma_q][j], 0x4)); // 8 apart - m[mma_q][j] = - max(m[mma_q][j], gpu_iface::math::shfl_xor_sync( - m[mma_q][j], 0x2)); // 4 apart - m[mma_q][j] = - max(m[mma_q][j], gpu_iface::math::shfl_xor_sync( - m[mma_q][j], 0x1)); // 2 apart - - float o_scale = gpu_iface::math::ptx_exp2( - m_prev * sm_scale - m[mma_q][j] * sm_scale); + // Butterfly reduction across all threads in the band + m[mma_q][j] = max(m[mma_q][j], gpu_iface::math::shfl_xor_sync(m[mma_q][j], 0x10)); // 32 apart + m[mma_q][j] = max(m[mma_q][j], gpu_iface::math::shfl_xor_sync(m[mma_q][j], 0x8)); // 16 apart + + float o_scale = gpu_iface::math::ptx_exp2(m_prev * sm_scale - m[mma_q][j] * sm_scale); d[mma_q][j] *= o_scale; // Scale output fragments for this specific row #pragma unroll - for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; - ++mma_d) + for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { - o_frag[mma_q][mma_d][j] *= o_scale; // Direct indexing + o_frag[mma_q][mma_d][0] *= o_scale; + o_frag[mma_q][mma_d][1] *= o_scale; + o_frag[mma_q][mma_d][2] *= o_scale; + o_frag[mma_q][mma_d][3] *= o_scale; } // Convert logits to probabilities for this row #pragma unroll - for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; - ++mma_kv) + for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { - s_frag[mma_q][mma_kv][j] = gpu_iface::math::ptx_exp2( - s_frag[mma_q][mma_kv][j] * sm_scale - - m[mma_q][j] * sm_scale); + s_frag[mma_q][mma_kv][0] = gpu_iface::math::ptx_exp2( + s_frag[mma_q][mma_kv][0] * sm_scale - m[mma_q][j] * sm_scale); + s_frag[mma_q][mma_kv][1] = gpu_iface::math::ptx_exp2( + s_frag[mma_q][mma_kv][1] * sm_scale - m[mma_q][j] * sm_scale); + s_frag[mma_q][mma_kv][2] = gpu_iface::math::ptx_exp2( + s_frag[mma_q][mma_kv][2] * sm_scale - m[mma_q][j] * sm_scale); + s_frag[mma_q][mma_kv][3] = gpu_iface::math::ptx_exp2( + s_frag[mma_q][mma_kv][3] * sm_scale - m[mma_q][j] * sm_scale); } #else - m[mma_q][j] = - max(m[mma_q][j], - gpu_iface::math::shfl_xor_sync(m[mma_q][j], 0x2)); - m[mma_q][j] = - max(m[mma_q][j], - gpu_iface::math::shfl_xor_sync(m[mma_q][j], 0x1)); - - float o_scale = gpu_iface::math::ptx_exp2( - m_prev * sm_scale - m[mma_q][j] * sm_scale); + m[mma_q][j] = max(m[mma_q][j], gpu_iface::math::shfl_xor_sync(m[mma_q][j], 0x2)); + m[mma_q][j] = max(m[mma_q][j], gpu_iface::math::shfl_xor_sync(m[mma_q][j], 0x1)); + float o_scale = gpu_iface::math::ptx_exp2(m_prev * sm_scale - m[mma_q][j] * sm_scale); d[mma_q][j] *= o_scale; #pragma unroll - for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; - ++mma_d) + for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { o_frag[mma_q][mma_d][j * 2 + 0] *= o_scale; o_frag[mma_q][mma_d][j * 2 + 1] *= o_scale; @@ -1425,15 +1407,10 @@ __device__ __forceinline__ void update_mdo_states( o_frag[mma_q][mma_d][j * 2 + 5] *= o_scale; } #pragma unroll - for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; - ++mma_kv) + for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { - s_frag[mma_q][mma_kv][j * 2 + 0] = - gpu_iface::math::ptx_exp2( - s_frag[mma_q][mma_kv][j * 2 + 0] * sm_scale - - m[mma_q][j] * sm_scale); - s_frag[mma_q][mma_kv][j * 2 + 1] = - gpu_iface::math::ptx_exp2( + s_frag[mma_q][mma_kv][j * 2 + 0] = gpu_iface::math::ptx_exp2(s_frag[mma_q][mma_kv][j * 2 + 0] * sm_scale - m[mma_q][j] * sm_scale); + s_frag[mma_q][mma_kv][j * 2 + 1] = gpu_iface::math::ptx_exp2( s_frag[mma_q][mma_kv][j * 2 + 1] * sm_scale - m[mma_q][j] * sm_scale); s_frag[mma_q][mma_kv][j * 2 + 4] = @@ -2164,8 +2141,7 @@ SinglePrefillWithKVCacheDevice(const Params params, const uint32_t window_left = variant.window_left; DTypeQKAccum s_frag[NUM_MMA_Q][NUM_MMA_KV][HALF_ELEMS_PER_THREAD]; - alignas( - 16) float o_frag[NUM_MMA_Q][NUM_MMA_D_VO][HALF_ELEMS_PER_THREAD]; + alignas(16) float o_frag[NUM_MMA_Q][NUM_MMA_D_VO][HALF_ELEMS_PER_THREAD]; DTypeQKAccum m[NUM_MMA_Q][NUM_ACCUM_ROWS_PER_THREAD]; float d[NUM_MMA_Q][NUM_ACCUM_ROWS_PER_THREAD]; float rope_freq[NUM_MMA_D_QK / 2][4]; @@ -2290,7 +2266,7 @@ SinglePrefillWithKVCacheDevice(const Params params, v_smem, &v_smem_offset_w, &v_ptr, v_stride_n, 0, chunk_size, tid); memory::commit_group(); -#if Debug1 +#if Debug if (warp_idx == 0 && lane_idx == 0) { printf("partition_kv : %d\n", partition_kv); printf("kv_len : %d\n", kv_len); @@ -2303,7 +2279,7 @@ SinglePrefillWithKVCacheDevice(const Params params, if (warp_idx == 0 && lane_idx == 0) { printf("\n DEBUG Q ORIGINAL (HIP):\n"); uint32_t q_smem_offset_r_debug; - for (auto i = 0; i < NUM_MMA_Q * 16; ++i) { + for (auto i = 0; i < NUM_MMA_Q * 16 * 4; ++i) { for (auto j = 0; j < NUM_MMA_D_QK * 4; ++j) { q_smem_offset_r_debug = qo_smem.template get_permuted_offset( @@ -2323,6 +2299,7 @@ SinglePrefillWithKVCacheDevice(const Params params, // Test K Global values: // Prints the (NUM_MMA_KV*16) x (NUM_MMA_D*16) matrix from global mem. + if (warp_idx == 0 && lane_idx == 0) { printf("\n DEBUG K Global (HIP):\n"); printf("k_stride_n : %d\n", k_stride_n); @@ -2334,7 +2311,7 @@ SinglePrefillWithKVCacheDevice(const Params params, printf("KTraits::NUM_MMA_D_QK : %d\n", KTraits::NUM_MMA_D_QK); printf("NUM_MMA_KV : %d\n", NUM_MMA_KV); printf("NUM_MMA_Q : %d\n", NUM_MMA_Q); - +#if 0 DTypeKV *k_ptr_tmp = k + (chunk_start + warp_idx * KV_THR_LAYOUT_ROW + lane_idx / KV_THR_LAYOUT_COL) * @@ -2350,17 +2327,18 @@ SinglePrefillWithKVCacheDevice(const Params params, } printf("\n"); } +#endif } // Test K LDS values: // Prints the (NUM_MMA_KV*16) x (NUM_MMA_D*16) matrix from shared mem. // Note that LDS is loaded collaboratively by all warps and not each // warp accesses the whole K matrix loaded into LDS. Each warp will - // only access 1/4 of the K values loaded into LDS> + // only access 1/4 of the K values loaded into LDS. if (warp_idx == 0 && lane_idx == 0) { printf("\n DEBUG K LDS ORIGINAL (HIP):\n"); uint32_t k_smem_offset_r_debug; - for (auto i = 0; i < NUM_MMA_KV * 16; ++i) { + for (auto i = 0; i < NUM_MMA_KV * 16 * 2; ++i) { for (auto j = 0; j < NUM_MMA_D_QK * 4; ++j) { k_smem_offset_r_debug = k_smem.template get_permuted_offset(i, @@ -2373,8 +2351,7 @@ SinglePrefillWithKVCacheDevice(const Params params, } } printf("\n"); - k_smem.template advance_offset_by_row<16, - KTraits::UPCAST_STRIDE_K>( + k_smem.template advance_offset_by_row<16, KTraits::UPCAST_STRIDE_K>( k_smem_offset_r_debug); } } @@ -2399,57 +2376,20 @@ SinglePrefillWithKVCacheDevice(const Params params, compute_qk(&qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag); #if Debug - smem_t scratch( - smem_storage.qk_scratch); - // copy sfrag into scratch - if (warp_idx == 0) { + if (warp_idx == 0 && lane_idx == 0) { + printf("s_frag results after compute_qk: \n"); for (auto mma_q = 0ul; mma_q < NUM_MMA_Q; ++mma_q) { for (auto mma_kv = 0ul; mma_kv < NUM_MMA_KV; ++mma_kv) { - auto tmp = s_frag[mma_q][mma_kv]; - auto col = lane_idx % 16; - auto row = lane_idx / 16; - auto scratch_offset = - scratch - .template get_permuted_offset( - row, col); - scratch.template store_fragment(scratch_offset, tmp); - } - } - } - - if (warp_idx == 0 && lane_idx == 0) { - auto _hScratch = reinterpret_cast<__half *>(scratch.base); - printf("compute_qk results (Warp 0): \n"); - - for (auto k = 0ul; k < 4; ++k) { - for (auto i = 0ul; i < 16; ++i) { - for (auto j = 0ul; j < 16; ++j) { - printf("%f ", float(_hScratch[k * 64 + i + j * 4])); + for (auto reg_id = 0ul; reg_id < HALF_ELEMS_PER_THREAD; ++reg_id) + { + auto tmp = s_frag[mma_q][mma_kv][reg_id]; + printf("s_frag[%lu][%lu][%lu] : %f ", mma_q, mma_kv, reg_id, float(tmp)); } printf("\n"); } } } - - // if (warp_idx == 0 && lane_idx == 0) { - // printf("s_frag results: \n"); - // for (auto mma_q = 0ul; mma_q < NUM_MMA_Q; ++mma_q) { - // for (auto mma_kv = 0ul; mma_kv < NUM_MMA_KV; ++mma_kv) { - // for (auto reg_id = 0ul; reg_id < - // HALF_ELEMS_PER_THREAD; - // ++reg_id) - // { - // auto tmp = s_frag[mma_q][mma_kv][reg_id]; - // printf("s_frag[%lu][%lu][%lu] : %f ", mma_q, - // mma_kv, - // reg_id, float(tmp)); - // } - // printf("\n"); - // } - // } - // } #endif - logits_transform( params, variant, /*batch_idx=*/0, qo_packed_idx_base, chunk_start + @@ -2457,6 +2397,27 @@ SinglePrefillWithKVCacheDevice(const Params params, NUM_MMA_KV * 16, qo_len, kv_len, group_size, s_frag, tid, kv_head_idx); +#if Debug + float soft_cap_pre_tanh_scale = + params.sm_scale * + gpu_iface::math::ptx_rcp(params.logits_soft_cap); + if (warp_idx == 0 && lane_idx == 0) { + printf("params.sm_scale %f, params.logits_soft_cap %f\n", params.sm_scale, params.logits_soft_cap); + printf("s_frag after logits transform (scaled by %f) : \n", soft_cap_pre_tanh_scale); + for (auto mma_q = 0ul; mma_q < NUM_MMA_Q; ++mma_q) { + for (auto mma_kv = 0ul; mma_kv < NUM_MMA_KV; ++mma_kv) { + for (auto reg_id = 0ul; reg_id < HALF_ELEMS_PER_THREAD; + ++reg_id) + { + auto tmp = s_frag[mma_q][mma_kv][reg_id]; + printf("s_frag[%lu][%lu][%lu] : %f ", mma_q, mma_kv, + reg_id, float(tmp)); + } + printf("\n"); + } + } + } +#endif // apply mask if (MASK_MODE == MaskMode::kCustom || (iter >= mask_iteration || iter < window_iteration)) @@ -2469,23 +2430,10 @@ SinglePrefillWithKVCacheDevice(const Params params, qo_len, kv_len, chunk_end, group_size, s_frag, tid, kv_head_idx); } + #if Debug - // TODO: - // smem_t scratch( - // smem_storage.qk_scratch); - // // copy sfrag into scratch - // for (auto mma_q = 0ul; mma_q < NUM_MMA_Q; ++mma_q) { - // for (auto mma_kv = 0ul; mma_kv < NUM_MMA_KV; ++mma_kv) { - // for (auto reg_id = 0ul; reg_id < HALF_ELEMS_PER_THREAD; - // ++reg_id) - // { - // auto tmp = s_frag[mma_q][mma_kv]; - // // store into scratch - // } - // } - // } if (warp_idx == 0 && lane_idx == 0) { - printf("s_frag after logits transform and masking : \n"); + printf("s_frag after logits masking\n"); for (auto mma_q = 0ul; mma_q < NUM_MMA_Q; ++mma_q) { for (auto mma_kv = 0ul; mma_kv < NUM_MMA_KV; ++mma_kv) { for (auto reg_id = 0ul; reg_id < HALF_ELEMS_PER_THREAD; @@ -2503,6 +2451,12 @@ SinglePrefillWithKVCacheDevice(const Params params, // compute m,d states in online softmax update_mdo_states(variant, s_frag, o_frag, m, d); +#if Debug + if (warp_idx == 0 && lane_idx == 0) { + printf("Max value for first 32 cols of row 0 %f\n", m[0][0]); + } +#endif + block.sync(); produce_kv( k_smem, &k_smem_offset_w, &k_ptr, k_stride_n, diff --git a/libflashinfer/include/flashinfer/attention/generic/variants.cuh b/libflashinfer/include/flashinfer/attention/generic/variants.cuh index 019c441224..3c488dd5c1 100644 --- a/libflashinfer/include/flashinfer/attention/generic/variants.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/variants.cuh @@ -89,8 +89,11 @@ struct DefaultAttention : AttentionVariantBase float(int(kv_idx) - int(qo_idx)); } if constexpr (use_logits_soft_cap) { - logits = float( - gpu_iface::math::tanh(logits * soft_cap_pre_tanh_scale)); +#if Debug + logits = float(logits * soft_cap_pre_tanh_scale); +#else + logits = float(gpu_iface::math::tanh(logits * soft_cap_pre_tanh_scale)); +#endif } return logits; }) diff --git a/libflashinfer/utils/cpu_reference_hip.h b/libflashinfer/utils/cpu_reference_hip.h index ee6bc43e7c..6e1b7876d4 100644 --- a/libflashinfer/utils/cpu_reference_hip.h +++ b/libflashinfer/utils/cpu_reference_hip.h @@ -205,12 +205,12 @@ single_mha(const std::vector &q, float logits_soft_cap = 8.0f, float rope_scale = 1.f, float rope_theta = 1e4, - bool use_soft_cap = true) + bool use_soft_cap = false) { assert(qo_len <= kv_len); assert(num_qo_heads % num_kv_heads == 0); float sm_scale = 1.f / std::sqrt(float(head_dim)); - // float sm_scale = 1.0; + //float sm_scale = 1.0; std::vector o(qo_len * num_qo_heads * head_dim); std::vector att(kv_len); std::vector q_rotary_local(head_dim); @@ -260,50 +260,16 @@ single_mha(const std::vector &q, switch (pos_encoding_mode) { case PosEncodingMode::kNone: { - for (size_t feat_idx = 0; feat_idx < head_dim; - ++feat_idx) + for (size_t feat_idx = 0; feat_idx < head_dim; ++feat_idx) { - if (use_soft_cap) { - auto score = - fi::con::explicit_casting( - q[info.get_q_elem_offset( - q_idx, qo_head_idx, feat_idx)]) * - fi::con::explicit_casting( - k[info.get_kv_elem_offset( - kv_idx, kv_head_idx, feat_idx)]); - auto tscore = float( - std::tanh(score * soft_cap_pre_tanh_scale)); - att[kv_idx] += tscore; -#if Debug1 - if (qo_head_idx == 0 && q_idx == 0 && - kv_idx < 16) - { - std::cout << "score: " << score - << " Transformed : " << tscore - << " Final : " << att[kv_idx] - << '\n'; - } -#endif - } - else { - att[kv_idx] += - fi::con::explicit_casting( - q[info.get_q_elem_offset( - q_idx, qo_head_idx, feat_idx)]) * - fi::con::explicit_casting( - k[info.get_kv_elem_offset( - kv_idx, kv_head_idx, feat_idx)]) * - sm_scale; -#if Debug1 - if (qo_head_idx == 0 && q_idx == 0 && - kv_idx < 16) - { - std::cout - << "Post-transform: " << att[kv_idx] - << '\n'; - } -#endif - } + att[kv_idx] += + fi::con::explicit_casting( + q[info.get_q_elem_offset( + q_idx, qo_head_idx, feat_idx)]) * + fi::con::explicit_casting( + k[info.get_kv_elem_offset( + kv_idx, kv_head_idx, feat_idx)]) * + sm_scale; } break; } @@ -317,17 +283,9 @@ single_mha(const std::vector &q, for (size_t feat_idx = 0; feat_idx < head_dim; ++feat_idx) { - if (use_soft_cap) { - att[kv_idx] += q_rotary_local[feat_idx] * - k_rotary_local[feat_idx]; - att[kv_idx] = std::tanh( - att[kv_idx] * soft_cap_pre_tanh_scale); - } - else { - att[kv_idx] += q_rotary_local[feat_idx] * - k_rotary_local[feat_idx] * - sm_scale; - } + att[kv_idx] += q_rotary_local[feat_idx] * + k_rotary_local[feat_idx] * + sm_scale; } break; } @@ -345,13 +303,17 @@ single_mha(const std::vector &q, max_val = std::max(max_val, att[kv_idx]); } -#if Debug1 - if (qo_head_idx == 0) { - for (auto i = 0ul; i < 16; ++i) { - std::cout << att[i] << '\n'; +#if Debug + if (qo_head_idx == 0 && q_idx == 0) { + // for qo_len = 128, each warp on the GPU will store 128/4, + // that is, 32 attention scores. For CDNA3, these 32 scores + // are spread across 4 threads. + for(auto i = 0ul; i < 32; ++i) { + std::cout << " >>>>> scaled att " << att[i] << '\n'; } + std::cout << "Max value for warp 0 = " << *std::max_element(att.begin(),att.begin()+32) << '\n'; } -#endif +#endif // exp minus max float denom = 0; for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { diff --git a/libflashinfer/utils/flashinfer_prefill_ops.hip.h b/libflashinfer/utils/flashinfer_prefill_ops.hip.h index dc2d9e1853..4f3564365e 100644 --- a/libflashinfer/utils/flashinfer_prefill_ops.hip.h +++ b/libflashinfer/utils/flashinfer_prefill_ops.hip.h @@ -56,8 +56,7 @@ hipError_t SinglePrefillWithKVCacheCustomMask( float rope_theta = 1e4, hipStream_t stream = nullptr) { - const float sm_scale = - maybe_sm_scale.value_or(1.f / std::sqrt(float(head_dim))); + const float sm_scale = 1.f; auto [qo_stride_n, qo_stride_h, kv_stride_n, kv_stride_h] = get_qkv_strides( kv_layout, kv_len, num_qo_heads, num_kv_heads, head_dim); DISPATCH_use_fp16_qk_reduction( @@ -130,8 +129,7 @@ hipError_t SinglePrefillWithKVCache( float rope_theta = 1e4, hipStream_t stream = nullptr) { - const float sm_scale = - maybe_sm_scale.value_or(1.f / std::sqrt(float(head_dim))); + const float sm_scale = 1.f; const MaskMode mask_mode = causal ? MaskMode::kCausal : MaskMode::kNone; auto [qo_stride_n, qo_stride_h, kv_stride_n, kv_stride_h] = get_qkv_strides( kv_layout, kv_len, num_qo_heads, num_kv_heads, head_dim); From e9aec3d3593d0622971ab1c14da65f5002e34a2d Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Wed, 10 Sep 2025 15:03:10 -0400 Subject: [PATCH 070/109] sfrag debug writer --- .../flashinfer/attention/generic/prefill.cuh | 325 +++++++++++++----- .../include/gpu_iface/backend/hip/mma_hip.h | 8 +- libflashinfer/utils/cpu_reference_hip.h | 39 ++- 3 files changed, 277 insertions(+), 95 deletions(-) diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh index 1e9ee0f89d..fdbdd534da 100644 --- a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -159,10 +159,10 @@ struct KernelTraits // Presently we use 16x4 thread layout for all cases. static constexpr uint32_t KV_THR_LAYOUT_ROW = WARP_THREAD_ROWS; static constexpr uint32_t KV_THR_LAYOUT_COL = WARP_THREAD_COLS; - // FIXME: [The comment is not correct] The constant is defined based on the matrix layout of the "D/C" - // accumulator matrix in a D = A*B+C computation. On CDNA3 the D/C matrices - // are distributed as four 4x16 bands across the 64 threads. Each thread - // owns one element from four different rows. + // FIXME: [The comment is not correct] The constant is defined based on the + // matrix layout of the "D/C" accumulator matrix in a D = A*B+C computation. + // On CDNA3 the D/C matrices are distributed as four 4x16 bands across the + // 64 threads. Each thread owns one element from four different rows. static constexpr uint32_t NUM_ACCUM_ROWS_PER_THREAD = 1; // Number of threads that collaboratively handle the same set of matrix rows // in attention score computation and cross-warp synchronization. @@ -260,31 +260,6 @@ struct KernelTraits namespace { -template -__device__ __forceinline__ void -debug_printer(uint32_t threadid, const char *var_name, T val) -{ - int global_idx = (blockIdx.z * gridDim.y * gridDim.x + - blockIdx.y * gridDim.x + blockIdx.x) * - (blockDim.z * blockDim.y * blockDim.x) + - (threadIdx.z * blockDim.y * blockDim.x + - threadIdx.y * blockDim.x + threadIdx.x); - - if (global_idx == 0 || global_idx == 16 || global_idx == 32 || - global_idx == 48) - { - if constexpr (std::is_integral_v) { - printf("%s : %d\n", var_name, (int)val); - } - else if constexpr (std::is_floating_point_v) { - printf("%s : %f\n", var_name, (float)val); - } - else { - printf("%s : (unsupported type)\n", var_name); - } - } -} - template __device__ __forceinline__ uint32_t get_warp_idx_q(const uint32_t tid_y = threadIdx.y) @@ -1331,14 +1306,18 @@ logits_mask(const Params ¶ms, template __device__ __forceinline__ void update_mdo_states( typename KTraits::AttentionVariant variant, - typename KTraits::DTypeQKAccum (*s_frag)[KTraits::NUM_MMA_KV][KTraits::HALF_ELEMS_PER_THREAD], + typename KTraits::DTypeQKAccum ( + *s_frag)[KTraits::NUM_MMA_KV][KTraits::HALF_ELEMS_PER_THREAD], float (*o_frag)[KTraits::NUM_MMA_D_VO][KTraits::HALF_ELEMS_PER_THREAD], typename KTraits::DTypeQKAccum (*m)[KTraits::NUM_ACCUM_ROWS_PER_THREAD], - float (*d)[KTraits::NUM_ACCUM_ROWS_PER_THREAD]) + float (*d)[KTraits::NUM_ACCUM_ROWS_PER_THREAD], + uint32_t warp_idx = 0, + uint32_t lane_idx = 0) { using DTypeQKAccum = typename KTraits::DTypeQKAccum; using AttentionVariant = typename KTraits::AttentionVariant; - constexpr uint32_t NUM_ACCUM_ROWS_PER_THREAD = KTraits::NUM_ACCUM_ROWS_PER_THREAD; + constexpr uint32_t NUM_ACCUM_ROWS_PER_THREAD = + KTraits::NUM_ACCUM_ROWS_PER_THREAD; constexpr bool use_softmax = AttentionVariant::use_softmax; if constexpr (use_softmax) { @@ -1350,29 +1329,40 @@ __device__ __forceinline__ void update_mdo_states( for (uint32_t j = 0; j < NUM_ACCUM_ROWS_PER_THREAD; ++j) { float m_prev = m[mma_q][j]; #pragma unroll - for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) + for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; + ++mma_kv) { #if defined(PLATFORM_HIP_DEVICE) - m[mma_q][j] = max(max(s_frag[mma_q][mma_kv][0], s_frag[mma_q][mma_kv][1]), - max(s_frag[mma_q][mma_kv][2], s_frag[mma_q][mma_kv][3])); + auto m_local = max(max(s_frag[mma_q][mma_kv][0], + s_frag[mma_q][mma_kv][1]), + max(s_frag[mma_q][mma_kv][2], + s_frag[mma_q][mma_kv][3])); + m[mma_q][j] = max(m[mma_q][j], m_local); #else float m_local = - max(max(s_frag[mma_q][mma_kv][j * 2 + 0], s_frag[mma_q][mma_kv][j * 2 + 1]), - max(s_frag[mma_q][mma_kv][j * 2 + 4], s_frag[mma_q][mma_kv][j * 2 + 5])); + max(max(s_frag[mma_q][mma_kv][j * 2 + 0], + s_frag[mma_q][mma_kv][j * 2 + 1]), + max(s_frag[mma_q][mma_kv][j * 2 + 4], + s_frag[mma_q][mma_kv][j * 2 + 5])); m[mma_q][j] = max(m[mma_q][j], m_local); #endif } #if defined(PLATFORM_HIP_DEVICE) - // Butterfly reduction across all threads in the band - m[mma_q][j] = max(m[mma_q][j], gpu_iface::math::shfl_xor_sync(m[mma_q][j], 0x10)); // 32 apart - m[mma_q][j] = max(m[mma_q][j], gpu_iface::math::shfl_xor_sync(m[mma_q][j], 0x8)); // 16 apart - - float o_scale = gpu_iface::math::ptx_exp2(m_prev * sm_scale - m[mma_q][j] * sm_scale); + // Butterfly reduction across all threads in the band + m[mma_q][j] = + max(m[mma_q][j], gpu_iface::math::shfl_xor_sync( + m[mma_q][j], 0x10)); // 32 apart + m[mma_q][j] = + max(m[mma_q][j], gpu_iface::math::shfl_xor_sync( + m[mma_q][j], 0x8)); // 16 apart + float o_scale = gpu_iface::math::ptx_exp2( + m_prev * sm_scale - m[mma_q][j] * sm_scale); d[mma_q][j] *= o_scale; // Scale output fragments for this specific row #pragma unroll - for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) + for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; + ++mma_d) { o_frag[mma_q][mma_d][0] *= o_scale; o_frag[mma_q][mma_d][1] *= o_scale; @@ -1382,24 +1372,35 @@ __device__ __forceinline__ void update_mdo_states( // Convert logits to probabilities for this row #pragma unroll - for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) + for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; + ++mma_kv) { s_frag[mma_q][mma_kv][0] = gpu_iface::math::ptx_exp2( - s_frag[mma_q][mma_kv][0] * sm_scale - m[mma_q][j] * sm_scale); + s_frag[mma_q][mma_kv][0] * sm_scale - + m[mma_q][j] * sm_scale); s_frag[mma_q][mma_kv][1] = gpu_iface::math::ptx_exp2( - s_frag[mma_q][mma_kv][1] * sm_scale - m[mma_q][j] * sm_scale); + s_frag[mma_q][mma_kv][1] * sm_scale - + m[mma_q][j] * sm_scale); s_frag[mma_q][mma_kv][2] = gpu_iface::math::ptx_exp2( - s_frag[mma_q][mma_kv][2] * sm_scale - m[mma_q][j] * sm_scale); + s_frag[mma_q][mma_kv][2] * sm_scale - + m[mma_q][j] * sm_scale); s_frag[mma_q][mma_kv][3] = gpu_iface::math::ptx_exp2( - s_frag[mma_q][mma_kv][3] * sm_scale - m[mma_q][j] * sm_scale); + s_frag[mma_q][mma_kv][3] * sm_scale - + m[mma_q][j] * sm_scale); } -#else - m[mma_q][j] = max(m[mma_q][j], gpu_iface::math::shfl_xor_sync(m[mma_q][j], 0x2)); - m[mma_q][j] = max(m[mma_q][j], gpu_iface::math::shfl_xor_sync(m[mma_q][j], 0x1)); - float o_scale = gpu_iface::math::ptx_exp2(m_prev * sm_scale - m[mma_q][j] * sm_scale); +#else // CUDA PATH + m[mma_q][j] = + max(m[mma_q][j], + gpu_iface::math::shfl_xor_sync(m[mma_q][j], 0x2)); + m[mma_q][j] = + max(m[mma_q][j], + gpu_iface::math::shfl_xor_sync(m[mma_q][j], 0x1)); + float o_scale = gpu_iface::math::ptx_exp2( + m_prev * sm_scale - m[mma_q][j] * sm_scale); d[mma_q][j] *= o_scale; #pragma unroll - for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) + for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; + ++mma_d) { o_frag[mma_q][mma_d][j * 2 + 0] *= o_scale; o_frag[mma_q][mma_d][j * 2 + 1] *= o_scale; @@ -1407,10 +1408,15 @@ __device__ __forceinline__ void update_mdo_states( o_frag[mma_q][mma_d][j * 2 + 5] *= o_scale; } #pragma unroll - for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) + for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; + ++mma_kv) { - s_frag[mma_q][mma_kv][j * 2 + 0] = gpu_iface::math::ptx_exp2(s_frag[mma_q][mma_kv][j * 2 + 0] * sm_scale - m[mma_q][j] * sm_scale); - s_frag[mma_q][mma_kv][j * 2 + 1] = gpu_iface::math::ptx_exp2( + s_frag[mma_q][mma_kv][j * 2 + 0] = + gpu_iface::math::ptx_exp2( + s_frag[mma_q][mma_kv][j * 2 + 0] * sm_scale - + m[mma_q][j] * sm_scale); + s_frag[mma_q][mma_kv][j * 2 + 1] = + gpu_iface::math::ptx_exp2( s_frag[mma_q][mma_kv][j * 2 + 1] * sm_scale - m[mma_q][j] * sm_scale); s_frag[mma_q][mma_kv][j * 2 + 4] = @@ -1500,12 +1506,13 @@ __device__ __forceinline__ void compute_sfm_v( typename KTraits::DTypeQKAccum ( *s_frag)[KTraits::NUM_MMA_KV][KTraits::HALF_ELEMS_PER_THREAD], float (*o_frag)[KTraits::NUM_MMA_D_VO][KTraits::HALF_ELEMS_PER_THREAD], - float (*d)[KTraits::NUM_ACCUM_ROWS_PER_THREAD]) + float (*d)[KTraits::NUM_ACCUM_ROWS_PER_THREAD], + uint32_t warp_idx = 0, + uint32_t lane_idx = 0) { constexpr uint32_t UPCAST_STRIDE_V = KTraits::UPCAST_STRIDE_V; constexpr uint32_t HALF_ELEMS_PER_THREAD = KTraits::HALF_ELEMS_PER_THREAD; constexpr uint32_t INT32_ELEMS_PER_THREAD = KTraits::INT32_ELEMS_PER_THREAD; - constexpr uint32_t V_SMEM_COLUMN_ADVANCE = 16 / KTraits::HALF_ELEMS_PER_THREAD; @@ -1533,6 +1540,9 @@ __device__ __forceinline__ void compute_sfm_v( { mma::m16k16_rowsum_f16f16f32(d[mma_q], s_frag_f16[mma_q][mma_kv]); + if (warp_idx == 0 && lane_idx == 1) { + printf("D values : %f\n", *d[mma_q]); + } } else { #if defined(PLATFORM_HIP_DEVICE) @@ -2018,6 +2028,74 @@ __device__ __forceinline__ void write_o_reg_gmem( } // namespace +template +__device__ __forceinline__ void debug_write_sfrag_to_scratch( + typename KTraits::DTypeQKAccum ( + *s_frag)[KTraits::NUM_MMA_KV][KTraits::HALF_ELEMS_PER_THREAD], + smem_t *scratch, + const dim3 tid = threadIdx) +{ + using DTypeQKAccum = typename KTraits::DTypeQKAccum; + constexpr uint32_t NUM_MMA_Q = KTraits::NUM_MMA_Q; + constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; + constexpr uint32_t HALF_ELEMS_PER_THREAD = KTraits::HALF_ELEMS_PER_THREAD; + const uint32_t warp_idx = get_warp_idx(tid.y, tid.z), + lane_idx = tid.x; + + // For CDNA3 with 16×4 thread layout: + uint32_t row = lane_idx % 16; + uint32_t col = lane_idx / 16; + + // Total matrix dimensions + constexpr uint32_t total_cols = NUM_MMA_KV * 16; + uint32_t offset = + scratch->template get_permuted_offset(row, col); + auto halfCastedBase = reinterpret_cast<__half *>(scratch->base); + + // Write all thread's fragments to shared memory + for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { + for (uint32_t mma_kv = 0; mma_kv < NUM_MMA_KV; ++mma_kv) { + + // if(lane_idx == 0 && warp_idx == 0) { + // printf("debug_write_sfrag_to_scratch..............\n"); + // for (auto reg_id = 0ul; reg_id < HALF_ELEMS_PER_THREAD; + // ++reg_id) + // { + // auto tmp = s_frag[mma_q][mma_kv][reg_id]; + // printf("s_frag[%u][%u][%lu] : %f ", mma_q, mma_kv, + // reg_id, float(tmp)); + // } + // printf("\n"); + // } + + for (auto reg_id = 0ul; reg_id < HALF_ELEMS_PER_THREAD; ++reg_id) { + auto tmp = s_frag[mma_q][mma_kv][reg_id]; + *(halfCastedBase + offset * 4 + reg_id) = tmp; + } + + // if(lane_idx == 0 && warp_idx == 0) { + // uint32_t a_frag[KTraits::INT32_ELEMS_PER_THREAD]; + // scratch->load_fragment(offset, a_frag); + // auto frag_T = reinterpret_cast<__half *>(a_frag); + // for (auto i = 0ul; i < 4; ++i) { + // printf("----scratch-----> %f \n", (float)(*(frag_T + + // i))); + // } + + // printf("KTraits::UPCAST_STRIDE_O %d\n", + // KTraits::UPCAST_STRIDE_O); + // } + offset = + scratch->template advance_offset_by_column<4>(offset, mma_kv); + } + offset = + scratch + ->template advance_offset_by_row<16, KTraits::UPCAST_STRIDE_O>( + offset); + } + __syncthreads(); +} + /*! * \brief FlashAttention prefill CUDA kernel for a single request. * \tparam partition_kv Whether to split kv_len into chunks. @@ -2141,7 +2219,8 @@ SinglePrefillWithKVCacheDevice(const Params params, const uint32_t window_left = variant.window_left; DTypeQKAccum s_frag[NUM_MMA_Q][NUM_MMA_KV][HALF_ELEMS_PER_THREAD]; - alignas(16) float o_frag[NUM_MMA_Q][NUM_MMA_D_VO][HALF_ELEMS_PER_THREAD]; + alignas( + 16) float o_frag[NUM_MMA_Q][NUM_MMA_D_VO][HALF_ELEMS_PER_THREAD]; DTypeQKAccum m[NUM_MMA_Q][NUM_ACCUM_ROWS_PER_THREAD]; float d[NUM_MMA_Q][NUM_ACCUM_ROWS_PER_THREAD]; float rope_freq[NUM_MMA_D_QK / 2][4]; @@ -2267,6 +2346,10 @@ SinglePrefillWithKVCacheDevice(const Params params, memory::commit_group(); #if Debug + + smem_t scratch( + smem_storage.qk_scratch); + if (warp_idx == 0 && lane_idx == 0) { printf("partition_kv : %d\n", partition_kv); printf("kv_len : %d\n", kv_len); @@ -2274,7 +2357,7 @@ SinglePrefillWithKVCacheDevice(const Params params, printf("chunk_end : %d\n", chunk_end); printf("chunk_start : %d\n", chunk_start); } - +#if 0 // Test Q if (warp_idx == 0 && lane_idx == 0) { printf("\n DEBUG Q ORIGINAL (HIP):\n"); @@ -2296,7 +2379,7 @@ SinglePrefillWithKVCacheDevice(const Params params, 16, KTraits::UPCAST_STRIDE_Q>(q_smem_offset_r_debug); } } - +#endif // Test K Global values: // Prints the (NUM_MMA_KV*16) x (NUM_MMA_D*16) matrix from global mem. @@ -2335,10 +2418,11 @@ SinglePrefillWithKVCacheDevice(const Params params, // Note that LDS is loaded collaboratively by all warps and not each // warp accesses the whole K matrix loaded into LDS. Each warp will // only access 1/4 of the K values loaded into LDS. +#if 0 if (warp_idx == 0 && lane_idx == 0) { printf("\n DEBUG K LDS ORIGINAL (HIP):\n"); uint32_t k_smem_offset_r_debug; - for (auto i = 0; i < NUM_MMA_KV * 16 * 2; ++i) { + for (auto i = 0; i < NUM_MMA_KV * 16; ++i) { for (auto j = 0; j < NUM_MMA_D_QK * 4; ++j) { k_smem_offset_r_debug = k_smem.template get_permuted_offset(i, @@ -2356,6 +2440,7 @@ SinglePrefillWithKVCacheDevice(const Params params, } } #endif +#endif #pragma unroll 1 for (uint32_t iter = 0; iter < num_iterations; ++iter) { @@ -2375,20 +2460,47 @@ SinglePrefillWithKVCacheDevice(const Params params, // compute attention score compute_qk(&qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag); -#if Debug +#if Debug1 + debug_write_sfrag_to_scratch(s_frag, &scratch, tid); + if (warp_idx == 0 && lane_idx == 0) { printf("s_frag results after compute_qk: \n"); - for (auto mma_q = 0ul; mma_q < NUM_MMA_Q; ++mma_q) { - for (auto mma_kv = 0ul; mma_kv < NUM_MMA_KV; ++mma_kv) { - for (auto reg_id = 0ul; reg_id < HALF_ELEMS_PER_THREAD; ++reg_id) - { - auto tmp = s_frag[mma_q][mma_kv][reg_id]; - printf("s_frag[%lu][%lu][%lu] : %f ", mma_q, mma_kv, reg_id, float(tmp)); + uint32_t scratch_offset_r_debug; + for (auto i = 0; i < NUM_MMA_KV * 16 * 2; ++i) { + for (auto j = 0; j < NUM_MMA_D_QK * 4; ++j) { + scratch_offset_r_debug = + scratch + .template get_permuted_offset( + i, j); + uint32_t a_frag[KTraits::INT32_ELEMS_PER_THREAD]; + scratch.load_fragment(scratch_offset_r_debug, a_frag); + auto frag_T = reinterpret_cast<__half *>(a_frag); + for (auto i = 0ul; i < 4; ++i) { + printf("%f ", (float)(*(frag_T + i))); } - printf("\n"); } + printf("\n"); + scratch.template advance_offset_by_row< + 16, KTraits::UPCAST_STRIDE_K>(scratch_offset_r_debug); } } + + // if (warp_idx == 0 && lane_idx == 0 && iter == 0) { + // printf("s_frag results after compute_qk: \n"); + // for (auto mma_q = 0ul; mma_q < NUM_MMA_Q * 16 * 4; ++mma_q) { + // for (auto mma_kv = 0ul; mma_kv < NUM_MMA_KV*4; ++mma_kv) + // { + // for (auto reg_id = 0ul; reg_id < + // HALF_ELEMS_PER_THREAD; ++reg_id) + // { + // auto tmp = s_frag[mma_q][mma_kv][reg_id]; + // printf("s_frag[%lu][%lu][%lu] : %f ", mma_q, + // mma_kv, reg_id, float(tmp)); + // } + // printf("\n"); + // } + // } + // } #endif logits_transform( params, variant, /*batch_idx=*/0, qo_packed_idx_base, @@ -2397,13 +2509,15 @@ SinglePrefillWithKVCacheDevice(const Params params, NUM_MMA_KV * 16, qo_len, kv_len, group_size, s_frag, tid, kv_head_idx); -#if Debug +#if Debug1 float soft_cap_pre_tanh_scale = params.sm_scale * gpu_iface::math::ptx_rcp(params.logits_soft_cap); - if (warp_idx == 0 && lane_idx == 0) { - printf("params.sm_scale %f, params.logits_soft_cap %f\n", params.sm_scale, params.logits_soft_cap); - printf("s_frag after logits transform (scaled by %f) : \n", soft_cap_pre_tanh_scale); + if (warp_idx == 0 && lane_idx == 0 && iter == 0) { + printf("params.sm_scale %f, params.logits_soft_cap %f\n", + params.sm_scale, params.logits_soft_cap); + printf("s_frag after logits transform (scaled by %f) : \n", + soft_cap_pre_tanh_scale); for (auto mma_q = 0ul; mma_q < NUM_MMA_Q; ++mma_q) { for (auto mma_kv = 0ul; mma_kv < NUM_MMA_KV; ++mma_kv) { for (auto reg_id = 0ul; reg_id < HALF_ELEMS_PER_THREAD; @@ -2418,6 +2532,37 @@ SinglePrefillWithKVCacheDevice(const Params params, } } #endif + +#if Debug + debug_write_sfrag_to_scratch(s_frag, &scratch, tid); + if (warp_idx == 0 && lane_idx == 0) { + float soft_cap_pre_tanh_scale = + params.sm_scale * + gpu_iface::math::ptx_rcp(params.logits_soft_cap); + printf("params.sm_scale %f, params.logits_soft_cap %f\n", + params.sm_scale, params.logits_soft_cap); + printf("s_frag after logits transform (scaled by %f) : \n", + soft_cap_pre_tanh_scale); + uint32_t scratch_offset_r_debug; + for (auto i = 0; i < NUM_MMA_KV * 16 * 2; ++i) { + for (auto j = 0; j < NUM_MMA_D_QK * 4; ++j) { + scratch_offset_r_debug = + scratch + .template get_permuted_offset( + i, j); + uint32_t a_frag[KTraits::INT32_ELEMS_PER_THREAD]; + scratch.load_fragment(scratch_offset_r_debug, a_frag); + auto frag_T = reinterpret_cast<__half *>(a_frag); + for (auto i = 0ul; i < 4; ++i) { + printf("%f ", (float)(*(frag_T + i))); + } + } + printf("\n"); + scratch.template advance_offset_by_row< + 16, KTraits::UPCAST_STRIDE_K>(scratch_offset_r_debug); + } + } +#endif // apply mask if (MASK_MODE == MaskMode::kCustom || (iter >= mask_iteration || iter < window_iteration)) @@ -2432,7 +2577,7 @@ SinglePrefillWithKVCacheDevice(const Params params, } #if Debug - if (warp_idx == 0 && lane_idx == 0) { + if (warp_idx == 0 && lane_idx == 0 && iter == 0) { printf("s_frag after logits masking\n"); for (auto mma_q = 0ul; mma_q < NUM_MMA_Q; ++mma_q) { for (auto mma_kv = 0ul; mma_kv < NUM_MMA_KV; ++mma_kv) { @@ -2449,11 +2594,33 @@ SinglePrefillWithKVCacheDevice(const Params params, } #endif // compute m,d states in online softmax - update_mdo_states(variant, s_frag, o_frag, m, d); + update_mdo_states(variant, s_frag, o_frag, m, d, warp_idx, + lane_idx); #if Debug - if (warp_idx == 0 && lane_idx == 0) { + if (warp_idx == 0 && lane_idx == 0 && iter == 0) { printf("Max value for first 32 cols of row 0 %f\n", m[0][0]); + printf("D value for first 32 cols of row 0 %f\n", d[0][0]); + } +#endif + +#if Debug + if (warp_idx == 0 && lane_idx == 0 && iter == 0) { + printf("gpu_iface::math::ptx_exp2(0) = %f\n", + gpu_iface::math::ptx_exp2(0.f)); + printf("s_frag after exp - max att\n"); + for (auto mma_q = 0ul; mma_q < NUM_MMA_Q; ++mma_q) { + for (auto mma_kv = 0ul; mma_kv < NUM_MMA_KV; ++mma_kv) { + for (auto reg_id = 0ul; reg_id < HALF_ELEMS_PER_THREAD; + ++reg_id) + { + auto tmp = s_frag[mma_q][mma_kv][reg_id]; + printf("s_frag[%lu][%lu][%lu] : %f ", mma_q, mma_kv, + reg_id, float(tmp)); + } + printf("\n"); + } + } } #endif @@ -2466,8 +2633,8 @@ SinglePrefillWithKVCacheDevice(const Params params, block.sync(); // compute sfm*v - compute_sfm_v(&v_smem, &v_smem_offset_r, s_frag, o_frag, - d); + compute_sfm_v(&v_smem, &v_smem_offset_r, s_frag, o_frag, d, + warp_idx, lane_idx); block.sync(); produce_kv( diff --git a/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h b/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h index b7bb15a2e0..fdb907abf2 100644 --- a/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h +++ b/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h @@ -184,12 +184,12 @@ __device__ __forceinline__ void m16k16_rowsum_f16f16f32(float *d, DType *s_frag) static_assert(sizeof(DType) == 2, "DType must be 16-bit type"); f16x4 a = reinterpret_cast(s_frag)[0]; f16x4 b = {f16(1.0f), f16(1.0f), f16(1.0f), f16(1.0f)}; - f32x4 c = {0.f, 0.f, 0.f, 0.f}; + f32x4 c = {d[0], 0.f, 0.f, 0.f}; f32x4 out = __builtin_amdgcn_mfma_f32_16x16x16f16(a, b, c, 0, 0, 0); d[0] = out.x; - d[1] = out.y; - d[2] = out.z; - d[3] = out.w; + // d[1] = out.y; + // d[2] = out.z; + // d[3] = out.w; } // TODO (rimaddur) : After release 2025.08 diff --git a/libflashinfer/utils/cpu_reference_hip.h b/libflashinfer/utils/cpu_reference_hip.h index 6e1b7876d4..bd5cc328a4 100644 --- a/libflashinfer/utils/cpu_reference_hip.h +++ b/libflashinfer/utils/cpu_reference_hip.h @@ -210,7 +210,7 @@ single_mha(const std::vector &q, assert(qo_len <= kv_len); assert(num_qo_heads % num_kv_heads == 0); float sm_scale = 1.f / std::sqrt(float(head_dim)); - //float sm_scale = 1.0; + // float sm_scale = 1.0; std::vector o(qo_len * num_qo_heads * head_dim); std::vector att(kv_len); std::vector q_rotary_local(head_dim); @@ -260,12 +260,13 @@ single_mha(const std::vector &q, switch (pos_encoding_mode) { case PosEncodingMode::kNone: { - for (size_t feat_idx = 0; feat_idx < head_dim; ++feat_idx) + for (size_t feat_idx = 0; feat_idx < head_dim; + ++feat_idx) { att[kv_idx] += fi::con::explicit_casting( - q[info.get_q_elem_offset( - q_idx, qo_head_idx, feat_idx)]) * + q[info.get_q_elem_offset(q_idx, qo_head_idx, + feat_idx)]) * fi::con::explicit_casting( k[info.get_kv_elem_offset( kv_idx, kv_head_idx, feat_idx)]) * @@ -284,8 +285,7 @@ single_mha(const std::vector &q, ++feat_idx) { att[kv_idx] += q_rotary_local[feat_idx] * - k_rotary_local[feat_idx] * - sm_scale; + k_rotary_local[feat_idx] * sm_scale; } break; } @@ -303,17 +303,17 @@ single_mha(const std::vector &q, max_val = std::max(max_val, att[kv_idx]); } -#if Debug - if (qo_head_idx == 0 && q_idx == 0) { +#if Debug + if (qo_head_idx == 0) { // for qo_len = 128, each warp on the GPU will store 128/4, // that is, 32 attention scores. For CDNA3, these 32 scores // are spread across 4 threads. - for(auto i = 0ul; i < 32; ++i) { - std::cout << " >>>>> scaled att " << att[i] << '\n'; + for (auto i = 0ul; i < 128; ++i) { + std::cout << att[i] << " "; } - std::cout << "Max value for warp 0 = " << *std::max_element(att.begin(),att.begin()+32) << '\n'; + std::cout << std::endl; } -#endif +#endif // exp minus max float denom = 0; for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { @@ -321,6 +321,21 @@ single_mha(const std::vector &q, denom += att[kv_idx]; } +#if Debug1 + if (qo_head_idx == 0 && q_idx == 1) { + std::cout << "D vaulued CPU q0 " << denom << '\n'; + } + if (qo_head_idx == 0 && q_idx == 0) { + // for qo_len = 128, each warp on the GPU will store 128/4, + // that is, 32 attention scores. For CDNA3, these 32 scores + // are spread across 4 threads. + for (auto i = 0ul; i < 64; ++i) { + std::cout << " >>>>> after exp - max att " << att[i] + << '\n'; + } + } +#endif + // divide by denom for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { att[kv_idx] /= denom; From 7da105d4a7b3be7d567cd16ef7bc0982a75db6b5 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Wed, 10 Sep 2025 15:24:30 -0400 Subject: [PATCH 071/109] sfrag more debugging --- .../include/flashinfer/attention/generic/prefill.cuh | 5 ++--- libflashinfer/utils/cpu_reference_hip.h | 11 ++++------- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh index fdbdd534da..ee417716d8 100644 --- a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -2053,9 +2053,8 @@ __device__ __forceinline__ void debug_write_sfrag_to_scratch( auto halfCastedBase = reinterpret_cast<__half *>(scratch->base); // Write all thread's fragments to shared memory - for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { - for (uint32_t mma_kv = 0; mma_kv < NUM_MMA_KV; ++mma_kv) { - + for (uint32_t mma_kv = 0; mma_kv < NUM_MMA_KV; ++mma_kv) { + for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { // if(lane_idx == 0 && warp_idx == 0) { // printf("debug_write_sfrag_to_scratch..............\n"); // for (auto reg_id = 0ul; reg_id < HALF_ELEMS_PER_THREAD; diff --git a/libflashinfer/utils/cpu_reference_hip.h b/libflashinfer/utils/cpu_reference_hip.h index bd5cc328a4..049731d1bd 100644 --- a/libflashinfer/utils/cpu_reference_hip.h +++ b/libflashinfer/utils/cpu_reference_hip.h @@ -322,17 +322,14 @@ single_mha(const std::vector &q, } #if Debug1 - if (qo_head_idx == 0 && q_idx == 1) { - std::cout << "D vaulued CPU q0 " << denom << '\n'; - } - if (qo_head_idx == 0 && q_idx == 0) { + if (qo_head_idx == 0) { // for qo_len = 128, each warp on the GPU will store 128/4, // that is, 32 attention scores. For CDNA3, these 32 scores // are spread across 4 threads. - for (auto i = 0ul; i < 64; ++i) { - std::cout << " >>>>> after exp - max att " << att[i] - << '\n'; + for (auto i = 0ul; i < 128; ++i) { + std::cout << att[i] << " "; } + std::cout << std::endl; } #endif From 1b794266543e35335cb8c09eec724e5cfd70b9f3 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Thu, 11 Sep 2025 16:11:47 -0400 Subject: [PATCH 072/109] Revert update_mod_changes --- .../flashinfer/attention/generic/prefill.cuh | 109 ++++++++---------- .../include/gpu_iface/backend/hip/mma_hip.h | 8 +- libflashinfer/utils/cpu_reference_hip.h | 2 +- 3 files changed, 53 insertions(+), 66 deletions(-) diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh index ee417716d8..2aa8fa9a89 100644 --- a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -163,7 +163,7 @@ struct KernelTraits // matrix layout of the "D/C" accumulator matrix in a D = A*B+C computation. // On CDNA3 the D/C matrices are distributed as four 4x16 bands across the // 64 threads. Each thread owns one element from four different rows. - static constexpr uint32_t NUM_ACCUM_ROWS_PER_THREAD = 1; + static constexpr uint32_t NUM_ACCUM_ROWS_PER_THREAD = 4; // Number of threads that collaboratively handle the same set of matrix rows // in attention score computation and cross-warp synchronization. // CUDA: 4 threads (each thread handles 2 elements from same row group) @@ -1339,35 +1339,33 @@ __device__ __forceinline__ void update_mdo_states( s_frag[mma_q][mma_kv][3])); m[mma_q][j] = max(m[mma_q][j], m_local); #else - float m_local = - max(max(s_frag[mma_q][mma_kv][j * 2 + 0], - s_frag[mma_q][mma_kv][j * 2 + 1]), - max(s_frag[mma_q][mma_kv][j * 2 + 4], - s_frag[mma_q][mma_kv][j * 2 + 5])); - m[mma_q][j] = max(m[mma_q][j], m_local); + m[mma_q][j] = + max(m[mma_q][j], s_frag[mma_q][mma_kv][j]); #endif } #if defined(PLATFORM_HIP_DEVICE) // Butterfly reduction across all threads in the band m[mma_q][j] = max(m[mma_q][j], gpu_iface::math::shfl_xor_sync( - m[mma_q][j], 0x10)); // 32 apart + m[mma_q][j], 0x8)); // 16 apart m[mma_q][j] = max(m[mma_q][j], gpu_iface::math::shfl_xor_sync( - m[mma_q][j], 0x8)); // 16 apart + m[mma_q][j], 0x4)); // 8 apart + m[mma_q][j] = + max(m[mma_q][j], gpu_iface::math::shfl_xor_sync( + m[mma_q][j], 0x2)); // 4 apart + m[mma_q][j] = + max(m[mma_q][j], gpu_iface::math::shfl_xor_sync( + m[mma_q][j], 0x1)); // 2 apart float o_scale = gpu_iface::math::ptx_exp2( m_prev * sm_scale - m[mma_q][j] * sm_scale); - d[mma_q][j] *= o_scale; // Scale output fragments for this specific row #pragma unroll for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { - o_frag[mma_q][mma_d][0] *= o_scale; - o_frag[mma_q][mma_d][1] *= o_scale; - o_frag[mma_q][mma_d][2] *= o_scale; - o_frag[mma_q][mma_d][3] *= o_scale; + o_frag[mma_q][mma_d][j] *= o_scale; } // Convert logits to probabilities for this row @@ -1375,17 +1373,8 @@ __device__ __forceinline__ void update_mdo_states( for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { - s_frag[mma_q][mma_kv][0] = gpu_iface::math::ptx_exp2( - s_frag[mma_q][mma_kv][0] * sm_scale - - m[mma_q][j] * sm_scale); - s_frag[mma_q][mma_kv][1] = gpu_iface::math::ptx_exp2( - s_frag[mma_q][mma_kv][1] * sm_scale - - m[mma_q][j] * sm_scale); - s_frag[mma_q][mma_kv][2] = gpu_iface::math::ptx_exp2( - s_frag[mma_q][mma_kv][2] * sm_scale - - m[mma_q][j] * sm_scale); - s_frag[mma_q][mma_kv][3] = gpu_iface::math::ptx_exp2( - s_frag[mma_q][mma_kv][3] * sm_scale - + s_frag[mma_q][mma_kv][j] = gpu_iface::math::ptx_exp2( + s_frag[mma_q][mma_kv][j] * sm_scale - m[mma_q][j] * sm_scale); } #else // CUDA PATH @@ -2053,44 +2042,42 @@ __device__ __forceinline__ void debug_write_sfrag_to_scratch( auto halfCastedBase = reinterpret_cast<__half *>(scratch->base); // Write all thread's fragments to shared memory - for (uint32_t mma_kv = 0; mma_kv < NUM_MMA_KV; ++mma_kv) { - for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { - // if(lane_idx == 0 && warp_idx == 0) { - // printf("debug_write_sfrag_to_scratch..............\n"); - // for (auto reg_id = 0ul; reg_id < HALF_ELEMS_PER_THREAD; - // ++reg_id) - // { - // auto tmp = s_frag[mma_q][mma_kv][reg_id]; - // printf("s_frag[%u][%u][%lu] : %f ", mma_q, mma_kv, - // reg_id, float(tmp)); - // } - // printf("\n"); - // } + for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { + for (uint32_t mma_kv = 0; mma_kv < NUM_MMA_KV; ++mma_kv) { + if (lane_idx == 1 && warp_idx == 0) { + printf("debug_write_sfrag_to_scratch..............\n"); + for (auto reg_id = 0ul; reg_id < HALF_ELEMS_PER_THREAD; + ++reg_id) + { + auto tmp = s_frag[mma_q][mma_kv][reg_id]; + printf("s_frag[%u][%u][%lu] : %f ", mma_q, mma_kv, reg_id, + float(tmp)); + } + printf("\n"); + } for (auto reg_id = 0ul; reg_id < HALF_ELEMS_PER_THREAD; ++reg_id) { auto tmp = s_frag[mma_q][mma_kv][reg_id]; *(halfCastedBase + offset * 4 + reg_id) = tmp; } - // if(lane_idx == 0 && warp_idx == 0) { - // uint32_t a_frag[KTraits::INT32_ELEMS_PER_THREAD]; - // scratch->load_fragment(offset, a_frag); - // auto frag_T = reinterpret_cast<__half *>(a_frag); - // for (auto i = 0ul; i < 4; ++i) { - // printf("----scratch-----> %f \n", (float)(*(frag_T + - // i))); - // } + if (lane_idx == 1 && warp_idx == 0) { + uint32_t a_frag[KTraits::INT32_ELEMS_PER_THREAD]; + scratch->load_fragment(offset, a_frag); + auto frag_T = reinterpret_cast<__half *>(a_frag); + for (auto i = 0ul; i < 4; ++i) { + printf("----scratch-----> %f \n", (float)(*(frag_T + i))); + } - // printf("KTraits::UPCAST_STRIDE_O %d\n", - // KTraits::UPCAST_STRIDE_O); - // } - offset = - scratch->template advance_offset_by_column<4>(offset, mma_kv); + printf("KTraits::UPCAST_STRIDE_O %d\n", + KTraits::UPCAST_STRIDE_O); + } + + offset = scratch->template advance_offset_by_row< + 16, KTraits::UPCAST_STRIDE_O>(offset); } - offset = - scratch - ->template advance_offset_by_row<16, KTraits::UPCAST_STRIDE_O>( - offset); + offset = scratch->template advance_offset_by_column<4>(offset, mma_q) - + 16 * NUM_MMA_KV * KTraits::UPCAST_STRIDE_O; } __syncthreads(); } @@ -2459,14 +2446,14 @@ SinglePrefillWithKVCacheDevice(const Params params, // compute attention score compute_qk(&qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag); -#if Debug1 +#if Debug debug_write_sfrag_to_scratch(s_frag, &scratch, tid); if (warp_idx == 0 && lane_idx == 0) { printf("s_frag results after compute_qk: \n"); uint32_t scratch_offset_r_debug; - for (auto i = 0; i < NUM_MMA_KV * 16 * 2; ++i) { - for (auto j = 0; j < NUM_MMA_D_QK * 4; ++j) { + for (auto i = 0; i < NUM_MMA_KV * 16; ++i) { + for (auto j = 0; j < NUM_MMA_D_QK * 2; ++j) { scratch_offset_r_debug = scratch .template get_permuted_offset( @@ -2532,7 +2519,7 @@ SinglePrefillWithKVCacheDevice(const Params params, } #endif -#if Debug +#if Debug1 debug_write_sfrag_to_scratch(s_frag, &scratch, tid); if (warp_idx == 0 && lane_idx == 0) { float soft_cap_pre_tanh_scale = @@ -2544,7 +2531,7 @@ SinglePrefillWithKVCacheDevice(const Params params, soft_cap_pre_tanh_scale); uint32_t scratch_offset_r_debug; for (auto i = 0; i < NUM_MMA_KV * 16 * 2; ++i) { - for (auto j = 0; j < NUM_MMA_D_QK * 4; ++j) { + for (auto j = 0; j < NUM_MMA_D_QK * 2; ++j) { scratch_offset_r_debug = scratch .template get_permuted_offset( @@ -2575,7 +2562,7 @@ SinglePrefillWithKVCacheDevice(const Params params, kv_head_idx); } -#if Debug +#if Debug1 if (warp_idx == 0 && lane_idx == 0 && iter == 0) { printf("s_frag after logits masking\n"); for (auto mma_q = 0ul; mma_q < NUM_MMA_Q; ++mma_q) { diff --git a/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h b/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h index fdb907abf2..7b32c54343 100644 --- a/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h +++ b/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h @@ -184,12 +184,12 @@ __device__ __forceinline__ void m16k16_rowsum_f16f16f32(float *d, DType *s_frag) static_assert(sizeof(DType) == 2, "DType must be 16-bit type"); f16x4 a = reinterpret_cast(s_frag)[0]; f16x4 b = {f16(1.0f), f16(1.0f), f16(1.0f), f16(1.0f)}; - f32x4 c = {d[0], 0.f, 0.f, 0.f}; + f32x4 c = {d[0], d[1], d[2], d[3]}; f32x4 out = __builtin_amdgcn_mfma_f32_16x16x16f16(a, b, c, 0, 0, 0); d[0] = out.x; - // d[1] = out.y; - // d[2] = out.z; - // d[3] = out.w; + d[1] = out.y; + d[2] = out.z; + d[3] = out.w; } // TODO (rimaddur) : After release 2025.08 diff --git a/libflashinfer/utils/cpu_reference_hip.h b/libflashinfer/utils/cpu_reference_hip.h index 049731d1bd..bcb096d456 100644 --- a/libflashinfer/utils/cpu_reference_hip.h +++ b/libflashinfer/utils/cpu_reference_hip.h @@ -309,7 +309,7 @@ single_mha(const std::vector &q, // that is, 32 attention scores. For CDNA3, these 32 scores // are spread across 4 threads. for (auto i = 0ul; i < 128; ++i) { - std::cout << att[i] << " "; + std::cout << att[i] / sm_scale << " "; } std::cout << std::endl; } From 2ae831c8476a3922fa3f5072f33ee1e3fc227f9d Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Fri, 12 Sep 2025 11:17:51 -0400 Subject: [PATCH 073/109] Debugger for sfrag using pandas --- .../generic/default_prefill_params.cuh | 7 +- .../flashinfer/attention/generic/prefill.cuh | 138 +++++++----------- .../tests/hip/test_single_prefill.cpp | 69 ++++++--- libflashinfer/utils/cpu_reference_hip.h | 2 +- .../utils/flashinfer_prefill_ops.hip.h | 3 +- 5 files changed, 111 insertions(+), 108 deletions(-) diff --git a/libflashinfer/include/flashinfer/attention/generic/default_prefill_params.cuh b/libflashinfer/include/flashinfer/attention/generic/default_prefill_params.cuh index 2b558d03cf..6d89c85889 100644 --- a/libflashinfer/include/flashinfer/attention/generic/default_prefill_params.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/default_prefill_params.cuh @@ -57,6 +57,7 @@ struct SinglePrefillParams float sm_scale; float rope_rcp_scale; float rope_rcp_theta; + uint32_t debug_thread_id; uint32_t partition_kv; @@ -91,7 +92,8 @@ struct SinglePrefillParams float logits_soft_cap, float sm_scale, float rope_scale, - float rope_theta) + float rope_theta, + uint32_t debug_thread_id) : q(q), k(k), v(v), maybe_custom_mask(maybe_custom_mask), o(o), lse(lse), maybe_alibi_slopes(maybe_alibi_slopes), group_size(num_qo_heads / num_kv_heads), num_qo_heads(num_qo_heads), @@ -101,7 +103,8 @@ struct SinglePrefillParams v_stride_n(kv_stride_n), v_stride_h(kv_stride_h), head_dim(head_dim), window_left(window_left), logits_soft_cap(logits_soft_cap), sm_scale(sm_scale), rope_rcp_scale(1. / rope_scale), - rope_rcp_theta(1. / rope_theta), partition_kv(false) + rope_rcp_theta(1. / rope_theta), debug_thread_id(debug_thread_id), + partition_kv(false) { } diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh index 2aa8fa9a89..e83e60c1e0 100644 --- a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -1529,9 +1529,6 @@ __device__ __forceinline__ void compute_sfm_v( { mma::m16k16_rowsum_f16f16f32(d[mma_q], s_frag_f16[mma_q][mma_kv]); - if (warp_idx == 0 && lane_idx == 1) { - printf("D values : %f\n", *d[mma_q]); - } } else { #if defined(PLATFORM_HIP_DEVICE) @@ -2021,8 +2018,8 @@ template __device__ __forceinline__ void debug_write_sfrag_to_scratch( typename KTraits::DTypeQKAccum ( *s_frag)[KTraits::NUM_MMA_KV][KTraits::HALF_ELEMS_PER_THREAD], - smem_t *scratch, - const dim3 tid = threadIdx) + const dim3 tid = threadIdx, + uint32_t debug_thread_id = 0) { using DTypeQKAccum = typename KTraits::DTypeQKAccum; constexpr uint32_t NUM_MMA_Q = KTraits::NUM_MMA_Q; @@ -2032,52 +2029,18 @@ __device__ __forceinline__ void debug_write_sfrag_to_scratch( lane_idx = tid.x; // For CDNA3 with 16×4 thread layout: - uint32_t row = lane_idx % 16; - uint32_t col = lane_idx / 16; - - // Total matrix dimensions - constexpr uint32_t total_cols = NUM_MMA_KV * 16; - uint32_t offset = - scratch->template get_permuted_offset(row, col); - auto halfCastedBase = reinterpret_cast<__half *>(scratch->base); + uint32_t row = lane_idx / 16; + uint32_t col = lane_idx % 16; // Write all thread's fragments to shared memory for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { for (uint32_t mma_kv = 0; mma_kv < NUM_MMA_KV; ++mma_kv) { - if (lane_idx == 1 && warp_idx == 0) { - printf("debug_write_sfrag_to_scratch..............\n"); - for (auto reg_id = 0ul; reg_id < HALF_ELEMS_PER_THREAD; - ++reg_id) - { - auto tmp = s_frag[mma_q][mma_kv][reg_id]; - printf("s_frag[%u][%u][%lu] : %f ", mma_q, mma_kv, reg_id, - float(tmp)); - } - printf("\n"); - } - - for (auto reg_id = 0ul; reg_id < HALF_ELEMS_PER_THREAD; ++reg_id) { - auto tmp = s_frag[mma_q][mma_kv][reg_id]; - *(halfCastedBase + offset * 4 + reg_id) = tmp; + if (lane_idx == debug_thread_id && warp_idx == 0) { + printf("%.6f %.6f %.6f %.6f\n", s_frag[mma_q][mma_kv][0], + s_frag[mma_q][mma_kv][1], s_frag[mma_q][mma_kv][2], + s_frag[mma_q][mma_kv][3]); } - - if (lane_idx == 1 && warp_idx == 0) { - uint32_t a_frag[KTraits::INT32_ELEMS_PER_THREAD]; - scratch->load_fragment(offset, a_frag); - auto frag_T = reinterpret_cast<__half *>(a_frag); - for (auto i = 0ul; i < 4; ++i) { - printf("----scratch-----> %f \n", (float)(*(frag_T + i))); - } - - printf("KTraits::UPCAST_STRIDE_O %d\n", - KTraits::UPCAST_STRIDE_O); - } - - offset = scratch->template advance_offset_by_row< - 16, KTraits::UPCAST_STRIDE_O>(offset); } - offset = scratch->template advance_offset_by_column<4>(offset, mma_q) - - 16 * NUM_MMA_KV * KTraits::UPCAST_STRIDE_O; } __syncthreads(); } @@ -2336,13 +2299,13 @@ SinglePrefillWithKVCacheDevice(const Params params, smem_t scratch( smem_storage.qk_scratch); - if (warp_idx == 0 && lane_idx == 0) { - printf("partition_kv : %d\n", partition_kv); - printf("kv_len : %d\n", kv_len); - printf("max_chunk_size : %d\n", max_chunk_size); - printf("chunk_end : %d\n", chunk_end); - printf("chunk_start : %d\n", chunk_start); - } + // if (warp_idx == 0 && lane_idx == 0) { + // printf("partition_kv : %d\n", partition_kv); + // printf("kv_len : %d\n", kv_len); + // printf("max_chunk_size : %d\n", max_chunk_size); + // printf("chunk_end : %d\n", chunk_end); + // printf("chunk_start : %d\n", chunk_start); + // } #if 0 // Test Q if (warp_idx == 0 && lane_idx == 0) { @@ -2370,16 +2333,16 @@ SinglePrefillWithKVCacheDevice(const Params params, // Prints the (NUM_MMA_KV*16) x (NUM_MMA_D*16) matrix from global mem. if (warp_idx == 0 && lane_idx == 0) { - printf("\n DEBUG K Global (HIP):\n"); - printf("k_stride_n : %d\n", k_stride_n); - printf("k_stride_h : %d\n", k_stride_h); - printf("kv_head_idx : %d\n", kv_head_idx); - printf("num_qo_heads : %d\n", num_qo_heads); - printf("num_kv_heads : %d\n", num_kv_heads); - printf("k_stride_n : %d\n", k_stride_n); - printf("KTraits::NUM_MMA_D_QK : %d\n", KTraits::NUM_MMA_D_QK); - printf("NUM_MMA_KV : %d\n", NUM_MMA_KV); - printf("NUM_MMA_Q : %d\n", NUM_MMA_Q); + // printf("\n DEBUG K Global (HIP):\n"); + // printf("k_stride_n : %d\n", k_stride_n); + // printf("k_stride_h : %d\n", k_stride_h); + // printf("kv_head_idx : %d\n", kv_head_idx); + // printf("num_qo_heads : %d\n", num_qo_heads); + // printf("num_kv_heads : %d\n", num_kv_heads); + // printf("k_stride_n : %d\n", k_stride_n); + // printf("KTraits::NUM_MMA_D_QK : %d\n", KTraits::NUM_MMA_D_QK); + // printf("NUM_MMA_KV : %d\n", NUM_MMA_KV); + // printf("NUM_MMA_Q : %d\n", NUM_MMA_Q); #if 0 DTypeKV *k_ptr_tmp = k + (chunk_start + warp_idx * KV_THR_LAYOUT_ROW + @@ -2447,29 +2410,32 @@ SinglePrefillWithKVCacheDevice(const Params params, compute_qk(&qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag); #if Debug - debug_write_sfrag_to_scratch(s_frag, &scratch, tid); + debug_write_sfrag_to_scratch(s_frag, tid, + params.debug_thread_id); - if (warp_idx == 0 && lane_idx == 0) { - printf("s_frag results after compute_qk: \n"); - uint32_t scratch_offset_r_debug; - for (auto i = 0; i < NUM_MMA_KV * 16; ++i) { - for (auto j = 0; j < NUM_MMA_D_QK * 2; ++j) { - scratch_offset_r_debug = - scratch - .template get_permuted_offset( - i, j); - uint32_t a_frag[KTraits::INT32_ELEMS_PER_THREAD]; - scratch.load_fragment(scratch_offset_r_debug, a_frag); - auto frag_T = reinterpret_cast<__half *>(a_frag); - for (auto i = 0ul; i < 4; ++i) { - printf("%f ", (float)(*(frag_T + i))); - } - } - printf("\n"); - scratch.template advance_offset_by_row< - 16, KTraits::UPCAST_STRIDE_K>(scratch_offset_r_debug); - } - } + // if (warp_idx == 0 && lane_idx == 0) { + // printf("s_frag results after compute_qk: \n"); + // uint32_t scratch_offset_r_debug; + // for (auto i = 0; i < NUM_MMA_KV * 16; ++i) { + // for (auto j = 0; j < NUM_MMA_D_QK * 2; ++j) { + // scratch_offset_r_debug = + // scratch + // .template + // get_permuted_offset( + // i, j); + // uint32_t a_frag[KTraits::INT32_ELEMS_PER_THREAD]; + // scratch.load_fragment(scratch_offset_r_debug, + // a_frag); auto frag_T = reinterpret_cast<__half + // *>(a_frag); for (auto i = 0ul; i < 4; ++i) { + // printf("%f ", (float)(*(frag_T + i))); + // } + // } + // printf("\n"); + // scratch.template advance_offset_by_row< + // 16, + // KTraits::UPCAST_STRIDE_K>(scratch_offset_r_debug); + // } + // } // if (warp_idx == 0 && lane_idx == 0 && iter == 0) { // printf("s_frag results after compute_qk: \n"); @@ -2583,14 +2549,14 @@ SinglePrefillWithKVCacheDevice(const Params params, update_mdo_states(variant, s_frag, o_frag, m, d, warp_idx, lane_idx); -#if Debug +#if Debug1 if (warp_idx == 0 && lane_idx == 0 && iter == 0) { printf("Max value for first 32 cols of row 0 %f\n", m[0][0]); printf("D value for first 32 cols of row 0 %f\n", d[0][0]); } #endif -#if Debug +#if Debug1 if (warp_idx == 0 && lane_idx == 0 && iter == 0) { printf("gpu_iface::math::ptx_exp2(0) = %f\n", gpu_iface::math::ptx_exp2(0.f)); diff --git a/libflashinfer/tests/hip/test_single_prefill.cpp b/libflashinfer/tests/hip/test_single_prefill.cpp index b471bcd752..d32bd066ff 100644 --- a/libflashinfer/tests/hip/test_single_prefill.cpp +++ b/libflashinfer/tests/hip/test_single_prefill.cpp @@ -31,10 +31,10 @@ void _TestComputeQKCorrectness(size_t qo_len, float rtol = 1e-3, float atol = 1e-3) { - std::cout << "Testing compute_qk: qo_len=" << qo_len - << ", kv_len=" << kv_len << ", num_qo_heads=" << num_qo_heads - << ", num_kv_heads=" << num_kv_heads << ", head_dim=" << head_dim - << std::endl; + // std::cout << "Testing compute_qk: qo_len=" << qo_len + // << ", kv_len=" << kv_len << ", num_qo_heads=" << num_qo_heads + // << ", num_kv_heads=" << num_kv_heads << ", head_dim=" << head_dim + // << std::endl; // Generate test data (same as original test) std::vector q(qo_len * num_qo_heads * head_dim); @@ -78,13 +78,13 @@ void _TestComputeQKCorrectness(size_t qo_len, float *qk_scores_d; FI_GPU_CALL(hipMalloc(&qk_scores_d, qk_output_size * sizeof(float))); - std::cout << "Debug: Kernel launch parameters:" << std::endl; - std::cout << " qo_len=" << qo_len << ", kv_len=" << kv_len << std::endl; - std::cout << " num_qo_heads=" << num_qo_heads - << ", num_kv_heads=" << num_kv_heads << std::endl; - std::cout << " head_dim=" << head_dim << std::endl; - std::cout << " qk_output_size=" << qk_output_size << std::endl; - std::cout << " Launching ComputeQKStubCaller..." << std::endl; + // std::cout << "Debug: Kernel launch parameters:" << std::endl; + // std::cout << " qo_len=" << qo_len << ", kv_len=" << kv_len << std::endl; + // std::cout << " num_qo_heads=" << num_qo_heads + // << ", num_kv_heads=" << num_kv_heads << std::endl; + // std::cout << " head_dim=" << head_dim << std::endl; + // std::cout << " qk_output_size=" << qk_output_size << std::endl; + // std::cout << " Launching ComputeQKStubCaller..." << std::endl; // Call ComputeQKStubCaller instead of SinglePrefillWithKVCache hipError_t status = @@ -94,8 +94,8 @@ void _TestComputeQKCorrectness(size_t qo_len, num_qo_heads, num_kv_heads, qo_len, kv_len, head_dim, causal, kv_layout, pos_encoding_mode, use_fp16_qk_reduction); - std::cout << " Kernel launch status: " << hipGetErrorString(status) - << std::endl; + // std::cout << " Kernel launch status: " << hipGetErrorString(status) + // << std::endl; EXPECT_EQ(status, hipSuccess) << "ComputeQKStubCaller kernel launch failed, error message: " << hipGetErrorString(status); @@ -138,7 +138,7 @@ void _TestComputeQKCorrectness(size_t qo_len, } } } - +#if 0 // Calculate and report accuracy float result_accuracy = 1.0f - float(num_results_error_atol) / float(comparison_size); @@ -165,7 +165,7 @@ void _TestComputeQKCorrectness(size_t qo_len, EXPECT_GT(result_accuracy, 0.80) << "compute_qk accuracy too low"; // Start with 80% EXPECT_FALSE(nan_detected) << "NaN detected in compute_qk results"; - +#endif // Cleanup FI_GPU_CALL(hipFree(q_d)); FI_GPU_CALL(hipFree(k_d)); @@ -187,6 +187,7 @@ void _TestSinglePrefillKernelCorrectness(size_t qo_len, QKVLayout kv_layout, PosEncodingMode pos_encoding_mode, bool use_fp16_qk_reduction, + uint32_t debug_thread_id, float rtol = 1e-3, float atol = 1e-3) { @@ -231,7 +232,7 @@ void _TestSinglePrefillKernelCorrectness(size_t qo_len, q_d, k_d, v_d, o_d, tmp_d, /*lse=*/nullptr, num_qo_heads, num_kv_heads, qo_len, kv_len, head_dim, causal, kv_layout, pos_encoding_mode, - use_fp16_qk_reduction); + use_fp16_qk_reduction, debug_thread_id); EXPECT_EQ(status, hipSuccess) << "SinglePrefillWithKVCache kernel launch failed, error message: " @@ -280,6 +281,7 @@ void _TestSinglePrefillKernelCorrectness(size_t qo_len, // for(auto i: att_out) { // std::cout << i << "\n"; // } +#if 0 float result_accuracy = 1. - float(num_results_error_atol) / float(o_ref.size()); std::cout << "num_qo_heads=" << num_qo_heads @@ -293,7 +295,7 @@ void _TestSinglePrefillKernelCorrectness(size_t qo_len, EXPECT_GT(result_accuracy, 0.90) << "Result correctness test failed."; EXPECT_FALSE(nan_detected) << "Nan detected in the result."; - +#endif FI_GPU_CALL(hipFree(q_d)); FI_GPU_CALL(hipFree(k_d)); FI_GPU_CALL(hipFree(v_d)); @@ -557,6 +559,7 @@ int main(int argc, char **argv) // return RUN_ALL_TESTS(); using DTypeIn = __half; using DTypeO = __half; + uint32_t debug_thread_id = 0; bool use_fp16_qk_reduction = false; size_t qo_len = 128; size_t kv_len = 128; @@ -566,10 +569,40 @@ int main(int argc, char **argv) size_t pos_encoding_mode = 0; // 1 == kRopeLLama size_t kv_layout = 0; + for (int i = 1; i < argc; i++) { + std::string arg = argv[i]; + + if (arg == "--thread" && i + 1 < argc) { + debug_thread_id = std::stoi(argv[++i]); + std::cout << "Debug thread ID set to: " << debug_thread_id + << std::endl; + } + else if (arg == "--qo_len" && i + 1 < argc) { + qo_len = std::stoi(argv[++i]); + } + else if (arg == "--kv_len" && i + 1 < argc) { + kv_len = std::stoi(argv[++i]); + } + else if (arg == "--heads" && i + 1 < argc) { + num_heads = std::stoi(argv[++i]); + } + else if (arg == "--help") { + std::cout + << "Usage: " << argv[0] << " [options]\n" + << "Options:\n" + << " --thread Debug thread ID (0-255 for 4 warps)\n" + << " --qo_len Query/Output length (default: 128)\n" + << " --kv_len Key/Value length (default: 128)\n" + << " --heads Number of heads (default: 1)\n" + << " --help Show this help message\n"; + return 0; + } + } + _TestSinglePrefillKernelCorrectness( qo_len, kv_len, num_heads, num_heads, head_dim, causal, QKVLayout(kv_layout), PosEncodingMode(pos_encoding_mode), - use_fp16_qk_reduction); + use_fp16_qk_reduction, debug_thread_id); } // int main(int argc, char **argv) diff --git a/libflashinfer/utils/cpu_reference_hip.h b/libflashinfer/utils/cpu_reference_hip.h index bcb096d456..cd933426fd 100644 --- a/libflashinfer/utils/cpu_reference_hip.h +++ b/libflashinfer/utils/cpu_reference_hip.h @@ -303,7 +303,7 @@ single_mha(const std::vector &q, max_val = std::max(max_val, att[kv_idx]); } -#if Debug +#if Debug1 if (qo_head_idx == 0) { // for qo_len = 128, each warp on the GPU will store 128/4, // that is, 32 attention scores. For CDNA3, these 32 scores diff --git a/libflashinfer/utils/flashinfer_prefill_ops.hip.h b/libflashinfer/utils/flashinfer_prefill_ops.hip.h index 4f3564365e..a08ce6754e 100644 --- a/libflashinfer/utils/flashinfer_prefill_ops.hip.h +++ b/libflashinfer/utils/flashinfer_prefill_ops.hip.h @@ -124,6 +124,7 @@ hipError_t SinglePrefillWithKVCache( QKVLayout kv_layout = QKVLayout::kNHD, PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, bool use_fp16_qk_reduction = false, + uint32_t debug_thread_id = 0, std::optional maybe_sm_scale = std::nullopt, float rope_scale = 1.f, float rope_theta = 1e4, @@ -155,7 +156,7 @@ hipError_t SinglePrefillWithKVCache( head_dim, /*window_left=*/-1, /*logits_soft_cap=*/8.f, sm_scale, - rope_scale, rope_theta); + rope_scale, rope_theta, debug_thread_id); return SinglePrefillWithKVCacheDispatched< HEAD_DIM, HEAD_DIM, POS_ENCODING_MODE, USE_FP16_QK_REDUCTION, MASK_MODE, AttentionVariant, From 20d7e2a83cf04e4cd43cc3eb7dcb56a4c9a237f7 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Fri, 12 Sep 2025 13:57:18 -0400 Subject: [PATCH 074/109] Tester scripts --- .gitignore | 3 + sfrag_tester_script.py | 200 +++++++++++++++++++++++++++++++++++++++++ test_prefill_sfrag.sh | 12 +++ 3 files changed, 215 insertions(+) create mode 100644 sfrag_tester_script.py create mode 100644 test_prefill_sfrag.sh diff --git a/.gitignore b/.gitignore index 9f58c56551..a8c2776f19 100644 --- a/.gitignore +++ b/.gitignore @@ -188,3 +188,6 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ + +*.csv +*.pkl diff --git a/sfrag_tester_script.py b/sfrag_tester_script.py new file mode 100644 index 0000000000..8d9b0790ac --- /dev/null +++ b/sfrag_tester_script.py @@ -0,0 +1,200 @@ +#!/usr/bin/env python3 +import re +import sys + +import numpy as np +import pandas as pd + + +def parse_sfrag_log(log_file_path, num_threads=64, num_mma_q=2, num_mma_kv=4): + """ + Parse s_frag debug output from multiple thread runs into a 32x128 DataFrame. + + Args: + log_file_path: Path to the concatenated log file + num_threads: Number of threads (default 64) + num_mma_q: NUM_MMA_Q value (default 2) + num_mma_kv: NUM_MMA_KV value (default 4) + + Returns: + DataFrame with shape (32, 128) containing the s_frag values + """ + + # Initialize the result matrix + matrix = np.zeros((32, 128)) + + # Read the log file + with open(log_file_path, "r") as f: + lines = f.readlines() + + # Track current thread and value position + current_thread = -1 + value_idx = 0 + values = [] + + for line in lines: + line = line.strip() + + # Check if this is a thread ID line + if line.startswith("Debug thread ID set to:"): + if current_thread >= 0 and values: + # Process the previous thread's data + populate_matrix(matrix, current_thread, values) + + # Extract thread ID + current_thread = int(line.split(":")[-1].strip()) + values = [] + value_idx = 0 + + # Otherwise, it should be a line of float values + elif line and current_thread >= 0: + # Parse the float values from the line + line_values = [float(x) for x in line.split()] + values.extend(line_values) + + # Don't forget to process the last thread + if current_thread >= 0 and values: + populate_matrix(matrix, current_thread, values) + + # Create DataFrame with appropriate column and row labels + df = pd.DataFrame(matrix) + df.index = [f"Row_{i}" for i in range(32)] + df.columns = [f"Col_{i}" for i in range(128)] + + return df + + +def populate_matrix(matrix, thread_id, values): + """ + Populate the matrix with values from a specific thread according to the + pattern the values are printed. + + Args: + matrix: The 32x128 numpy array to populate + thread_id: The thread ID (0-63) + values: List of 64 float values from this thread + """ + + if len(values) != 64: + print(f"Warning: Thread {thread_id} has {len(values)} values instead of 64") + return + + # Calculate base row and column for this thread + row_base = (thread_id // 16) * 4 + col_base = thread_id % 16 + + # Split values into two calls (32 values each) + first_call = values[:32] + second_call = values[32:] + + # Process first call (columns 0-63) + process_call_values(matrix, first_call, row_base, col_base, col_offset=0) + + # Process second call (columns 64-127) + process_call_values(matrix, second_call, row_base, col_base, col_offset=64) + + +def process_call_values(matrix, values, row_base, col_base, col_offset): + """ + Process 32 values from one call according to the nested loop pattern. + + Args: + matrix: The matrix to populate + values: 32 values from one call + row_base: Base row for this thread + col_base: Base column for this thread + col_offset: Column offset (0 for first call, 64 for second call) + """ + + value_idx = 0 + current_row = row_base + current_col = col_base + col_offset + + # Outer loop: 2 iterations (NUM_MMA_Q) + for mma_q in range(2): + # Middle loop: 4 iterations (NUM_MMA_KV) + for mma_kv in range(4): + # Inner loop: 4 values + for i in range(4): + if value_idx < len(values): + # Place values in consecutive rows, same column + matrix[current_row + i, current_col] = values[value_idx] + value_idx += 1 + + # After inner loop, move to next column set + current_col += 16 + + # After middle loop, reset column and move to next row set + current_col = col_base + col_offset + current_row += 16 + + +def print_matrix_info(df): + """Print summary information about the populated matrix.""" + print(f"Matrix shape: {df.shape}") + print(f"Non-zero elements: {(df != 0).sum().sum()}") + print(f"Matrix statistics:") + print(f" Min: {df.min().min():.6f}") + print(f" Max: {df.max().max():.6f}") + print(f" Mean: {df.mean().mean():.6f}") + print(f" Std: {df.values.std():.6f}") + + +def save_results(df, output_prefix="sfrag_matrix"): + """Save the DataFrame in multiple formats.""" + # Save as CSV + csv_file = f"{output_prefix}.csv" + df.to_csv(csv_file) + print(f"Saved matrix to {csv_file}") + + # Save as pickle for exact preservation + pickle_file = f"{output_prefix}.pkl" + df.to_pickle(pickle_file) + print(f"Saved matrix to {pickle_file}") + + # Save a heatmap visualization + try: + import matplotlib.pyplot as plt + import seaborn as sns + + plt.figure(figsize=(20, 6)) + sns.heatmap( + df.iloc[:32, :64], cmap="RdBu_r", center=0, cbar_kws={"label": "Value"} + ) + plt.title("S_FRAG Matrix Heatmap (First 32x64 block)") + plt.xlabel("Column") + plt.ylabel("Row") + plt.tight_layout() + plt.savefig(f"{output_prefix}_heatmap.png", dpi=150) + plt.close() + print(f"Saved heatmap to {output_prefix}_heatmap.png") + except ImportError: + print("Matplotlib/Seaborn not available, skipping heatmap") + + +def main(): + if len(sys.argv) < 2: + print("Usage: python parse_sfrag.py [output_prefix]") + sys.exit(1) + + log_file = sys.argv[1] + output_prefix = sys.argv[2] if len(sys.argv) > 2 else "sfrag_matrix" + + print(f"Parsing log file: {log_file}") + + # Parse the log file + df = parse_sfrag_log(log_file) + + # Print summary information + print_matrix_info(df) + + # Save results + save_results(df, output_prefix) + + # Print a sample of the matrix + print("\nSample of the matrix (first 8x8 block):") + print(df.iloc[:8, :8]) + + +if __name__ == "__main__": + main() diff --git a/test_prefill_sfrag.sh b/test_prefill_sfrag.sh new file mode 100644 index 0000000000..a6de32ef3c --- /dev/null +++ b/test_prefill_sfrag.sh @@ -0,0 +1,12 @@ +#!/bin/bash +# run_all_threads.sh + +OUTPUT_FILE="sfrag_combined.log" +> $OUTPUT_FILE # Clear the file + +for thread in {0..63}; do + echo "Running thread $thread..." + ./a.out --thread $thread >> $OUTPUT_FILE 2>&1 +done + +echo "All threads complete. Output in $OUTPUT_FILE" From 3b32ec0304ab9b1e422e245601353c9a00156b97 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Fri, 12 Sep 2025 16:05:01 -0400 Subject: [PATCH 075/109] Add warp-level debug prints for sfrag. --- .../generic/default_prefill_params.cuh | 6 ++++-- .../flashinfer/attention/generic/prefill.cuh | 9 +++++---- libflashinfer/tests/hip/test_single_prefill.cpp | 11 +++++++++-- .../utils/flashinfer_prefill_ops.hip.h | 17 +++++++++-------- 4 files changed, 27 insertions(+), 16 deletions(-) diff --git a/libflashinfer/include/flashinfer/attention/generic/default_prefill_params.cuh b/libflashinfer/include/flashinfer/attention/generic/default_prefill_params.cuh index 6d89c85889..55942d775d 100644 --- a/libflashinfer/include/flashinfer/attention/generic/default_prefill_params.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/default_prefill_params.cuh @@ -58,6 +58,7 @@ struct SinglePrefillParams float rope_rcp_scale; float rope_rcp_theta; uint32_t debug_thread_id; + uint32_t debug_warp_id; uint32_t partition_kv; @@ -93,7 +94,8 @@ struct SinglePrefillParams float sm_scale, float rope_scale, float rope_theta, - uint32_t debug_thread_id) + uint32_t debug_thread_id, + uint32_t debug_warp_id) : q(q), k(k), v(v), maybe_custom_mask(maybe_custom_mask), o(o), lse(lse), maybe_alibi_slopes(maybe_alibi_slopes), group_size(num_qo_heads / num_kv_heads), num_qo_heads(num_qo_heads), @@ -104,7 +106,7 @@ struct SinglePrefillParams window_left(window_left), logits_soft_cap(logits_soft_cap), sm_scale(sm_scale), rope_rcp_scale(1. / rope_scale), rope_rcp_theta(1. / rope_theta), debug_thread_id(debug_thread_id), - partition_kv(false) + debug_warp_id(debug_warp_id), partition_kv(false) { } diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh index e83e60c1e0..5ff40d3ce0 100644 --- a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -2019,7 +2019,8 @@ __device__ __forceinline__ void debug_write_sfrag_to_scratch( typename KTraits::DTypeQKAccum ( *s_frag)[KTraits::NUM_MMA_KV][KTraits::HALF_ELEMS_PER_THREAD], const dim3 tid = threadIdx, - uint32_t debug_thread_id = 0) + uint32_t debug_thread_id = 0, + uint32_t debug_warp_id = 0) { using DTypeQKAccum = typename KTraits::DTypeQKAccum; constexpr uint32_t NUM_MMA_Q = KTraits::NUM_MMA_Q; @@ -2035,7 +2036,7 @@ __device__ __forceinline__ void debug_write_sfrag_to_scratch( // Write all thread's fragments to shared memory for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { for (uint32_t mma_kv = 0; mma_kv < NUM_MMA_KV; ++mma_kv) { - if (lane_idx == debug_thread_id && warp_idx == 0) { + if (lane_idx == debug_thread_id && warp_idx == debug_warp_id) { printf("%.6f %.6f %.6f %.6f\n", s_frag[mma_q][mma_kv][0], s_frag[mma_q][mma_kv][1], s_frag[mma_q][mma_kv][2], s_frag[mma_q][mma_kv][3]); @@ -2410,8 +2411,8 @@ SinglePrefillWithKVCacheDevice(const Params params, compute_qk(&qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag); #if Debug - debug_write_sfrag_to_scratch(s_frag, tid, - params.debug_thread_id); + debug_write_sfrag_to_scratch( + s_frag, tid, params.debug_thread_id, params.debug_warp_id); // if (warp_idx == 0 && lane_idx == 0) { // printf("s_frag results after compute_qk: \n"); diff --git a/libflashinfer/tests/hip/test_single_prefill.cpp b/libflashinfer/tests/hip/test_single_prefill.cpp index d32bd066ff..5bb86ce1ac 100644 --- a/libflashinfer/tests/hip/test_single_prefill.cpp +++ b/libflashinfer/tests/hip/test_single_prefill.cpp @@ -188,6 +188,7 @@ void _TestSinglePrefillKernelCorrectness(size_t qo_len, PosEncodingMode pos_encoding_mode, bool use_fp16_qk_reduction, uint32_t debug_thread_id, + uint32_t debug_warp_id, float rtol = 1e-3, float atol = 1e-3) { @@ -232,7 +233,7 @@ void _TestSinglePrefillKernelCorrectness(size_t qo_len, q_d, k_d, v_d, o_d, tmp_d, /*lse=*/nullptr, num_qo_heads, num_kv_heads, qo_len, kv_len, head_dim, causal, kv_layout, pos_encoding_mode, - use_fp16_qk_reduction, debug_thread_id); + use_fp16_qk_reduction, debug_thread_id, debug_warp_id); EXPECT_EQ(status, hipSuccess) << "SinglePrefillWithKVCache kernel launch failed, error message: " @@ -560,6 +561,7 @@ int main(int argc, char **argv) using DTypeIn = __half; using DTypeO = __half; uint32_t debug_thread_id = 0; + uint32_t debug_warp_id = 0; bool use_fp16_qk_reduction = false; size_t qo_len = 128; size_t kv_len = 128; @@ -577,6 +579,10 @@ int main(int argc, char **argv) std::cout << "Debug thread ID set to: " << debug_thread_id << std::endl; } + else if (arg == "--warp" && i + 1 < argc) { + debug_warp_id = std::stoi(argv[++i]); + std::cout << "Debug warp ID set to: " << debug_warp_id << std::endl; + } else if (arg == "--qo_len" && i + 1 < argc) { qo_len = std::stoi(argv[++i]); } @@ -591,6 +597,7 @@ int main(int argc, char **argv) << "Usage: " << argv[0] << " [options]\n" << "Options:\n" << " --thread Debug thread ID (0-255 for 4 warps)\n" + << " --warp Debug warp ID (0-3 for 4 warps)\n" << " --qo_len Query/Output length (default: 128)\n" << " --kv_len Key/Value length (default: 128)\n" << " --heads Number of heads (default: 1)\n" @@ -602,7 +609,7 @@ int main(int argc, char **argv) _TestSinglePrefillKernelCorrectness( qo_len, kv_len, num_heads, num_heads, head_dim, causal, QKVLayout(kv_layout), PosEncodingMode(pos_encoding_mode), - use_fp16_qk_reduction, debug_thread_id); + use_fp16_qk_reduction, debug_thread_id, debug_warp_id); } // int main(int argc, char **argv) diff --git a/libflashinfer/utils/flashinfer_prefill_ops.hip.h b/libflashinfer/utils/flashinfer_prefill_ops.hip.h index a08ce6754e..5c282c4b8a 100644 --- a/libflashinfer/utils/flashinfer_prefill_ops.hip.h +++ b/libflashinfer/utils/flashinfer_prefill_ops.hip.h @@ -125,6 +125,7 @@ hipError_t SinglePrefillWithKVCache( PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, bool use_fp16_qk_reduction = false, uint32_t debug_thread_id = 0, + uint32_t debug_warp_id = 0, std::optional maybe_sm_scale = std::nullopt, float rope_scale = 1.f, float rope_theta = 1e4, @@ -149,14 +150,14 @@ hipError_t SinglePrefillWithKVCache( MaskMode::kCustom), /*use_sliding_window=*/false, /*use_logits_soft_cap=*/true, /*use_alibi=*/false>; - Params params(q, k, v, /*custom_mask=*/nullptr, o, lse, - /*alibi_slopes=*/nullptr, num_qo_heads, - num_kv_heads, qo_len, kv_len, qo_stride_n, - qo_stride_h, kv_stride_n, kv_stride_h, - head_dim, - /*window_left=*/-1, - /*logits_soft_cap=*/8.f, sm_scale, - rope_scale, rope_theta, debug_thread_id); + Params params( + q, k, v, /*custom_mask=*/nullptr, o, lse, + /*alibi_slopes=*/nullptr, num_qo_heads, + num_kv_heads, qo_len, kv_len, qo_stride_n, + qo_stride_h, kv_stride_n, kv_stride_h, head_dim, + /*window_left=*/-1, + /*logits_soft_cap=*/8.f, sm_scale, rope_scale, + rope_theta, debug_thread_id, debug_warp_id); return SinglePrefillWithKVCacheDispatched< HEAD_DIM, HEAD_DIM, POS_ENCODING_MODE, USE_FP16_QK_REDUCTION, MASK_MODE, AttentionVariant, From 6041be1eaddaf0d14816725fa0b0cbefbe0e55d3 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Sat, 13 Sep 2025 14:51:05 -0400 Subject: [PATCH 076/109] More debugging --- .../flashinfer/attention/generic/prefill.cuh | 60 ++------- libflashinfer/utils/cpu_reference_hip.h | 4 +- sfrag_tester_script.py | 118 +++++++++++++----- 3 files changed, 105 insertions(+), 77 deletions(-) diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh index 5ff40d3ce0..2b36d3cddc 100644 --- a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -2025,14 +2025,9 @@ __device__ __forceinline__ void debug_write_sfrag_to_scratch( using DTypeQKAccum = typename KTraits::DTypeQKAccum; constexpr uint32_t NUM_MMA_Q = KTraits::NUM_MMA_Q; constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; - constexpr uint32_t HALF_ELEMS_PER_THREAD = KTraits::HALF_ELEMS_PER_THREAD; const uint32_t warp_idx = get_warp_idx(tid.y, tid.z), lane_idx = tid.x; - // For CDNA3 with 16×4 thread layout: - uint32_t row = lane_idx / 16; - uint32_t col = lane_idx % 16; - // Write all thread's fragments to shared memory for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { for (uint32_t mma_kv = 0; mma_kv < NUM_MMA_KV; ++mma_kv) { @@ -2406,54 +2401,25 @@ SinglePrefillWithKVCacheDevice(const Params params, rope_freq, tid); block.sync(); } +#if Debug + if (warp_idx == 0 && lane_idx == 0) { + uint32_t b_frag[KTraits::INT32_ELEMS_PER_THREAD]; + k_smem.load_fragment(k_smem_offset_r, b_frag); + auto frag_T = reinterpret_cast<__half *>(b_frag); + for (auto reg_id = 0ul; reg_id < 4; ++reg_id) { + for (auto i = 0ul; i < 4; ++i) { + printf("%f ", (float)(*(frag_T + i))); + } + } + } +#endif // compute attention score compute_qk(&qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag); -#if Debug +#if Debug1 debug_write_sfrag_to_scratch( s_frag, tid, params.debug_thread_id, params.debug_warp_id); - - // if (warp_idx == 0 && lane_idx == 0) { - // printf("s_frag results after compute_qk: \n"); - // uint32_t scratch_offset_r_debug; - // for (auto i = 0; i < NUM_MMA_KV * 16; ++i) { - // for (auto j = 0; j < NUM_MMA_D_QK * 2; ++j) { - // scratch_offset_r_debug = - // scratch - // .template - // get_permuted_offset( - // i, j); - // uint32_t a_frag[KTraits::INT32_ELEMS_PER_THREAD]; - // scratch.load_fragment(scratch_offset_r_debug, - // a_frag); auto frag_T = reinterpret_cast<__half - // *>(a_frag); for (auto i = 0ul; i < 4; ++i) { - // printf("%f ", (float)(*(frag_T + i))); - // } - // } - // printf("\n"); - // scratch.template advance_offset_by_row< - // 16, - // KTraits::UPCAST_STRIDE_K>(scratch_offset_r_debug); - // } - // } - - // if (warp_idx == 0 && lane_idx == 0 && iter == 0) { - // printf("s_frag results after compute_qk: \n"); - // for (auto mma_q = 0ul; mma_q < NUM_MMA_Q * 16 * 4; ++mma_q) { - // for (auto mma_kv = 0ul; mma_kv < NUM_MMA_KV*4; ++mma_kv) - // { - // for (auto reg_id = 0ul; reg_id < - // HALF_ELEMS_PER_THREAD; ++reg_id) - // { - // auto tmp = s_frag[mma_q][mma_kv][reg_id]; - // printf("s_frag[%lu][%lu][%lu] : %f ", mma_q, - // mma_kv, reg_id, float(tmp)); - // } - // printf("\n"); - // } - // } - // } #endif logits_transform( params, variant, /*batch_idx=*/0, qo_packed_idx_base, diff --git a/libflashinfer/utils/cpu_reference_hip.h b/libflashinfer/utils/cpu_reference_hip.h index cd933426fd..75d7faf05e 100644 --- a/libflashinfer/utils/cpu_reference_hip.h +++ b/libflashinfer/utils/cpu_reference_hip.h @@ -223,7 +223,7 @@ single_mha(const std::vector &q, kv_layout, HEAD_DIM); #if Debug1 std::cout << "DEBUG: Original Q (CPU): " << '\n'; - for (auto i = 0ul; i < 16; ++i) { + for (auto i = 0ul; i < 128; ++i) { for (int j = 0; j < 64; ++j) { std::cout << (float)q[info.get_q_elem_offset(i, 0, j)] << " "; } @@ -234,7 +234,7 @@ single_mha(const std::vector &q, std::cout << std::endl; std::cout << "DEBUG: Original K (CPU): " << '\n'; - for (auto i = 0ul; i < 64; ++i) { + for (auto i = 0ul; i < 128; ++i) { for (int j = 0ul; j < 64; ++j) { std::cout << (float)k[info.get_kv_elem_offset(i, 0, j)] << " "; } diff --git a/sfrag_tester_script.py b/sfrag_tester_script.py index 8d9b0790ac..32c239cc39 100644 --- a/sfrag_tester_script.py +++ b/sfrag_tester_script.py @@ -6,30 +6,33 @@ import pandas as pd -def parse_sfrag_log(log_file_path, num_threads=64, num_mma_q=2, num_mma_kv=4): +def parse_sfrag_log( + log_file_path, num_threads=64, num_warps=4, num_mma_q=2, num_mma_kv=4 +): """ - Parse s_frag debug output from multiple thread runs into a 32x128 DataFrame. + Parse s_frag debug output from multiple thread/warp runs into a 128x128 DataFrame. Args: log_file_path: Path to the concatenated log file - num_threads: Number of threads (default 64) + num_threads: Number of threads per warp (default 64) + num_warps: Number of warps (default 4) num_mma_q: NUM_MMA_Q value (default 2) num_mma_kv: NUM_MMA_KV value (default 4) Returns: - DataFrame with shape (32, 128) containing the s_frag values + DataFrame with shape (128, 128) containing the s_frag values """ - # Initialize the result matrix - matrix = np.zeros((32, 128)) + # Initialize the full result matrix (128x128) + matrix = np.zeros((128, 128)) # Read the log file with open(log_file_path, "r") as f: lines = f.readlines() - # Track current thread and value position + # Track current thread, warp and value position current_thread = -1 - value_idx = 0 + current_warp = -1 values = [] for line in lines: @@ -37,50 +40,62 @@ def parse_sfrag_log(log_file_path, num_threads=64, num_mma_q=2, num_mma_kv=4): # Check if this is a thread ID line if line.startswith("Debug thread ID set to:"): - if current_thread >= 0 and values: + if current_thread >= 0 and current_warp >= 0 and values: # Process the previous thread's data - populate_matrix(matrix, current_thread, values) + populate_matrix_with_warp(matrix, current_thread, current_warp, values) # Extract thread ID current_thread = int(line.split(":")[-1].strip()) values = [] - value_idx = 0 + + # Check if this is a warp ID line + elif line.startswith("Debug warp ID set to:"): + current_warp = int(line.split(":")[-1].strip()) # Otherwise, it should be a line of float values - elif line and current_thread >= 0: + elif line and current_thread >= 0 and current_warp >= 0: # Parse the float values from the line - line_values = [float(x) for x in line.split()] - values.extend(line_values) + try: + line_values = [float(x) for x in line.split()] + values.extend(line_values) + except ValueError: + # Skip lines that can't be parsed as floats + continue # Don't forget to process the last thread - if current_thread >= 0 and values: - populate_matrix(matrix, current_thread, values) + if current_thread >= 0 and current_warp >= 0 and values: + populate_matrix_with_warp(matrix, current_thread, current_warp, values) # Create DataFrame with appropriate column and row labels df = pd.DataFrame(matrix) - df.index = [f"Row_{i}" for i in range(32)] + df.index = [f"Row_{i}" for i in range(128)] df.columns = [f"Col_{i}" for i in range(128)] return df -def populate_matrix(matrix, thread_id, values): +def populate_matrix_with_warp(matrix, thread_id, warp_id, values): """ - Populate the matrix with values from a specific thread according to the - pattern the values are printed. + Populate the matrix with values from a specific thread and warp. Args: - matrix: The 32x128 numpy array to populate + matrix: The 128x128 numpy array to populate thread_id: The thread ID (0-63) + warp_id: The warp ID (0-3) values: List of 64 float values from this thread """ if len(values) != 64: - print(f"Warning: Thread {thread_id} has {len(values)} values instead of 64") + print( + f"Warning: Thread {thread_id} Warp {warp_id} has {len(values)} values instead of 64" + ) return - # Calculate base row and column for this thread - row_base = (thread_id // 16) * 4 + # Calculate base row and column for this thread within its warp + # Each warp handles 32 rows (warp 0: rows 0-31, warp 1: rows 32-63, etc.) + warp_row_offset = warp_id * 32 + thread_row_base = (thread_id // 16) * 4 + row_base = warp_row_offset + thread_row_base col_base = thread_id % 16 # Split values into two calls (32 values each) @@ -139,6 +154,17 @@ def print_matrix_info(df): print(f" Mean: {df.mean().mean():.6f}") print(f" Std: {df.values.std():.6f}") + # Check which warps have been populated + print("\nWarp population check:") + for warp in range(4): + start_row = warp * 32 + end_row = (warp + 1) * 32 + warp_data = df.iloc[start_row:end_row, :] + non_zero = (warp_data != 0).sum().sum() + print( + f" Warp {warp} (rows {start_row}-{end_row-1}): {non_zero} non-zero elements" + ) + def save_results(df, output_prefix="sfrag_matrix"): """Save the DataFrame in multiple formats.""" @@ -157,17 +183,47 @@ def save_results(df, output_prefix="sfrag_matrix"): import matplotlib.pyplot as plt import seaborn as sns - plt.figure(figsize=(20, 6)) + plt.figure(figsize=(20, 20)) sns.heatmap( - df.iloc[:32, :64], cmap="RdBu_r", center=0, cbar_kws={"label": "Value"} + df, + cmap="RdBu_r", + center=0, + cbar_kws={"label": "Value"}, + xticklabels=16, + yticklabels=16, # Show every 16th label ) - plt.title("S_FRAG Matrix Heatmap (First 32x64 block)") + plt.title("S_FRAG Matrix Heatmap (Full 128x128)") plt.xlabel("Column") plt.ylabel("Row") + + # Add grid lines to show warp boundaries + for i in range(1, 4): + plt.axhline(y=i * 32, color="black", linewidth=2, alpha=0.5) + plt.tight_layout() plt.savefig(f"{output_prefix}_heatmap.png", dpi=150) plt.close() print(f"Saved heatmap to {output_prefix}_heatmap.png") + + # Also save individual warp heatmaps + fig, axes = plt.subplots(2, 2, figsize=(20, 20)) + for warp in range(4): + ax = axes[warp // 2, warp % 2] + start_row = warp * 32 + end_row = (warp + 1) * 32 + warp_data = df.iloc[start_row:end_row, :] + sns.heatmap( + warp_data, cmap="RdBu_r", center=0, ax=ax, cbar_kws={"label": "Value"} + ) + ax.set_title(f"Warp {warp} (Rows {start_row}-{end_row-1})") + ax.set_xlabel("Column") + ax.set_ylabel("Row (relative to warp)") + + plt.tight_layout() + plt.savefig(f"{output_prefix}_warps_heatmap.png", dpi=150) + plt.close() + print(f"Saved per-warp heatmap to {output_prefix}_warps_heatmap.png") + except ImportError: print("Matplotlib/Seaborn not available, skipping heatmap") @@ -178,7 +234,7 @@ def main(): sys.exit(1) log_file = sys.argv[1] - output_prefix = sys.argv[2] if len(sys.argv) > 2 else "sfrag_matrix" + output_prefix = sys.argv[2] if len(sys.argv) > 2 else "sfrag_matrix_full" print(f"Parsing log file: {log_file}") @@ -195,6 +251,12 @@ def main(): print("\nSample of the matrix (first 8x8 block):") print(df.iloc[:8, :8]) + print("\nSample from each warp (first 4x4 block):") + for warp in range(4): + start_row = warp * 32 + print(f"\nWarp {warp}:") + print(df.iloc[start_row : start_row + 4, :4]) + if __name__ == "__main__": main() From 06faaabe198503b1bb232a8c622e2f1f3a18c80c Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Sat, 13 Sep 2025 16:32:35 -0400 Subject: [PATCH 077/109] Fixed compute_qk --- .../flashinfer/attention/generic/prefill.cuh | 58 ++++++++++++------- test_prefill_sfrag.sh | 11 ++-- 2 files changed, 42 insertions(+), 27 deletions(-) diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh index 2b36d3cddc..d16fc2d71c 100644 --- a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -1065,11 +1065,7 @@ __device__ __forceinline__ void compute_qk( #endif } else { -#if defined(PLATFORM_HIP_DEVICE) - k_smem->load_fragment(*k_smem_offset_r, b_frag); -#else k_smem->load_fragment(*k_smem_offset_r, b_frag); -#endif } *k_smem_offset_r = @@ -1129,8 +1125,13 @@ __device__ __forceinline__ void compute_qk( } } *q_smem_offset_r -= KTraits::NUM_MMA_D_QK * QK_SMEM_COLUMN_ADVANCE; + +#if defined(PLATFORM_HIP_DEVICE) + *k_smem_offset_r -= KTraits::NUM_MMA_D_QK * (QK_SMEM_COLUMN_ADVANCE); +#elif defined(PLATFORM_CUDA_DEVICE) *k_smem_offset_r -= KTraits::NUM_MMA_D_QK * sizeof(typename KTraits::DTypeKV); +#endif } template @@ -2363,9 +2364,27 @@ SinglePrefillWithKVCacheDevice(const Params params, // Note that LDS is loaded collaboratively by all warps and not each // warp accesses the whole K matrix loaded into LDS. Each warp will // only access 1/4 of the K values loaded into LDS. +#endif + +#pragma unroll 1 + for (uint32_t iter = 0; iter < num_iterations; ++iter) { + // for (uint32_t iter = 0; iter < 1; ++iter) { + memory::wait_group<1>(); + block.sync(); + + if constexpr (KTraits::POS_ENCODING_MODE == + PosEncodingMode::kRoPELlama) + { + k_smem_inplace_apply_rotary( + chunk_start + iter * CTA_TILE_KV, &k_smem, &k_smem_offset_r, + rope_freq, tid); + block.sync(); + } +#if Debug1 + #if 0 if (warp_idx == 0 && lane_idx == 0) { - printf("\n DEBUG K LDS ORIGINAL (HIP):\n"); + printf("\n DEBUG K LDS ORIGINAL (HIP) Iter %d:\n", iter); uint32_t k_smem_offset_r_debug; for (auto i = 0; i < NUM_MMA_KV * 16; ++i) { for (auto j = 0; j < NUM_MMA_D_QK * 4; ++j) { @@ -2385,23 +2404,8 @@ SinglePrefillWithKVCacheDevice(const Params params, } } #endif -#endif - -#pragma unroll 1 - for (uint32_t iter = 0; iter < num_iterations; ++iter) { - // for (uint32_t iter = 0; iter < 1; ++iter) { - memory::wait_group<1>(); - block.sync(); - if constexpr (KTraits::POS_ENCODING_MODE == - PosEncodingMode::kRoPELlama) - { - k_smem_inplace_apply_rotary( - chunk_start + iter * CTA_TILE_KV, &k_smem, &k_smem_offset_r, - rope_freq, tid); - block.sync(); - } -#if Debug +#if 1 if (warp_idx == 0 && lane_idx == 0) { uint32_t b_frag[KTraits::INT32_ELEMS_PER_THREAD]; k_smem.load_fragment(k_smem_offset_r, b_frag); @@ -2411,13 +2415,23 @@ SinglePrefillWithKVCacheDevice(const Params params, printf("%f ", (float)(*(frag_T + i))); } } + printf("\n------------\n"); + k_smem.load_fragment(k_smem_offset_r, b_frag); + frag_T = reinterpret_cast<__half *>(b_frag); + for (auto reg_id = 0ul; reg_id < 4; ++reg_id) { + for (auto i = 0ul; i < 4; ++i) { + printf("%f ", (float)(*(frag_T + i))); + } + } + printf("\n-----===============-------\n"); } +#endif #endif // compute attention score compute_qk(&qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag); -#if Debug1 +#if Debug debug_write_sfrag_to_scratch( s_frag, tid, params.debug_thread_id, params.debug_warp_id); #endif diff --git a/test_prefill_sfrag.sh b/test_prefill_sfrag.sh index a6de32ef3c..f0bd5d0c3e 100644 --- a/test_prefill_sfrag.sh +++ b/test_prefill_sfrag.sh @@ -1,12 +1,13 @@ #!/bin/bash -# run_all_threads.sh -OUTPUT_FILE="sfrag_combined.log" +OUTPUT_FILE="sfrag_full.log" > $OUTPUT_FILE # Clear the file -for thread in {0..63}; do - echo "Running thread $thread..." - ./a.out --thread $thread >> $OUTPUT_FILE 2>&1 +for warp in $(seq 0 3); do + for thread in $(seq 0 63); do + echo "Running thread ${thread}... warp ${warp}" + ./a.out --thread $thread --warp $warp >> $OUTPUT_FILE 2>&1 + done done echo "All threads complete. Output in $OUTPUT_FILE" From 496aaafe9962348dbc54518bacbba70a6c5f8eb3 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Mon, 15 Sep 2025 13:00:47 -0400 Subject: [PATCH 078/109] wip debugging of softmax --- .../flashinfer/attention/generic/prefill.cuh | 144 +++++------------- .../include/gpu_iface/backend/hip/mma_hip.h | 1 + libflashinfer/utils/cpu_reference_hip.h | 12 +- 3 files changed, 48 insertions(+), 109 deletions(-) diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh index d16fc2d71c..609122eb75 100644 --- a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -700,7 +700,7 @@ __device__ __forceinline__ void init_states( for (uint32_t reg_id = 0; reg_id < KTraits::HALF_ELEMS_PER_THREAD; ++reg_id) { - o_frag[mma_q][mma_d][reg_id] = 0.f; + o_frag[mma_q][mma_d][reg_id] = 1.f; } } } @@ -712,7 +712,7 @@ __device__ __forceinline__ void init_states( for (uint32_t j = 0; j < NUM_ACCUM_ROWS_PER_THREAD; ++j) { m[mma_q][j] = typename KTraits::DTypeQKAccum(-gpu_iface::math::inf); - d[mma_q][j] = 1.f; + d[mma_q][j] = 0.f; } } } @@ -1329,22 +1329,14 @@ __device__ __forceinline__ void update_mdo_states( #pragma unroll for (uint32_t j = 0; j < NUM_ACCUM_ROWS_PER_THREAD; ++j) { float m_prev = m[mma_q][j]; +#if defined(PLATFORM_HIP_DEVICE) #pragma unroll for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { -#if defined(PLATFORM_HIP_DEVICE) - auto m_local = max(max(s_frag[mma_q][mma_kv][0], - s_frag[mma_q][mma_kv][1]), - max(s_frag[mma_q][mma_kv][2], - s_frag[mma_q][mma_kv][3])); - m[mma_q][j] = max(m[mma_q][j], m_local); -#else m[mma_q][j] = max(m[mma_q][j], s_frag[mma_q][mma_kv][j]); -#endif } -#if defined(PLATFORM_HIP_DEVICE) // Butterfly reduction across all threads in the band m[mma_q][j] = max(m[mma_q][j], gpu_iface::math::shfl_xor_sync( @@ -1378,7 +1370,18 @@ __device__ __forceinline__ void update_mdo_states( s_frag[mma_q][mma_kv][j] * sm_scale - m[mma_q][j] * sm_scale); } -#else // CUDA PATH +#elif (PLATFORM_CUDA_DEVICE) +#pragma unroll + for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; + ++mma_kv) + { + auto m_local = max(max(s_frag[mma_q][mma_kv][0], + s_frag[mma_q][mma_kv][1]), + max(s_frag[mma_q][mma_kv][2], + s_frag[mma_q][mma_kv][3])); + m[mma_q][j] = max(m[mma_q][j], m_local); + } + m[mma_q][j] = max(m[mma_q][j], gpu_iface::math::shfl_xor_sync(m[mma_q][j], 0x2)); @@ -1496,9 +1499,7 @@ __device__ __forceinline__ void compute_sfm_v( typename KTraits::DTypeQKAccum ( *s_frag)[KTraits::NUM_MMA_KV][KTraits::HALF_ELEMS_PER_THREAD], float (*o_frag)[KTraits::NUM_MMA_D_VO][KTraits::HALF_ELEMS_PER_THREAD], - float (*d)[KTraits::NUM_ACCUM_ROWS_PER_THREAD], - uint32_t warp_idx = 0, - uint32_t lane_idx = 0) + float (*d)[KTraits::NUM_ACCUM_ROWS_PER_THREAD]) { constexpr uint32_t UPCAST_STRIDE_V = KTraits::UPCAST_STRIDE_V; constexpr uint32_t HALF_ELEMS_PER_THREAD = KTraits::HALF_ELEMS_PER_THREAD; @@ -2431,7 +2432,7 @@ SinglePrefillWithKVCacheDevice(const Params params, // compute attention score compute_qk(&qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag); -#if Debug +#if Debug1 debug_write_sfrag_to_scratch( s_frag, tid, params.debug_thread_id, params.debug_warp_id); #endif @@ -2443,58 +2444,12 @@ SinglePrefillWithKVCacheDevice(const Params params, qo_len, kv_len, group_size, s_frag, tid, kv_head_idx); #if Debug1 - float soft_cap_pre_tanh_scale = - params.sm_scale * - gpu_iface::math::ptx_rcp(params.logits_soft_cap); - if (warp_idx == 0 && lane_idx == 0 && iter == 0) { - printf("params.sm_scale %f, params.logits_soft_cap %f\n", - params.sm_scale, params.logits_soft_cap); - printf("s_frag after logits transform (scaled by %f) : \n", - soft_cap_pre_tanh_scale); - for (auto mma_q = 0ul; mma_q < NUM_MMA_Q; ++mma_q) { - for (auto mma_kv = 0ul; mma_kv < NUM_MMA_KV; ++mma_kv) { - for (auto reg_id = 0ul; reg_id < HALF_ELEMS_PER_THREAD; - ++reg_id) - { - auto tmp = s_frag[mma_q][mma_kv][reg_id]; - printf("s_frag[%lu][%lu][%lu] : %f ", mma_q, mma_kv, - reg_id, float(tmp)); - } - printf("\n"); - } - } - } + debug_write_sfrag_to_scratch( + s_frag, tid, params.debug_thread_id, params.debug_warp_id); #endif #if Debug1 debug_write_sfrag_to_scratch(s_frag, &scratch, tid); - if (warp_idx == 0 && lane_idx == 0) { - float soft_cap_pre_tanh_scale = - params.sm_scale * - gpu_iface::math::ptx_rcp(params.logits_soft_cap); - printf("params.sm_scale %f, params.logits_soft_cap %f\n", - params.sm_scale, params.logits_soft_cap); - printf("s_frag after logits transform (scaled by %f) : \n", - soft_cap_pre_tanh_scale); - uint32_t scratch_offset_r_debug; - for (auto i = 0; i < NUM_MMA_KV * 16 * 2; ++i) { - for (auto j = 0; j < NUM_MMA_D_QK * 2; ++j) { - scratch_offset_r_debug = - scratch - .template get_permuted_offset( - i, j); - uint32_t a_frag[KTraits::INT32_ELEMS_PER_THREAD]; - scratch.load_fragment(scratch_offset_r_debug, a_frag); - auto frag_T = reinterpret_cast<__half *>(a_frag); - for (auto i = 0ul; i < 4; ++i) { - printf("%f ", (float)(*(frag_T + i))); - } - } - printf("\n"); - scratch.template advance_offset_by_row< - 16, KTraits::UPCAST_STRIDE_K>(scratch_offset_r_debug); - } - } #endif // apply mask if (MASK_MODE == MaskMode::kCustom || @@ -2510,53 +2465,18 @@ SinglePrefillWithKVCacheDevice(const Params params, } #if Debug1 - if (warp_idx == 0 && lane_idx == 0 && iter == 0) { - printf("s_frag after logits masking\n"); - for (auto mma_q = 0ul; mma_q < NUM_MMA_Q; ++mma_q) { - for (auto mma_kv = 0ul; mma_kv < NUM_MMA_KV; ++mma_kv) { - for (auto reg_id = 0ul; reg_id < HALF_ELEMS_PER_THREAD; - ++reg_id) - { - auto tmp = s_frag[mma_q][mma_kv][reg_id]; - printf("s_frag[%lu][%lu][%lu] : %f ", mma_q, mma_kv, - reg_id, float(tmp)); - } - printf("\n"); - } - } - } + debug_write_sfrag_to_scratch( + s_frag, tid, params.debug_thread_id, params.debug_warp_id); #endif + // compute m,d states in online softmax update_mdo_states(variant, s_frag, o_frag, m, d, warp_idx, lane_idx); #if Debug1 - if (warp_idx == 0 && lane_idx == 0 && iter == 0) { - printf("Max value for first 32 cols of row 0 %f\n", m[0][0]); - printf("D value for first 32 cols of row 0 %f\n", d[0][0]); - } -#endif - -#if Debug1 - if (warp_idx == 0 && lane_idx == 0 && iter == 0) { - printf("gpu_iface::math::ptx_exp2(0) = %f\n", - gpu_iface::math::ptx_exp2(0.f)); - printf("s_frag after exp - max att\n"); - for (auto mma_q = 0ul; mma_q < NUM_MMA_Q; ++mma_q) { - for (auto mma_kv = 0ul; mma_kv < NUM_MMA_KV; ++mma_kv) { - for (auto reg_id = 0ul; reg_id < HALF_ELEMS_PER_THREAD; - ++reg_id) - { - auto tmp = s_frag[mma_q][mma_kv][reg_id]; - printf("s_frag[%lu][%lu][%lu] : %f ", mma_q, mma_kv, - reg_id, float(tmp)); - } - printf("\n"); - } - } - } + debug_write_sfrag_to_scratch( + s_frag, tid, params.debug_thread_id, params.debug_warp_id); #endif - block.sync(); produce_kv( k_smem, &k_smem_offset_w, &k_ptr, k_stride_n, @@ -2566,8 +2486,20 @@ SinglePrefillWithKVCacheDevice(const Params params, block.sync(); // compute sfm*v - compute_sfm_v(&v_smem, &v_smem_offset_r, s_frag, o_frag, d, - warp_idx, lane_idx); + compute_sfm_v(&v_smem, &v_smem_offset_r, s_frag, o_frag, + d); +#if Debug + if (lane_idx == params.debug_thread_id && + warp_idx == params.debug_warp_id) + { + for (auto mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { + printf("%f\n", d[mma_q][0]); + printf("%f\n", d[mma_q][1]); + printf("%f\n", d[mma_q][2]); + printf("%f\n", d[mma_q][3]); + } + } +#endif block.sync(); produce_kv( diff --git a/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h b/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h index 7b32c54343..320a2bf818 100644 --- a/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h +++ b/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h @@ -182,6 +182,7 @@ template __device__ __forceinline__ void m16k16_rowsum_f16f16f32(float *d, DType *s_frag) { static_assert(sizeof(DType) == 2, "DType must be 16-bit type"); + transpose_4x4_half_registers(reinterpret_cast(s_frag)); f16x4 a = reinterpret_cast(s_frag)[0]; f16x4 b = {f16(1.0f), f16(1.0f), f16(1.0f), f16(1.0f)}; f32x4 c = {d[0], d[1], d[2], d[3]}; diff --git a/libflashinfer/utils/cpu_reference_hip.h b/libflashinfer/utils/cpu_reference_hip.h index 75d7faf05e..94ed05ea76 100644 --- a/libflashinfer/utils/cpu_reference_hip.h +++ b/libflashinfer/utils/cpu_reference_hip.h @@ -228,8 +228,6 @@ single_mha(const std::vector &q, std::cout << (float)q[info.get_q_elem_offset(i, 0, j)] << " "; } std::cout << std::endl; - // q[info.get_q_elem_offset(q_idx, qo_head_idx, feat_idx) - // std::cout << (float)q[info.get_q_elem_offset(0, 0, i)] << " "; } std::cout << std::endl; @@ -239,7 +237,6 @@ single_mha(const std::vector &q, std::cout << (float)k[info.get_kv_elem_offset(i, 0, j)] << " "; } std::cout << std::endl; - // q[info.get_q_elem_offset(q_idx, qo_head_idx, feat_idx) } std::cout << std::endl; #endif @@ -333,6 +330,15 @@ single_mha(const std::vector &q, } #endif +#if Debug1 + if (qo_head_idx == 0) { + for (auto i = 0ul; i < 128; ++i) { + std::cout << denom << " "; + } + std::cout << std::endl; + } +#endif + // divide by denom for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { att[kv_idx] /= denom; From 4a06c3223e6e849381c361961730efe4ec797813 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Mon, 15 Sep 2025 13:27:22 -0400 Subject: [PATCH 079/109] Formatting --- .../include/flashinfer/attention/generic/prefill.cuh | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh index 609122eb75..dcc3993e52 100644 --- a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -965,11 +965,10 @@ __device__ __forceinline__ void k_smem_inplace_apply_rotary( static_assert(KTraits::NUM_MMA_D_QK % (2 * KTraits::NUM_WARPS_Q) == 0); // horizontal axis: y // vertical axis: z - // | (warp_idx_z, warp_idx_x) | 1-16 | 16-32 | 32-48 | 48-64 - // | ... | 1-16*NUM_MMA_KV | (0, 0) | (0, 1) | (0, 2) | - // (0, 3) | ... | 16*NUM_MMA_KV-32*NUM_MMA_KV | (1, 0) | (1, 1) | (1, - // 2) | (1, 3) | ... - // ... + // | (warp_idx_z, warp_idx_x) | 1-16 | 16-32 | 32-48 | 48-64 + // | ... | 1-16*NUM_MMA_KV | (0, 0) | (0, 1) | (0, 2) | (0, 3) + // | ... | 16*NUM_MMA_KV-32*NUM_MMA_KV | (1, 0) | (1, 1) | (1, 2) | (1, 3) + // | ... ... uint32_t kv_idx = kv_idx_base + (warp_idx_z * KTraits::NUM_MMA_KV * 16) + lane_idx / THREADS_PER_BMATRIX_ROW_SET; From 9082889f85a2f5f24bf87b967a4c64218710bf93 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Tue, 16 Sep 2025 11:03:02 -0400 Subject: [PATCH 080/109] Update clang-format to match upstream. --- .clang-format | 31 +++++++------------------------ 1 file changed, 7 insertions(+), 24 deletions(-) diff --git a/.clang-format b/.clang-format index ab3229bbcc..21f4968225 100644 --- a/.clang-format +++ b/.clang-format @@ -1,24 +1,7 @@ -BasedOnStyle: LLVM -IndentWidth: 4 -AccessModifierOffset: -4 -AlignEscapedNewlines: Right -AllowAllParametersOfDeclarationOnNextLine: false -BinPackParameters: false -BreakBeforeBraces: Custom -BraceWrapping: - AfterCaseLabel: true - AfterClass: true - AfterControlStatement: MultiLine - AfterEnum: true - AfterFunction: true - AfterNamespace: true - AfterObjCDeclaration: false - AfterStruct: true - AfterUnion: true - AfterExternBlock: true - BeforeCatch: false - BeforeElse: true - IndentBraces: false - SplitEmptyFunction: true - SplitEmptyRecord: true - SplitEmptyNamespace: true +--- +BasedOnStyle: Google +DerivePointerAlignment: false +ColumnLimit: 100 +PointerAlignment: Left +# InsertNewlineAtEOF: true +... From 2c995368358b5d167a1611e2197ab13dee8ada4a Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Tue, 16 Sep 2025 11:04:45 -0400 Subject: [PATCH 081/109] Reformat libflashinfer --- libflashinfer/include/flashinfer/allocator.h | 68 +- .../flashinfer/attention/generic/allocator.h | 68 +- .../flashinfer/attention/generic/exception.h | 54 +- .../flashinfer/attention/generic/heap.h | 64 +- .../include/flashinfer/attention/heap.h | 62 +- libflashinfer/include/flashinfer/exception.h | 54 +- libflashinfer/include/flashinfer/fp16.h | 293 +- .../include/flashinfer/hip/activation.hip.h | 64 +- .../flashinfer/hip/attention/cascade.hip.h | 1209 +++---- .../flashinfer/hip/attention/decode.hip.h | 1991 +++++------- .../hip/attention/default_decode_params.hip.h | 458 +-- .../include/flashinfer/hip/attention/heap.h | 64 +- .../flashinfer/hip/attention/mask.hip.h | 16 +- .../flashinfer/hip/attention/scheduler.hip.h | 2647 +++++++-------- .../flashinfer/hip/attention/state.hip.h | 100 +- .../hip/attention/variant_helper.hip.h | 100 +- .../flashinfer/hip/attention/variants.hip.h | 154 +- .../flashinfer/hip/attention_impl.hip.h | 2 +- .../include/flashinfer/hip/cp_async.hip.h | 97 +- .../include/flashinfer/hip/fastdiv.hip.h | 150 +- .../include/flashinfer/hip/hip_platform.h | 19 +- .../include/flashinfer/hip/layout.hip.h | 180 +- .../include/flashinfer/hip/math.hip.h | 71 +- .../include/flashinfer/hip/norm.hip.h | 522 ++- .../include/flashinfer/hip/page.hip.h | 1226 ++++--- .../include/flashinfer/hip/pos_enc.hip.h | 1642 ++++------ .../include/flashinfer/hip/utils.hip.h | 656 ++-- .../include/flashinfer/hip/vec_dtypes.hip.h | 2886 +++++++--------- .../include/gpu_iface/backend/hip/math_hip.h | 71 +- .../gpu_iface/backend/hip/memory_ops_hip.h | 131 +- .../include/gpu_iface/backend/hip/mma_hip.h | 281 +- .../gpu_iface/backend/hip/vec_dtypes_hip.h | 2889 +++++++---------- .../include/gpu_iface/conversion_utils.h | 53 +- .../include/gpu_iface/cooperative_groups.h | 20 +- libflashinfer/include/gpu_iface/enums.hpp | 29 +- libflashinfer/include/gpu_iface/error.hpp | 52 +- libflashinfer/include/gpu_iface/exception.h | 54 +- libflashinfer/include/gpu_iface/fragment.hpp | 187 +- .../include/gpu_iface/gpu_runtime_compat.hpp | 64 +- libflashinfer/include/gpu_iface/math_ops.hpp | 15 +- .../include/gpu_iface/memory_ops.hpp | 76 +- libflashinfer/include/gpu_iface/mma_ops.hpp | 50 +- libflashinfer/include/gpu_iface/mma_types.hpp | 22 +- libflashinfer/include/gpu_iface/platform.hpp | 13 +- .../include/gpu_iface/vec_dtypes.hpp | 18 +- .../tests/hip/test_apply_llama_rope.cpp | 624 ++-- libflashinfer/tests/hip/test_batch_decode.cpp | 386 +-- libflashinfer/tests/hip/test_cascade.cpp | 971 +++--- libflashinfer/tests/hip/test_compute_sfm.cpp | 371 +-- .../tests/hip/test_k_smem_read_pattern.cpp | 291 +- .../tests/hip/test_load_q_global_smem.cpp | 19 +- .../tests/hip/test_load_q_global_smem_v1.cpp | 275 +- .../tests/hip/test_load_q_global_smem_v2.cpp | 462 ++- libflashinfer/tests/hip/test_math.cpp | 352 +- .../tests/hip/test_mfma_fp32_16x16x16fp16.cpp | 252 +- libflashinfer/tests/hip/test_page.cpp | 723 ++--- .../tests/hip/test_permuted_smem.cpp | 407 ++- libflashinfer/tests/hip/test_pos_enc.cpp | 646 ++-- libflashinfer/tests/hip/test_produce_kv.cpp | 352 +- .../tests/hip/test_q_smem_read_pattern.cpp | 256 +- libflashinfer/tests/hip/test_rowsum.cpp | 237 +- .../tests/hip/test_single_decode.cpp | 353 +- .../tests/hip/test_single_prefill.cpp | 289 +- .../hip/test_transpose_4x4_half_registers.cpp | 509 ++- libflashinfer/utils/conversion_utils.h | 53 +- libflashinfer/utils/cpu_reference.h | 333 +- libflashinfer/utils/cpu_reference_hip.h | 617 ++-- .../flashinfer_batch_decode_test_ops.hip.h | 361 +- .../utils/flashinfer_prefill_ops.hip.h | 185 +- libflashinfer/utils/utils.h | 329 +- libflashinfer/utils/utils_hip.h | 404 +-- 71 files changed, 12776 insertions(+), 16223 deletions(-) diff --git a/libflashinfer/include/flashinfer/allocator.h b/libflashinfer/include/flashinfer/allocator.h index 942906d1b9..7bd1a854ed 100644 --- a/libflashinfer/include/flashinfer/allocator.h +++ b/libflashinfer/include/flashinfer/allocator.h @@ -21,52 +21,42 @@ #include "exception.h" -namespace flashinfer -{ +namespace flashinfer { // create a function that returns T* from base pointer and offset -template T *GetPtrFromBaseOffset(void *base_ptr, int64_t offset) -{ - return reinterpret_cast(reinterpret_cast(base_ptr) + offset); +template +T* GetPtrFromBaseOffset(void* base_ptr, int64_t offset) { + return reinterpret_cast(reinterpret_cast(base_ptr) + offset); } -struct AlignedAllocator -{ - void *base_ptr; - void *cur_ptr; - size_t remaining_space; - AlignedAllocator(void *buf, size_t space) - : base_ptr(buf), cur_ptr(buf), remaining_space(space) - { - } - template - T *aligned_alloc(size_t size, size_t alignment, std::string name) - { - if (std::align(alignment, size, cur_ptr, remaining_space)) { - T *result = reinterpret_cast(cur_ptr); - cur_ptr = (char *)cur_ptr + size; - remaining_space -= size; - return result; - } - else { - std::ostringstream oss; - oss << "Failed to allocate memory for " << name << " with size " - << size << " and alignment " << alignment - << " in AlignedAllocator"; - FLASHINFER_ERROR(oss.str()); - } - return nullptr; +struct AlignedAllocator { + void* base_ptr; + void* cur_ptr; + size_t remaining_space; + AlignedAllocator(void* buf, size_t space) : base_ptr(buf), cur_ptr(buf), remaining_space(space) {} + template + T* aligned_alloc(size_t size, size_t alignment, std::string name) { + if (std::align(alignment, size, cur_ptr, remaining_space)) { + T* result = reinterpret_cast(cur_ptr); + cur_ptr = (char*)cur_ptr + size; + remaining_space -= size; + return result; + } else { + std::ostringstream oss; + oss << "Failed to allocate memory for " << name << " with size " << size << " and alignment " + << alignment << " in AlignedAllocator"; + FLASHINFER_ERROR(oss.str()); } + return nullptr; + } - size_t aligned_alloc_offset(size_t size, size_t alignment, std::string name) - { - return (char *)aligned_alloc(size, alignment, name) - - (char *)base_ptr; - } + size_t aligned_alloc_offset(size_t size, size_t alignment, std::string name) { + return (char*)aligned_alloc(size, alignment, name) - (char*)base_ptr; + } - size_t num_allocated_bytes() { return (char *)cur_ptr - (char *)base_ptr; } + size_t num_allocated_bytes() { return (char*)cur_ptr - (char*)base_ptr; } }; -} // namespace flashinfer +} // namespace flashinfer -#endif // FLASHINFER_ALLOCATOR_H_ +#endif // FLASHINFER_ALLOCATOR_H_ diff --git a/libflashinfer/include/flashinfer/attention/generic/allocator.h b/libflashinfer/include/flashinfer/attention/generic/allocator.h index 2efafd536a..d776d5f8da 100644 --- a/libflashinfer/include/flashinfer/attention/generic/allocator.h +++ b/libflashinfer/include/flashinfer/attention/generic/allocator.h @@ -11,52 +11,42 @@ #include "exception.h" -namespace flashinfer -{ +namespace flashinfer { // create a function that returns T* from base pointer and offset -template T *GetPtrFromBaseOffset(void *base_ptr, int64_t offset) -{ - return reinterpret_cast(reinterpret_cast(base_ptr) + offset); +template +T* GetPtrFromBaseOffset(void* base_ptr, int64_t offset) { + return reinterpret_cast(reinterpret_cast(base_ptr) + offset); } -struct AlignedAllocator -{ - void *base_ptr; - void *cur_ptr; - size_t remaining_space; - AlignedAllocator(void *buf, size_t space) - : base_ptr(buf), cur_ptr(buf), remaining_space(space) - { - } - template - T *aligned_alloc(size_t size, size_t alignment, std::string name) - { - if (std::align(alignment, size, cur_ptr, remaining_space)) { - T *result = reinterpret_cast(cur_ptr); - cur_ptr = (char *)cur_ptr + size; - remaining_space -= size; - return result; - } - else { - std::ostringstream oss; - oss << "Failed to allocate memory for " << name << " with size " - << size << " and alignment " << alignment - << " in AlignedAllocator"; - FLASHINFER_ERROR(oss.str()); - } - return nullptr; +struct AlignedAllocator { + void* base_ptr; + void* cur_ptr; + size_t remaining_space; + AlignedAllocator(void* buf, size_t space) : base_ptr(buf), cur_ptr(buf), remaining_space(space) {} + template + T* aligned_alloc(size_t size, size_t alignment, std::string name) { + if (std::align(alignment, size, cur_ptr, remaining_space)) { + T* result = reinterpret_cast(cur_ptr); + cur_ptr = (char*)cur_ptr + size; + remaining_space -= size; + return result; + } else { + std::ostringstream oss; + oss << "Failed to allocate memory for " << name << " with size " << size << " and alignment " + << alignment << " in AlignedAllocator"; + FLASHINFER_ERROR(oss.str()); } + return nullptr; + } - size_t aligned_alloc_offset(size_t size, size_t alignment, std::string name) - { - return (char *)aligned_alloc(size, alignment, name) - - (char *)base_ptr; - } + size_t aligned_alloc_offset(size_t size, size_t alignment, std::string name) { + return (char*)aligned_alloc(size, alignment, name) - (char*)base_ptr; + } - size_t num_allocated_bytes() { return (char *)cur_ptr - (char *)base_ptr; } + size_t num_allocated_bytes() { return (char*)cur_ptr - (char*)base_ptr; } }; -} // namespace flashinfer +} // namespace flashinfer -#endif // FLASHINFER_ALLOCATOR_H_ +#endif // FLASHINFER_ALLOCATOR_H_ diff --git a/libflashinfer/include/flashinfer/attention/generic/exception.h b/libflashinfer/include/flashinfer/attention/generic/exception.h index 8abe50b059..1f410c1edb 100644 --- a/libflashinfer/include/flashinfer/attention/generic/exception.h +++ b/libflashinfer/include/flashinfer/attention/generic/exception.h @@ -9,40 +9,30 @@ #include #include -namespace flashinfer -{ - -class Error : public std::exception -{ -private: - std::string message_; - -public: - Error(const std::string &func, - const std::string &file, - int line, - const std::string &message) - { - std::ostringstream oss; - oss << "Error in function '" << func << "' " - << "at " << file << ":" << line << ": " << message; - message_ = oss.str(); - } - - virtual const char *what() const noexcept override - { - return message_.c_str(); - } +namespace flashinfer { + +class Error : public std::exception { + private: + std::string message_; + + public: + Error(const std::string& func, const std::string& file, int line, const std::string& message) { + std::ostringstream oss; + oss << "Error in function '" << func << "' " + << "at " << file << ":" << line << ": " << message; + message_ = oss.str(); + } + + virtual const char* what() const noexcept override { return message_.c_str(); } }; -#define FLASHINFER_ERROR(message) \ - throw Error(__FUNCTION__, __FILE__, __LINE__, message) +#define FLASHINFER_ERROR(message) throw Error(__FUNCTION__, __FILE__, __LINE__, message) -#define FLASHINFER_CHECK(condition, message) \ - if (!(condition)) { \ - FLASHINFER_ERROR(message); \ - } +#define FLASHINFER_CHECK(condition, message) \ + if (!(condition)) { \ + FLASHINFER_ERROR(message); \ + } -} // namespace flashinfer +} // namespace flashinfer -#endif // FLASHINFER_EXCEPTION_H_ +#endif // FLASHINFER_EXCEPTION_H_ diff --git a/libflashinfer/include/flashinfer/attention/generic/heap.h b/libflashinfer/include/flashinfer/attention/generic/heap.h index 312741d280..e69ac1db88 100644 --- a/libflashinfer/include/flashinfer/attention/generic/heap.h +++ b/libflashinfer/include/flashinfer/attention/generic/heap.h @@ -11,52 +11,46 @@ #include #include -namespace flashinfer -{ +namespace flashinfer { /*! * \brief Heap data structure for (index, value) pairs * \note minimal element on top */ -class MinHeap -{ -public: - // first: index, second: cost - using Element = std::pair; - - MinHeap(int capacity) : heap_(capacity) - { - for (int i = 0; i < capacity; ++i) { - heap_[i] = std::make_pair(i, 0.f); - } +class MinHeap { + public: + // first: index, second: cost + using Element = std::pair; + + MinHeap(int capacity) : heap_(capacity) { + for (int i = 0; i < capacity; ++i) { + heap_[i] = std::make_pair(i, 0.f); } + } - void insert(const Element &element) - { - heap_.push_back(element); - std::push_heap(heap_.begin(), heap_.end(), compare); - } + void insert(const Element& element) { + heap_.push_back(element); + std::push_heap(heap_.begin(), heap_.end(), compare); + } - Element pop() - { - std::pop_heap(heap_.begin(), heap_.end(), compare); - Element minElement = heap_.back(); - heap_.pop_back(); - return minElement; - } + Element pop() { + std::pop_heap(heap_.begin(), heap_.end(), compare); + Element minElement = heap_.back(); + heap_.pop_back(); + return minElement; + } - std::vector getHeap() const { return heap_; } + std::vector getHeap() const { return heap_; } -private: - // Custom comparator for the min-heap: compare based on 'val' in the pair - static bool compare(const Element &a, const Element &b) - { - return a.second > b.second; // create a min-heap based on val - } + private: + // Custom comparator for the min-heap: compare based on 'val' in the pair + static bool compare(const Element& a, const Element& b) { + return a.second > b.second; // create a min-heap based on val + } - std::vector heap_; + std::vector heap_; }; -} // namespace flashinfer +} // namespace flashinfer -#endif // FLASHINFER_ATTENTION_HEAP_H +#endif // FLASHINFER_ATTENTION_HEAP_H diff --git a/libflashinfer/include/flashinfer/attention/heap.h b/libflashinfer/include/flashinfer/attention/heap.h index 18bf9366f5..484669fbe2 100644 --- a/libflashinfer/include/flashinfer/attention/heap.h +++ b/libflashinfer/include/flashinfer/attention/heap.h @@ -21,52 +21,46 @@ #include #include -namespace flashinfer -{ +namespace flashinfer { /*! * \brief Heap data structure for (index, value) pairs * \note minimal element on top */ -class MinHeap -{ -public: - // first: index, second: cost - using Element = std::pair; +class MinHeap { + public: + // first: index, second: cost + using Element = std::pair; - MinHeap(int capacity) : heap_(capacity) - { - for (int i = 0; i < capacity; ++i) { - heap_[i] = std::make_pair(i, 0.f); - } + MinHeap(int capacity) : heap_(capacity) { + for (int i = 0; i < capacity; ++i) { + heap_[i] = std::make_pair(i, 0.f); } + } - void insert(const Element &element) - { - heap_.push_back(element); - std::push_heap(heap_.begin(), heap_.end(), compare); - } + void insert(const Element& element) { + heap_.push_back(element); + std::push_heap(heap_.begin(), heap_.end(), compare); + } - Element pop() - { - std::pop_heap(heap_.begin(), heap_.end(), compare); - Element minElement = heap_.back(); - heap_.pop_back(); - return minElement; - } + Element pop() { + std::pop_heap(heap_.begin(), heap_.end(), compare); + Element minElement = heap_.back(); + heap_.pop_back(); + return minElement; + } - std::vector getHeap() const { return heap_; } + std::vector getHeap() const { return heap_; } -private: - // Custom comparator for the min-heap: compare based on 'val' in the pair - static bool compare(const Element &a, const Element &b) - { - return a.second > b.second; // create a min-heap based on val - } + private: + // Custom comparator for the min-heap: compare based on 'val' in the pair + static bool compare(const Element& a, const Element& b) { + return a.second > b.second; // create a min-heap based on val + } - std::vector heap_; + std::vector heap_; }; -} // namespace flashinfer +} // namespace flashinfer -#endif // FLASHINFER_ATTENTION_HEAP_H +#endif // FLASHINFER_ATTENTION_HEAP_H diff --git a/libflashinfer/include/flashinfer/exception.h b/libflashinfer/include/flashinfer/exception.h index 2630f5e44b..9d4f9d7832 100644 --- a/libflashinfer/include/flashinfer/exception.h +++ b/libflashinfer/include/flashinfer/exception.h @@ -19,40 +19,30 @@ #include #include -namespace flashinfer -{ - -class Error : public std::exception -{ -private: - std::string message_; - -public: - Error(const std::string &func, - const std::string &file, - int line, - const std::string &message) - { - std::ostringstream oss; - oss << "Error in function '" << func << "' " - << "at " << file << ":" << line << ": " << message; - message_ = oss.str(); - } - - virtual const char *what() const noexcept override - { - return message_.c_str(); - } +namespace flashinfer { + +class Error : public std::exception { + private: + std::string message_; + + public: + Error(const std::string& func, const std::string& file, int line, const std::string& message) { + std::ostringstream oss; + oss << "Error in function '" << func << "' " + << "at " << file << ":" << line << ": " << message; + message_ = oss.str(); + } + + virtual const char* what() const noexcept override { return message_.c_str(); } }; -#define FLASHINFER_ERROR(message) \ - throw Error(__FUNCTION__, __FILE__, __LINE__, message) +#define FLASHINFER_ERROR(message) throw Error(__FUNCTION__, __FILE__, __LINE__, message) -#define FLASHINFER_CHECK(condition, message) \ - if (!(condition)) { \ - FLASHINFER_ERROR(message); \ - } +#define FLASHINFER_CHECK(condition, message) \ + if (!(condition)) { \ + FLASHINFER_ERROR(message); \ + } -} // namespace flashinfer +} // namespace flashinfer -#endif // FLASHINFER_EXCEPTION_H_ +#endif // FLASHINFER_EXCEPTION_H_ diff --git a/libflashinfer/include/flashinfer/fp16.h b/libflashinfer/include/flashinfer/fp16.h index c36deab5a2..d45d078172 100644 --- a/libflashinfer/include/flashinfer/fp16.h +++ b/libflashinfer/include/flashinfer/fp16.h @@ -22,163 +22,156 @@ * mode and no operations on denormals) floating-point operations and bitcasts * between integer and floating-point variables. */ -static constexpr uint16_t fp16_ieee_from_fp32_value(float f) -{ - const float scale_to_inf = std::bit_cast(UINT32_C(0x77800000)); - const float scale_to_zero = std::bit_cast(UINT32_C(0x08800000)); - const float saturated_f = - boost::math::ccmath::fabs(f) * scale_to_inf; +static constexpr uint16_t fp16_ieee_from_fp32_value(float f) { + const float scale_to_inf = std::bit_cast(UINT32_C(0x77800000)); + const float scale_to_zero = std::bit_cast(UINT32_C(0x08800000)); + const float saturated_f = boost::math::ccmath::fabs(f) * scale_to_inf; - float base = saturated_f * scale_to_zero; + float base = saturated_f * scale_to_zero; - const uint32_t w = std::bit_cast(f); - const uint32_t shl1_w = w + w; - const uint32_t sign = w & UINT32_C(0x80000000); - uint32_t bias = shl1_w & UINT32_C(0xFF000000); - if (bias < UINT32_C(0x71000000)) { - bias = UINT32_C(0x71000000); - } + const uint32_t w = std::bit_cast(f); + const uint32_t shl1_w = w + w; + const uint32_t sign = w & UINT32_C(0x80000000); + uint32_t bias = shl1_w & UINT32_C(0xFF000000); + if (bias < UINT32_C(0x71000000)) { + bias = UINT32_C(0x71000000); + } - base = std::bit_cast((bias >> 1) + UINT32_C(0x07800000)) + base; - const uint32_t bits = std::bit_cast(base); - const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00); - const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF); - const uint32_t nonsign = exp_bits + mantissa_bits; - return (sign >> 16) | - (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign); + base = std::bit_cast((bias >> 1) + UINT32_C(0x07800000)) + base; + const uint32_t bits = std::bit_cast(base); + const uint32_t exp_bits = (bits >> 13) & UINT32_C(0x00007C00); + const uint32_t mantissa_bits = bits & UINT32_C(0x00000FFF); + const uint32_t nonsign = exp_bits + mantissa_bits; + return (sign >> 16) | (shl1_w > UINT32_C(0xFF000000) ? UINT16_C(0x7E00) : nonsign); } -static constexpr float fp16_ieee_to_fp32_value(uint16_t h) -{ - /* - * Extend the half-precision floating-point number to 32 bits and shift to - * the upper part of the 32-bit word: - * +---+-----+------------+-------------------+ - * | S |EEEEE|MM MMMM MMMM|0000 0000 0000 0000| - * +---+-----+------------+-------------------+ - * Bits 31 26-30 16-25 0-15 - * - * S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, - * 0 - zero bits. - */ - const uint32_t w = (uint32_t)h << 16; - /* - * Extract the sign of the input number into the high bit of the 32-bit - * word: - * - * +---+----------------------------------+ - * | S |0000000 00000000 00000000 00000000| - * +---+----------------------------------+ - * Bits 31 0-31 - */ - const uint32_t sign = w & UINT32_C(0x80000000); - /* - * Extract mantissa and biased exponent of the input number into the high - * bits of the 32-bit word: - * - * +-----+------------+---------------------+ - * |EEEEE|MM MMMM MMMM|0 0000 0000 0000 0000| - * +-----+------------+---------------------+ - * Bits 27-31 17-26 0-16 - */ - const uint32_t two_w = w + w; +static constexpr float fp16_ieee_to_fp32_value(uint16_t h) { + /* + * Extend the half-precision floating-point number to 32 bits and shift to + * the upper part of the 32-bit word: + * +---+-----+------------+-------------------+ + * | S |EEEEE|MM MMMM MMMM|0000 0000 0000 0000| + * +---+-----+------------+-------------------+ + * Bits 31 26-30 16-25 0-15 + * + * S - sign bit, E - bits of the biased exponent, M - bits of the mantissa, + * 0 - zero bits. + */ + const uint32_t w = (uint32_t)h << 16; + /* + * Extract the sign of the input number into the high bit of the 32-bit + * word: + * + * +---+----------------------------------+ + * | S |0000000 00000000 00000000 00000000| + * +---+----------------------------------+ + * Bits 31 0-31 + */ + const uint32_t sign = w & UINT32_C(0x80000000); + /* + * Extract mantissa and biased exponent of the input number into the high + * bits of the 32-bit word: + * + * +-----+------------+---------------------+ + * |EEEEE|MM MMMM MMMM|0 0000 0000 0000 0000| + * +-----+------------+---------------------+ + * Bits 27-31 17-26 0-16 + */ + const uint32_t two_w = w + w; - /* - * Shift mantissa and exponent into bits 23-28 and bits 13-22 so they become - * mantissa and exponent of a single-precision floating-point number: - * - * S|Exponent | Mantissa - * +-+---+-----+------------+----------------+ - * |0|000|EEEEE|MM MMMM MMMM|0 0000 0000 0000| - * +-+---+-----+------------+----------------+ - * Bits | 23-31 | 0-22 - * - * Next, there are some adjustments to the exponent: - * - The exponent needs to be corrected by the difference in exponent bias - * between single-precision and half-precision - * formats (0x7F - 0xF = 0x70) - * - Inf and NaN values in the inputs should become Inf and NaN values after - * conversion to the single-precision number. - * Therefore, if the biased exponent of the half-precision input was 0x1F - * (max possible value), the biased exponent - * of the single-precision output must be 0xFF (max possible value). We do - * this correction in two steps: - * - First, we adjust the exponent by (0xFF - 0x1F) = 0xE0 (see exp_offset - * below) rather than by 0x70 suggested - * by the difference in the exponent bias (see above). - * - Then we multiply the single-precision result of exponent adjustment - * by 2**(-112) to reverse the effect of - * exponent adjustment by 0xE0 less the necessary exponent adjustment by - * 0x70 due to difference in exponent bias. - * The floating-point multiplication hardware would ensure than Inf and - * NaN would retain their value on at least - * partially IEEE754-compliant implementations. - * - * Note that the above operations do not handle denormal inputs (where - * biased exponent == 0). However, they also do not operate on denormal - * inputs, and do not produce denormal results. - */ - const uint32_t exp_offset = UINT32_C(0xE0) << 23; - const float exp_scale = std::bit_cast(UINT32_C(0x7800000)); - const float normalized_value = - std::bit_cast((two_w >> 4) + exp_offset) * exp_scale; + /* + * Shift mantissa and exponent into bits 23-28 and bits 13-22 so they become + * mantissa and exponent of a single-precision floating-point number: + * + * S|Exponent | Mantissa + * +-+---+-----+------------+----------------+ + * |0|000|EEEEE|MM MMMM MMMM|0 0000 0000 0000| + * +-+---+-----+------------+----------------+ + * Bits | 23-31 | 0-22 + * + * Next, there are some adjustments to the exponent: + * - The exponent needs to be corrected by the difference in exponent bias + * between single-precision and half-precision + * formats (0x7F - 0xF = 0x70) + * - Inf and NaN values in the inputs should become Inf and NaN values after + * conversion to the single-precision number. + * Therefore, if the biased exponent of the half-precision input was 0x1F + * (max possible value), the biased exponent + * of the single-precision output must be 0xFF (max possible value). We do + * this correction in two steps: + * - First, we adjust the exponent by (0xFF - 0x1F) = 0xE0 (see exp_offset + * below) rather than by 0x70 suggested + * by the difference in the exponent bias (see above). + * - Then we multiply the single-precision result of exponent adjustment + * by 2**(-112) to reverse the effect of + * exponent adjustment by 0xE0 less the necessary exponent adjustment by + * 0x70 due to difference in exponent bias. + * The floating-point multiplication hardware would ensure than Inf and + * NaN would retain their value on at least + * partially IEEE754-compliant implementations. + * + * Note that the above operations do not handle denormal inputs (where + * biased exponent == 0). However, they also do not operate on denormal + * inputs, and do not produce denormal results. + */ + const uint32_t exp_offset = UINT32_C(0xE0) << 23; + const float exp_scale = std::bit_cast(UINT32_C(0x7800000)); + const float normalized_value = std::bit_cast((two_w >> 4) + exp_offset) * exp_scale; - /* - * Convert denormalized half-precision inputs into single-precision results - * (always normalized). - * Zero inputs are also handled here. - * - * In a denormalized number the biased exponent is zero, and mantissa has - * on-zero bits. - * First, we shift mantissa into bits 0-9 of the 32-bit word. - * - * zeros | mantissa - * +---------------------------+------------+ - * |0000 0000 0000 0000 0000 00|MM MMMM MMMM| - * +---------------------------+------------+ - * Bits 10-31 0-9 - * - * Now, remember that denormalized half-precision numbers are represented - * as: - * FP16 = mantissa * 2**(-24). - * The trick is to construct a normalized single-precision number with the - * same mantissa and thehalf-precision input - * and with an exponent which would scale the corresponding mantissa bits - * to 2**(-24). - * A normalized single-precision floating-point number is represented as: - * FP32 = (1 + mantissa * 2**(-23)) * 2**(exponent - 127) - * Therefore, when the biased exponent is 126, a unit change in the mantissa - * of the input denormalized half-precision - * number causes a change of the constructud single-precision number by - * 2**(-24), i.e. the same ammount. - * - * The last step is to adjust the bias of the constructed single-precision - * number. When the input half-precision number - * is zero, the constructed single-precision number has the value of - * FP32 = 1 * 2**(126 - 127) = 2**(-1) = 0.5 - * Therefore, we need to subtract 0.5 from the constructed single-precision - * number to get the numerical equivalent of - * the input half-precision number. - */ - const uint32_t magic_mask = UINT32_C(126) << 23; - const float magic_bias = 0.5f; - const float denormalized_value = - std::bit_cast((two_w >> 17) | magic_mask) - magic_bias; + /* + * Convert denormalized half-precision inputs into single-precision results + * (always normalized). + * Zero inputs are also handled here. + * + * In a denormalized number the biased exponent is zero, and mantissa has + * on-zero bits. + * First, we shift mantissa into bits 0-9 of the 32-bit word. + * + * zeros | mantissa + * +---------------------------+------------+ + * |0000 0000 0000 0000 0000 00|MM MMMM MMMM| + * +---------------------------+------------+ + * Bits 10-31 0-9 + * + * Now, remember that denormalized half-precision numbers are represented + * as: + * FP16 = mantissa * 2**(-24). + * The trick is to construct a normalized single-precision number with the + * same mantissa and thehalf-precision input + * and with an exponent which would scale the corresponding mantissa bits + * to 2**(-24). + * A normalized single-precision floating-point number is represented as: + * FP32 = (1 + mantissa * 2**(-23)) * 2**(exponent - 127) + * Therefore, when the biased exponent is 126, a unit change in the mantissa + * of the input denormalized half-precision + * number causes a change of the constructud single-precision number by + * 2**(-24), i.e. the same ammount. + * + * The last step is to adjust the bias of the constructed single-precision + * number. When the input half-precision number + * is zero, the constructed single-precision number has the value of + * FP32 = 1 * 2**(126 - 127) = 2**(-1) = 0.5 + * Therefore, we need to subtract 0.5 from the constructed single-precision + * number to get the numerical equivalent of + * the input half-precision number. + */ + const uint32_t magic_mask = UINT32_C(126) << 23; + const float magic_bias = 0.5f; + const float denormalized_value = std::bit_cast((two_w >> 17) | magic_mask) - magic_bias; - /* - * - Choose either results of conversion of input as a normalized number, or - * as a denormalized number, depending on the - * input exponent. The variable two_w contains input exponent in bits - * 27-31, therefore if its smaller than 2**27, the - * input is either a denormal number, or zero. - * - Combine the result of conversion of exponent and mantissa with the sign - * of the input number. - */ - const uint32_t denormalized_cutoff = UINT32_C(1) << 27; - const uint32_t result = - sign | (two_w < denormalized_cutoff - ? std::bit_cast(denormalized_value) - : std::bit_cast(normalized_value)); - return std::bit_cast(result); + /* + * - Choose either results of conversion of input as a normalized number, or + * as a denormalized number, depending on the + * input exponent. The variable two_w contains input exponent in bits + * 27-31, therefore if its smaller than 2**27, the + * input is either a denormal number, or zero. + * - Combine the result of conversion of exponent and mantissa with the sign + * of the input number. + */ + const uint32_t denormalized_cutoff = UINT32_C(1) << 27; + const uint32_t result = + sign | (two_w < denormalized_cutoff ? std::bit_cast(denormalized_value) + : std::bit_cast(normalized_value)); + return std::bit_cast(result); #endif } diff --git a/libflashinfer/include/flashinfer/hip/activation.hip.h b/libflashinfer/include/flashinfer/hip/activation.hip.h index 4b1e988f06..96edc98c03 100644 --- a/libflashinfer/include/flashinfer/hip/activation.hip.h +++ b/libflashinfer/include/flashinfer/hip/activation.hip.h @@ -11,47 +11,41 @@ #include "utils.hip.h" #include "vec_dtypes.hip.h" -namespace flashinfer -{ - -namespace activation -{ - -template -__global__ void act_and_mul_kernel(T *__restrict__ out, - const T *__restrict__ input, - const int d) -{ - constexpr uint32_t vec_size = 16 / sizeof(T); - const int64_t token_idx = blockIdx.x; - const int64_t thread_idx = threadIdx.x; - const int64_t stride = blockDim.x; - const int64_t offset = token_idx * 2 * d; +namespace flashinfer { + +namespace activation { + +template +__global__ void act_and_mul_kernel(T* __restrict__ out, const T* __restrict__ input, const int d) { + constexpr uint32_t vec_size = 16 / sizeof(T); + const int64_t token_idx = blockIdx.x; + const int64_t thread_idx = threadIdx.x; + const int64_t stride = blockDim.x; + const int64_t offset = token_idx * 2 * d; #pragma unroll 1 - for (uint32_t idx = thread_idx; idx < d / vec_size; idx += stride) { - vec_t x_vec, y_vec, out_vec; - x_vec.cast_load(input + offset + idx * vec_size); - y_vec.cast_load(input + offset + d + idx * vec_size); + for (uint32_t idx = thread_idx; idx < d / vec_size; idx += stride) { + vec_t x_vec, y_vec, out_vec; + x_vec.cast_load(input + offset + idx * vec_size); + y_vec.cast_load(input + offset + d + idx * vec_size); #pragma unroll - for (uint32_t i = 0; i < vec_size; ++i) { - out_vec[i] = Activation(x_vec[i]) * y_vec[i]; - } - out_vec.cast_store(out + token_idx * d + idx * vec_size); + for (uint32_t i = 0; i < vec_size; ++i) { + out_vec[i] = Activation(x_vec[i]) * y_vec[i]; } + out_vec.cast_store(out + token_idx * d + idx * vec_size); + } - const int64_t remaining_offset = d - d % (stride * vec_size); - // process the remaining elements + const int64_t remaining_offset = d - d % (stride * vec_size); + // process the remaining elements #pragma unroll 1 - for (int64_t idx = thread_idx; idx < d % (stride * vec_size); idx += stride) - { - float x = input[offset + remaining_offset + idx], - y = input[offset + remaining_offset + d + idx]; - out[token_idx * d + remaining_offset + idx] = Activation(x) * y; - } + for (int64_t idx = thread_idx; idx < d % (stride * vec_size); idx += stride) { + float x = input[offset + remaining_offset + idx], + y = input[offset + remaining_offset + d + idx]; + out[token_idx * d + remaining_offset + idx] = Activation(x) * y; + } } -} // namespace activation -} // namespace flashinfer +} // namespace activation +} // namespace flashinfer -#endif // FLASHINFER_ACTIVATION_CUH_ +#endif // FLASHINFER_ACTIVATION_CUH_ diff --git a/libflashinfer/include/flashinfer/hip/attention/cascade.hip.h b/libflashinfer/include/flashinfer/hip/attention/cascade.hip.h index 621c191ff9..e74983290b 100644 --- a/libflashinfer/include/flashinfer/hip/attention/cascade.hip.h +++ b/libflashinfer/include/flashinfer/hip/attention/cascade.hip.h @@ -17,65 +17,51 @@ #include #include -namespace fi::con -{ +namespace fi::con { template -__host__ __device__ __inline__ DTypeOut explicit_casting(DTypeIn value) -{ - return DTypeOut(value); +__host__ __device__ __inline__ DTypeOut explicit_casting(DTypeIn value) { + return DTypeOut(value); } template <> -__host__ __device__ __inline__ float -explicit_casting<__half, float>(__half value) -{ - return __half2float(value); +__host__ __device__ __inline__ float explicit_casting<__half, float>(__half value) { + return __half2float(value); } template <> -__host__ __device__ __inline__ float -explicit_casting<__hip_bfloat16, float>(__hip_bfloat16 value) -{ - return __bfloat162float(value); +__host__ __device__ __inline__ float explicit_casting<__hip_bfloat16, float>(__hip_bfloat16 value) { + return __bfloat162float(value); } template <> -__host__ __device__ __inline__ __half -explicit_casting(float value) -{ - return __float2half(value); +__host__ __device__ __inline__ __half explicit_casting(float value) { + return __float2half(value); } template <> -__host__ __device__ __inline__ __hip_bfloat16 -explicit_casting<__half, __hip_bfloat16>(__half value) -{ - return __float2bfloat16(__half2float(value)); +__host__ __device__ __inline__ __hip_bfloat16 explicit_casting<__half, __hip_bfloat16>( + __half value) { + return __float2bfloat16(__half2float(value)); } template <> -__host__ __device__ __inline__ float explicit_casting(float value) -{ - return value; +__host__ __device__ __inline__ float explicit_casting(float value) { + return value; } template <> -__host__ __device__ __inline__ __half -explicit_casting<__half, __half>(__half value) -{ - return value; +__host__ __device__ __inline__ __half explicit_casting<__half, __half>(__half value) { + return value; } template <> -__host__ __device__ __inline__ __hip_bfloat16 -explicit_casting<__hip_bfloat16, __hip_bfloat16>(__hip_bfloat16 value) -{ - return value; +__host__ __device__ __inline__ __hip_bfloat16 explicit_casting<__hip_bfloat16, __hip_bfloat16>( + __hip_bfloat16 value) { + return value; } -} // namespace fi::con +} // namespace fi::con -namespace flashinfer -{ +namespace flashinfer { using cp_async::PrefetchMode; using cp_async::SharedMemFillMode; @@ -95,41 +81,32 @@ using cp_async::SharedMemFillMode; ///@param head_dim The dimension of each head. ///@note Both s_a and s_b are logsumexp values with base 2. template -__global__ void MergeStateKernel(DTypeIn *__restrict__ v_a, - float *__restrict__ s_a, - DTypeIn *__restrict__ v_b, - float *__restrict__ s_b, - DTypeO *__restrict__ v_merged, - float *__restrict__ s_merged, - uint32_t num_heads, - uint32_t head_dim) -{ - uint32_t tx = threadIdx.x, ty = threadIdx.y; - uint32_t pos = blockIdx.x; - uint32_t head_idx = ty; - - float s_a_val = s_a[pos * num_heads + head_idx]; - float s_b_val = s_b[pos * num_heads + head_idx]; - float s_max = max(s_a_val, s_b_val); - s_a_val = math::ptx_exp2(s_a_val - s_max); - s_b_val = math::ptx_exp2(s_b_val - s_max); - float a_scale = s_a_val / (s_a_val + s_b_val); - float b_scale = s_b_val / (s_a_val + s_b_val); - vec_t v_a_vec, v_b_vec, v_merged_vec; - v_a_vec.cast_load(v_a + (pos * num_heads + head_idx) * head_dim + - tx * vec_size); - v_b_vec.cast_load(v_b + (pos * num_heads + head_idx) * head_dim + - tx * vec_size); +__global__ void MergeStateKernel(DTypeIn* __restrict__ v_a, float* __restrict__ s_a, + DTypeIn* __restrict__ v_b, float* __restrict__ s_b, + DTypeO* __restrict__ v_merged, float* __restrict__ s_merged, + uint32_t num_heads, uint32_t head_dim) { + uint32_t tx = threadIdx.x, ty = threadIdx.y; + uint32_t pos = blockIdx.x; + uint32_t head_idx = ty; + + float s_a_val = s_a[pos * num_heads + head_idx]; + float s_b_val = s_b[pos * num_heads + head_idx]; + float s_max = max(s_a_val, s_b_val); + s_a_val = math::ptx_exp2(s_a_val - s_max); + s_b_val = math::ptx_exp2(s_b_val - s_max); + float a_scale = s_a_val / (s_a_val + s_b_val); + float b_scale = s_b_val / (s_a_val + s_b_val); + vec_t v_a_vec, v_b_vec, v_merged_vec; + v_a_vec.cast_load(v_a + (pos * num_heads + head_idx) * head_dim + tx * vec_size); + v_b_vec.cast_load(v_b + (pos * num_heads + head_idx) * head_dim + tx * vec_size); #pragma unroll - for (uint32_t i = 0; i < vec_size; ++i) { - v_merged_vec[i] = a_scale * v_a_vec[i] + b_scale * v_b_vec[i]; - } - v_merged_vec.cast_store(v_merged + (pos * num_heads + head_idx) * head_dim + - tx * vec_size); - if (s_merged != nullptr) { - s_merged[pos * num_heads + head_idx] = - math::ptx_log2(s_a_val + s_b_val) + s_max; - } + for (uint32_t i = 0; i < vec_size; ++i) { + v_merged_vec[i] = a_scale * v_a_vec[i] + b_scale * v_b_vec[i]; + } + v_merged_vec.cast_store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size); + if (s_merged != nullptr) { + s_merged[pos * num_heads + head_idx] = math::ptx_log2(s_a_val + s_b_val) + s_max; + } } ///@brief The CUDA kernel that merges the self-attention state with another @@ -145,189 +122,155 @@ __global__ void MergeStateKernel(DTypeIn *__restrict__ v_a, ///@param head_dim The dimension of each head. ///@note Both s and s_other are logsumexp values with base 2. template -__global__ void MergeStateInPlaceKernel(DType *__restrict__ v, - float *__restrict__ s, - DType *__restrict__ v_other, - float *__restrict__ s_other, - uint8_t *__restrict__ mask, - uint32_t num_heads, - uint32_t head_dim) -{ - uint32_t pos = blockIdx.x; - - if (mask != nullptr && mask[pos] == 0) - return; - - uint32_t tx = threadIdx.x, ty = threadIdx.y; - uint32_t head_idx = ty; - - float s_val = s[pos * num_heads + head_idx]; - float s_other_val = s_other[pos * num_heads + head_idx]; - float s_max = max(s_val, s_other_val); - s_val = math::ptx_exp2(s_val - s_max); - s_other_val = math::ptx_exp2(s_other_val - s_max); - float scale = s_val / (s_val + s_other_val); - float other_scale = s_other_val / (s_val + s_other_val); - vec_t v_vec, v_other_vec; - v_vec.cast_load(v + (pos * num_heads + head_idx) * head_dim + - tx * vec_size); - v_other_vec.cast_load(v_other + (pos * num_heads + head_idx) * head_dim + - tx * vec_size); +__global__ void MergeStateInPlaceKernel(DType* __restrict__ v, float* __restrict__ s, + DType* __restrict__ v_other, float* __restrict__ s_other, + uint8_t* __restrict__ mask, uint32_t num_heads, + uint32_t head_dim) { + uint32_t pos = blockIdx.x; + + if (mask != nullptr && mask[pos] == 0) return; + + uint32_t tx = threadIdx.x, ty = threadIdx.y; + uint32_t head_idx = ty; + + float s_val = s[pos * num_heads + head_idx]; + float s_other_val = s_other[pos * num_heads + head_idx]; + float s_max = max(s_val, s_other_val); + s_val = math::ptx_exp2(s_val - s_max); + s_other_val = math::ptx_exp2(s_other_val - s_max); + float scale = s_val / (s_val + s_other_val); + float other_scale = s_other_val / (s_val + s_other_val); + vec_t v_vec, v_other_vec; + v_vec.cast_load(v + (pos * num_heads + head_idx) * head_dim + tx * vec_size); + v_other_vec.cast_load(v_other + (pos * num_heads + head_idx) * head_dim + tx * vec_size); #pragma unroll - for (uint32_t i = 0; i < vec_size; ++i) { - v_vec[i] = scale * v_vec[i] + other_scale * v_other_vec[i]; - } - v_vec.cast_store(v + (pos * num_heads + head_idx) * head_dim + - tx * vec_size); - if (s != nullptr) { - s[pos * num_heads + head_idx] = - math::ptx_log2(s_val + s_other_val) + s_max; - } + for (uint32_t i = 0; i < vec_size; ++i) { + v_vec[i] = scale * v_vec[i] + other_scale * v_other_vec[i]; + } + v_vec.cast_store(v + (pos * num_heads + head_idx) * head_dim + tx * vec_size); + if (s != nullptr) { + s[pos * num_heads + head_idx] = math::ptx_log2(s_val + s_other_val) + s_max; + } } template -__device__ __forceinline__ void -threadblock_sync_state(state_t &st, DTypeIn *v_smem, float *s_smem) -{ - const uint32_t tx = threadIdx.x, ty = threadIdx.y; - constexpr uint32_t head_dim = vec_size * bdx; - st.o.cast_store(v_smem + ty * head_dim + tx * vec_size); - s_smem[ty] = st.get_lse(); - st.init(); - __syncthreads(); +__device__ __forceinline__ void threadblock_sync_state(state_t& st, DTypeIn* v_smem, + float* s_smem) { + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + constexpr uint32_t head_dim = vec_size * bdx; + st.o.cast_store(v_smem + ty * head_dim + tx * vec_size); + s_smem[ty] = st.get_lse(); + st.init(); + __syncthreads(); #pragma unroll - for (uint32_t iter = 0; iter < bdy; ++iter) { - float s = s_smem[iter]; - vec_t v; - v.cast_load(v_smem + iter * head_dim + tx * vec_size); - st.merge(v, s, 1); - } + for (uint32_t iter = 0; iter < bdy; ++iter) { + float s = s_smem[iter]; + vec_t v; + v.cast_load(v_smem + iter * head_dim + tx * vec_size); + st.merge(v, s, 1); + } } template -__device__ __forceinline__ void threadblock_sum(vec_t &v, - DTypeIn *v_smem) -{ - const uint32_t tx = threadIdx.x, ty = threadIdx.y; - constexpr uint32_t head_dim = vec_size * bdx; - v.cast_store(v_smem + ty * head_dim + tx * vec_size); - v.fill(DTypeIn(0.f)); - __syncthreads(); +__device__ __forceinline__ void threadblock_sum(vec_t& v, DTypeIn* v_smem) { + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + constexpr uint32_t head_dim = vec_size * bdx; + v.cast_store(v_smem + ty * head_dim + tx * vec_size); + v.fill(DTypeIn(0.f)); + __syncthreads(); #pragma unroll - for (uint32_t iter = 0; iter < bdy; ++iter) { - vec_t v_iter; - v_iter.cast_load(v_smem + iter * head_dim + tx * vec_size); + for (uint32_t iter = 0; iter < bdy; ++iter) { + vec_t v_iter; + v_iter.cast_load(v_smem + iter * head_dim + tx * vec_size); #pragma unroll - for (uint32_t i = 0; i < vec_size; ++i) { - v[i] += v_iter[i]; - } + for (uint32_t i = 0; i < vec_size; ++i) { + v[i] += v_iter[i]; } + } } template -__global__ void AttentionSumKernel(DTypeIn *__restrict__ V, - DTypeO *__restrict__ v_sum, - uint32_t num_index_sets, - uint32_t num_heads, - uint32_t head_dim) -{ - uint32_t tx = threadIdx.x, ty = threadIdx.y; - uint32_t pos = blockIdx.x; - uint32_t head_idx = ty; - - if (num_index_sets == 0) { - vec_t v; - v.fill(DTypeO(0.f)); - v.store(v_sum + (pos * num_heads + head_idx) * head_dim + - tx * vec_size); - return; - } - - if (num_index_sets == 1) { - vec_t v; - v.cast_load(V + (pos * num_heads + head_idx) * head_dim + - tx * vec_size); - v.store(v_sum + (pos * num_heads + head_idx) * head_dim + - tx * vec_size); - return; - } - - vec_t v_sum_vec; - v_sum_vec.fill(0.f); +__global__ void AttentionSumKernel(DTypeIn* __restrict__ V, DTypeO* __restrict__ v_sum, + uint32_t num_index_sets, uint32_t num_heads, uint32_t head_dim) { + uint32_t tx = threadIdx.x, ty = threadIdx.y; + uint32_t pos = blockIdx.x; + uint32_t head_idx = ty; + + if (num_index_sets == 0) { + vec_t v; + v.fill(DTypeO(0.f)); + v.store(v_sum + (pos * num_heads + head_idx) * head_dim + tx * vec_size); + return; + } + + if (num_index_sets == 1) { + vec_t v; + v.cast_load(V + (pos * num_heads + head_idx) * head_dim + tx * vec_size); + v.store(v_sum + (pos * num_heads + head_idx) * head_dim + tx * vec_size); + return; + } + + vec_t v_sum_vec; + v_sum_vec.fill(0.f); #pragma unroll 2 - for (uint32_t iter = 0; iter < num_index_sets; ++iter) { - vec_t v; - v.cast_load(V + - ((pos * num_index_sets + iter) * num_heads + head_idx) * - head_dim + - tx * vec_size); + for (uint32_t iter = 0; iter < num_index_sets; ++iter) { + vec_t v; + v.cast_load(V + ((pos * num_index_sets + iter) * num_heads + head_idx) * head_dim + + tx * vec_size); #pragma unroll - for (uint32_t i = 0; i < vec_size; ++i) { - v_sum_vec[i] += v[i]; - } + for (uint32_t i = 0; i < vec_size; ++i) { + v_sum_vec[i] += v[i]; } + } - v_sum_vec.cast_store(v_sum + (pos * num_heads + head_idx) * head_dim + - tx * vec_size); + v_sum_vec.cast_store(v_sum + (pos * num_heads + head_idx) * head_dim + tx * vec_size); } template -__global__ void MergeStatesKernel(DTypeIn *__restrict__ V, - float *__restrict__ S, - DTypeO *__restrict__ v_merged, - float *__restrict__ s_merged, - uint32_t num_index_sets, - uint32_t num_heads, - uint32_t head_dim) -{ - uint32_t tx = threadIdx.x, ty = threadIdx.y; - uint32_t pos = blockIdx.x; - uint32_t head_idx = ty; - - if (num_index_sets == 0) { - vec_t v; - v.fill(fi::con::explicit_casting(0.0f)); - v.store(v_merged + (pos * num_heads + head_idx) * head_dim + - tx * vec_size); - if (s_merged != nullptr) { - s_merged[pos * num_heads + head_idx] = -math::inf; - } - return; +__global__ void MergeStatesKernel(DTypeIn* __restrict__ V, float* __restrict__ S, + DTypeO* __restrict__ v_merged, float* __restrict__ s_merged, + uint32_t num_index_sets, uint32_t num_heads, uint32_t head_dim) { + uint32_t tx = threadIdx.x, ty = threadIdx.y; + uint32_t pos = blockIdx.x; + uint32_t head_idx = ty; + + if (num_index_sets == 0) { + vec_t v; + v.fill(fi::con::explicit_casting(0.0f)); + v.store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size); + if (s_merged != nullptr) { + s_merged[pos * num_heads + head_idx] = -math::inf; } + return; + } - if (num_index_sets == 1) { - vec_t v; - v.cast_load(V + (pos * num_heads + head_idx) * head_dim + - tx * vec_size); - v.store(v_merged + (pos * num_heads + head_idx) * head_dim + - tx * vec_size); - if (s_merged != nullptr) { - s_merged[pos * num_heads + head_idx] = - S[pos * num_heads + head_idx]; - } - return; + if (num_index_sets == 1) { + vec_t v; + v.cast_load(V + (pos * num_heads + head_idx) * head_dim + tx * vec_size); + v.store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size); + if (s_merged != nullptr) { + s_merged[pos * num_heads + head_idx] = S[pos * num_heads + head_idx]; } + return; + } - state_t st; + state_t st; #pragma unroll 2 - for (uint32_t iter = 0; iter < num_index_sets; ++iter) { - float s = S[(pos * num_index_sets + iter) * num_heads + head_idx]; - vec_t v; - v.cast_load(V + - ((pos * num_index_sets + iter) * num_heads + head_idx) * - head_dim + - tx * vec_size); - st.merge(v, s, 1); - } - - st.normalize(); - st.o.cast_store(v_merged + (pos * num_heads + head_idx) * head_dim + - tx * vec_size); - if (s_merged != nullptr) { - s_merged[pos * num_heads + head_idx] = st.get_lse(); - } + for (uint32_t iter = 0; iter < num_index_sets; ++iter) { + float s = S[(pos * num_index_sets + iter) * num_heads + head_idx]; + vec_t v; + v.cast_load(V + ((pos * num_index_sets + iter) * num_heads + head_idx) * head_dim + + tx * vec_size); + st.merge(v, s, 1); + } + + st.normalize(); + st.o.cast_store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size); + if (s_merged != nullptr) { + s_merged[pos * num_heads + head_idx] = st.get_lse(); + } } ///@brief The CUDA kernel that merges self-attention states of a list of index @@ -346,92 +289,71 @@ __global__ void MergeStatesKernel(DTypeIn *__restrict__ V, ///@param num_heads The number of heads of v. ///@param head_dim The dimension of each head. ///@note s are logsumexp values with base 2. -template -__global__ void -MergeStatesLargeNumIndexSetsKernel(DTypeIn *__restrict__ V, - float *__restrict__ S, - DTypeO *__restrict__ v_merged, - float *__restrict__ s_merged, - uint32_t num_index_sets, - uint32_t num_heads) -{ - uint32_t tx = threadIdx.x, ty = threadIdx.y; - uint32_t pos = blockIdx.x; - uint32_t head_idx = blockIdx.y; - state_t st; - constexpr uint32_t vec_bits = sizeof(DTypeIn) * vec_size * 8; - constexpr uint32_t head_dim = vec_size * bdx; - - extern __shared__ uint8_t smem[]; - DTypeIn *v_smem = (DTypeIn *)smem; - float *s_smem = - (float *)(smem + num_smem_stages * bdy * head_dim * sizeof(DTypeIn)); +__global__ void MergeStatesLargeNumIndexSetsKernel(DTypeIn* __restrict__ V, float* __restrict__ S, + DTypeO* __restrict__ v_merged, + float* __restrict__ s_merged, + uint32_t num_index_sets, uint32_t num_heads) { + uint32_t tx = threadIdx.x, ty = threadIdx.y; + uint32_t pos = blockIdx.x; + uint32_t head_idx = blockIdx.y; + state_t st; + constexpr uint32_t vec_bits = sizeof(DTypeIn) * vec_size * 8; + constexpr uint32_t head_dim = vec_size * bdx; + + extern __shared__ uint8_t smem[]; + DTypeIn* v_smem = (DTypeIn*)smem; + float* s_smem = (float*)(smem + num_smem_stages * bdy * head_dim * sizeof(DTypeIn)); #pragma unroll - for (uint32_t iter = 0; iter < num_smem_stages; ++iter) { - cp_async::pred_load( - v_smem + (iter * bdy + ty) * head_dim + tx * vec_size, - V + - ((pos * num_index_sets + (iter * bdy + ty)) * num_heads + - head_idx) * - head_dim + - tx * vec_size, - (iter * bdy + ty) < num_index_sets); - cp_async::commit_group(); - } + for (uint32_t iter = 0; iter < num_smem_stages; ++iter) { + cp_async::pred_load( + v_smem + (iter * bdy + ty) * head_dim + tx * vec_size, + V + ((pos * num_index_sets + (iter * bdy + ty)) * num_heads + head_idx) * head_dim + + tx * vec_size, + (iter * bdy + ty) < num_index_sets); + cp_async::commit_group(); + } #pragma unroll 4 - for (uint32_t iter = 0; iter < ceil_div(num_index_sets, bdy); ++iter) { - if (iter % bdx == 0) { - s_smem[ty * bdx + tx] = - iter * bdy + (ty * bdx + tx) < num_index_sets - ? S[(pos * num_index_sets + (iter * bdy + ty * bdx + tx)) * - num_heads + - head_idx] - : 0.f; - __syncthreads(); - } - cp_async::wait_group(); - __syncthreads(); - vec_t v; - v.cast_load(v_smem + ((iter % num_smem_stages) * bdy + ty) * head_dim + - tx * vec_size); - if (iter * bdy + ty < num_index_sets) { - float s = s_smem[(iter % bdx) * bdy + ty]; - st.merge(v, s, 1); - } - __syncthreads(); - cp_async::pred_load( - v_smem + ((iter % num_smem_stages) * bdy + ty) * head_dim + - tx * vec_size, - V + - ((pos * num_index_sets + - ((iter + num_smem_stages) * bdy + ty)) * - num_heads + - head_idx) * - head_dim + - tx * vec_size, - (iter + num_smem_stages) * bdy + ty < num_index_sets); - cp_async::commit_group(); + for (uint32_t iter = 0; iter < ceil_div(num_index_sets, bdy); ++iter) { + if (iter % bdx == 0) { + s_smem[ty * bdx + tx] = + iter * bdy + (ty * bdx + tx) < num_index_sets + ? S[(pos * num_index_sets + (iter * bdy + ty * bdx + tx)) * num_heads + head_idx] + : 0.f; + __syncthreads(); } - cp_async::wait_group<0>(); + cp_async::wait_group(); __syncthreads(); - - st.normalize(); - threadblock_sync_state(st, v_smem, s_smem); - st.normalize(); - - st.o.cast_store(v_merged + (pos * num_heads + head_idx) * head_dim + - tx * vec_size); - if (s_merged != nullptr) { - s_merged[pos * num_heads + head_idx] = st.get_lse(); + vec_t v; + v.cast_load(v_smem + ((iter % num_smem_stages) * bdy + ty) * head_dim + tx * vec_size); + if (iter * bdy + ty < num_index_sets) { + float s = s_smem[(iter % bdx) * bdy + ty]; + st.merge(v, s, 1); } + __syncthreads(); + cp_async::pred_load( + v_smem + ((iter % num_smem_stages) * bdy + ty) * head_dim + tx * vec_size, + V + + ((pos * num_index_sets + ((iter + num_smem_stages) * bdy + ty)) * num_heads + + head_idx) * + head_dim + + tx * vec_size, + (iter + num_smem_stages) * bdy + ty < num_index_sets); + cp_async::commit_group(); + } + cp_async::wait_group<0>(); + __syncthreads(); + + st.normalize(); + threadblock_sync_state(st, v_smem, s_smem); + st.normalize(); + + st.o.cast_store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size); + if (s_merged != nullptr) { + s_merged[pos * num_heads + head_idx] = st.get_lse(); + } } ///@brief The CUDA kernel to merge self-attention states of multiple index sets, @@ -461,226 +383,175 @@ MergeStatesLargeNumIndexSetsKernel(DTypeIn *__restrict__ V, ///@param num_heads The number of heads of v. ///@param head_dim The dimension of each head. ///@note s are logsumexp values with base 2. -template -__global__ void -PersistentVariableLengthMergeStatesKernel(DTypeIn *__restrict__ V, - float *__restrict__ S, - IdType *indptr, - DTypeO *__restrict__ v_merged, - float *__restrict__ s_merged, - uint32_t max_seq_len, - uint32_t *__restrict__ seq_len_ptr, - uint32_t num_heads) -{ - uint32_t tx = threadIdx.x, ty = threadIdx.y; - uint32_t cta_id = blockIdx.x; - uint32_t num_ctas = gridDim.x; - const uint32_t seq_len = seq_len_ptr ? *seq_len_ptr : max_seq_len; - uint32_t num_iters = ceil_div(seq_len * num_heads, num_ctas); - constexpr uint32_t vec_bits = sizeof(DTypeIn) * vec_size * 8; - constexpr uint32_t head_dim = vec_size * bdx; - extern __shared__ uint8_t smem[]; - DTypeIn *v_smem = (DTypeIn *)smem; - float *s_smem = - (float *)(smem + num_smem_stages * bdy * head_dim * sizeof(DTypeIn)); +template +__global__ void PersistentVariableLengthMergeStatesKernel( + DTypeIn* __restrict__ V, float* __restrict__ S, IdType* indptr, DTypeO* __restrict__ v_merged, + float* __restrict__ s_merged, uint32_t max_seq_len, uint32_t* __restrict__ seq_len_ptr, + uint32_t num_heads) { + uint32_t tx = threadIdx.x, ty = threadIdx.y; + uint32_t cta_id = blockIdx.x; + uint32_t num_ctas = gridDim.x; + const uint32_t seq_len = seq_len_ptr ? *seq_len_ptr : max_seq_len; + uint32_t num_iters = ceil_div(seq_len * num_heads, num_ctas); + constexpr uint32_t vec_bits = sizeof(DTypeIn) * vec_size * 8; + constexpr uint32_t head_dim = vec_size * bdx; + extern __shared__ uint8_t smem[]; + DTypeIn* v_smem = (DTypeIn*)smem; + float* s_smem = (float*)(smem + num_smem_stages * bdy * head_dim * sizeof(DTypeIn)); #pragma unroll 1 - for (uint32_t i = cta_id; i < seq_len * num_heads; i += num_ctas) { - uint32_t pos = i / num_heads; - uint32_t head_idx = i % num_heads; - state_t st; - const uint32_t num_index_sets = indptr[pos + 1] - indptr[pos]; - - if (num_index_sets == 0) { - vec_t v; - v.fill(fi::con::explicit_casting(0.0f)); - v.store(v_merged + (pos * num_heads + head_idx) * head_dim + - tx * vec_size); - if (s_merged != nullptr) { - s_merged[pos * num_heads + head_idx] = -math::inf; - } - continue; - } + for (uint32_t i = cta_id; i < seq_len * num_heads; i += num_ctas) { + uint32_t pos = i / num_heads; + uint32_t head_idx = i % num_heads; + state_t st; + const uint32_t num_index_sets = indptr[pos + 1] - indptr[pos]; - if (num_index_sets == 1) { - vec_t v; - v.cast_load(V + (indptr[pos] * num_heads + head_idx) * head_dim + - tx * vec_size); - v.store(v_merged + (pos * num_heads + head_idx) * head_dim + - tx * vec_size); - if (s_merged != nullptr) { - s_merged[pos * num_heads + head_idx] = - S[indptr[pos] * num_heads + head_idx]; - } - continue; - } + if (num_index_sets == 0) { + vec_t v; + v.fill(fi::con::explicit_casting(0.0f)); + v.store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size); + if (s_merged != nullptr) { + s_merged[pos * num_heads + head_idx] = -math::inf; + } + continue; + } + + if (num_index_sets == 1) { + vec_t v; + v.cast_load(V + (indptr[pos] * num_heads + head_idx) * head_dim + tx * vec_size); + v.store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size); + if (s_merged != nullptr) { + s_merged[pos * num_heads + head_idx] = S[indptr[pos] * num_heads + head_idx]; + } + continue; + } #pragma unroll - for (uint32_t iter = 0; iter < num_smem_stages; ++iter) { - cp_async::pred_load( - v_smem + (iter * bdy + ty) * head_dim + tx * vec_size, - V + - ((indptr[pos] + (iter * bdy + ty)) * num_heads + head_idx) * - head_dim + - tx * vec_size, - (iter * bdy + ty) < num_index_sets); - cp_async::commit_group(); - } + for (uint32_t iter = 0; iter < num_smem_stages; ++iter) { + cp_async::pred_load( + v_smem + (iter * bdy + ty) * head_dim + tx * vec_size, + V + ((indptr[pos] + (iter * bdy + ty)) * num_heads + head_idx) * head_dim + tx * vec_size, + (iter * bdy + ty) < num_index_sets); + cp_async::commit_group(); + } #pragma unroll 4 - for (uint32_t iter = 0; iter < ceil_div(num_index_sets, bdy); ++iter) { - if (iter % bdx == 0) { - s_smem[ty * bdx + tx] = - iter * bdy + (ty * bdx + tx) < num_index_sets - ? S[(indptr[pos] + (iter * bdy + ty * bdx + tx)) * - num_heads + - head_idx] - : 0.f; - __syncthreads(); - } - cp_async::wait_group(); - __syncthreads(); - vec_t v; - v.cast_load(v_smem + - ((iter % num_smem_stages) * bdy + ty) * head_dim + - tx * vec_size); - if (iter * bdy + ty < num_index_sets) { - float s = s_smem[(iter % bdx) * bdy + ty]; - st.merge(v, s, 1); - } - __syncthreads(); - cp_async::pred_load( - v_smem + ((iter % num_smem_stages) * bdy + ty) * head_dim + - tx * vec_size, - V + - ((indptr[pos] + ((iter + num_smem_stages) * bdy + ty)) * - num_heads + - head_idx) * - head_dim + - tx * vec_size, - (iter + num_smem_stages) * bdy + ty < num_index_sets); - cp_async::commit_group(); - } - cp_async::wait_group<0>(); + for (uint32_t iter = 0; iter < ceil_div(num_index_sets, bdy); ++iter) { + if (iter % bdx == 0) { + s_smem[ty * bdx + tx] = + iter * bdy + (ty * bdx + tx) < num_index_sets + ? S[(indptr[pos] + (iter * bdy + ty * bdx + tx)) * num_heads + head_idx] + : 0.f; __syncthreads(); + } + cp_async::wait_group(); + __syncthreads(); + vec_t v; + v.cast_load(v_smem + ((iter % num_smem_stages) * bdy + ty) * head_dim + tx * vec_size); + if (iter * bdy + ty < num_index_sets) { + float s = s_smem[(iter % bdx) * bdy + ty]; + st.merge(v, s, 1); + } + __syncthreads(); + cp_async::pred_load( + v_smem + ((iter % num_smem_stages) * bdy + ty) * head_dim + tx * vec_size, + V + + ((indptr[pos] + ((iter + num_smem_stages) * bdy + ty)) * num_heads + head_idx) * + head_dim + + tx * vec_size, + (iter + num_smem_stages) * bdy + ty < num_index_sets); + cp_async::commit_group(); + } + cp_async::wait_group<0>(); + __syncthreads(); - st.normalize(); - threadblock_sync_state(st, v_smem, s_smem); - st.normalize(); + st.normalize(); + threadblock_sync_state(st, v_smem, s_smem); + st.normalize(); - st.o.cast_store(v_merged + (pos * num_heads + head_idx) * head_dim + - tx * vec_size); - if (s_merged != nullptr) { - s_merged[pos * num_heads + head_idx] = st.get_lse(); - } + st.o.cast_store(v_merged + (pos * num_heads + head_idx) * head_dim + tx * vec_size); + if (s_merged != nullptr) { + s_merged[pos * num_heads + head_idx] = st.get_lse(); } + } } -template -__global__ void -PersistentVariableLengthAttentionSumKernel(DTypeIn *__restrict__ V, - IdType *indptr, - DTypeO *__restrict__ v_sum, - uint32_t max_seq_len, - uint32_t *__restrict__ seq_len_ptr, - uint32_t num_heads) -{ - uint32_t tx = threadIdx.x, ty = threadIdx.y; - uint32_t cta_id = blockIdx.x; - uint32_t num_ctas = gridDim.x; - const uint32_t seq_len = seq_len_ptr ? *seq_len_ptr : max_seq_len; - uint32_t num_iters = ceil_div(seq_len * num_heads, num_ctas); - constexpr uint32_t vec_bits = sizeof(DTypeIn) * vec_size * 8; - constexpr uint32_t head_dim = vec_size * bdx; - extern __shared__ uint8_t smem[]; - DTypeIn *v_smem = (DTypeIn *)smem; - - vec_t v_sum_vec; +template +__global__ void PersistentVariableLengthAttentionSumKernel(DTypeIn* __restrict__ V, IdType* indptr, + DTypeO* __restrict__ v_sum, + uint32_t max_seq_len, + uint32_t* __restrict__ seq_len_ptr, + uint32_t num_heads) { + uint32_t tx = threadIdx.x, ty = threadIdx.y; + uint32_t cta_id = blockIdx.x; + uint32_t num_ctas = gridDim.x; + const uint32_t seq_len = seq_len_ptr ? *seq_len_ptr : max_seq_len; + uint32_t num_iters = ceil_div(seq_len * num_heads, num_ctas); + constexpr uint32_t vec_bits = sizeof(DTypeIn) * vec_size * 8; + constexpr uint32_t head_dim = vec_size * bdx; + extern __shared__ uint8_t smem[]; + DTypeIn* v_smem = (DTypeIn*)smem; + + vec_t v_sum_vec; #pragma unroll 1 - for (uint32_t i = cta_id; i < seq_len * num_heads; i += num_ctas) { - uint32_t pos = i / num_heads; - uint32_t head_idx = i % num_heads; - const uint32_t num_index_sets = indptr[pos + 1] - indptr[pos]; - - if (num_index_sets == 0) { - vec_t v; - v.fill(DTypeO(0.f)); - v.store(v_sum + (pos * num_heads + head_idx) * head_dim + - tx * vec_size); - continue; - } + for (uint32_t i = cta_id; i < seq_len * num_heads; i += num_ctas) { + uint32_t pos = i / num_heads; + uint32_t head_idx = i % num_heads; + const uint32_t num_index_sets = indptr[pos + 1] - indptr[pos]; - if (num_index_sets == 1) { - vec_t v; - v.cast_load(V + (indptr[pos] * num_heads + head_idx) * head_dim + - tx * vec_size); - v.store(v_sum + (pos * num_heads + head_idx) * head_dim + - tx * vec_size); - continue; - } + if (num_index_sets == 0) { + vec_t v; + v.fill(DTypeO(0.f)); + v.store(v_sum + (pos * num_heads + head_idx) * head_dim + tx * vec_size); + continue; + } + + if (num_index_sets == 1) { + vec_t v; + v.cast_load(V + (indptr[pos] * num_heads + head_idx) * head_dim + tx * vec_size); + v.store(v_sum + (pos * num_heads + head_idx) * head_dim + tx * vec_size); + continue; + } #pragma unroll - for (uint32_t iter = 0; iter < num_smem_stages; ++iter) { - cp_async::pred_load( - v_smem + (iter * bdy + ty) * head_dim + tx * vec_size, - V + - ((indptr[pos] + (iter * bdy + ty)) * num_heads + head_idx) * - head_dim + - tx * vec_size, - (iter * bdy + ty) < num_index_sets); - cp_async::commit_group(); - } + for (uint32_t iter = 0; iter < num_smem_stages; ++iter) { + cp_async::pred_load( + v_smem + (iter * bdy + ty) * head_dim + tx * vec_size, + V + ((indptr[pos] + (iter * bdy + ty)) * num_heads + head_idx) * head_dim + tx * vec_size, + (iter * bdy + ty) < num_index_sets); + cp_async::commit_group(); + } #pragma unroll 4 - for (uint32_t iter = 0; iter < ceil_div(num_index_sets, bdy); ++iter) { - cp_async::wait_group(); - __syncthreads(); - vec_t v; - v.cast_load(v_smem + - ((iter % num_smem_stages) * bdy + ty) * head_dim + - tx * vec_size); - if (iter * bdy + ty < num_index_sets) { + for (uint32_t iter = 0; iter < ceil_div(num_index_sets, bdy); ++iter) { + cp_async::wait_group(); + __syncthreads(); + vec_t v; + v.cast_load(v_smem + ((iter % num_smem_stages) * bdy + ty) * head_dim + tx * vec_size); + if (iter * bdy + ty < num_index_sets) { #pragma unroll - for (uint32_t i = 0; i < vec_size; ++i) { - v_sum_vec[i] += v[i]; - } - } - __syncthreads(); - cp_async::pred_load( - v_smem + ((iter % num_smem_stages) * bdy + ty) * head_dim + - tx * vec_size, - V + - ((indptr[pos] + ((iter + num_smem_stages) * bdy + ty)) * - num_heads + - head_idx) * - head_dim + - tx * vec_size, - (iter + num_smem_stages) * bdy + ty < num_index_sets); - cp_async::commit_group(); + for (uint32_t i = 0; i < vec_size; ++i) { + v_sum_vec[i] += v[i]; } - cp_async::wait_group<0>(); - __syncthreads(); + } + __syncthreads(); + cp_async::pred_load( + v_smem + ((iter % num_smem_stages) * bdy + ty) * head_dim + tx * vec_size, + V + + ((indptr[pos] + ((iter + num_smem_stages) * bdy + ty)) * num_heads + head_idx) * + head_dim + + tx * vec_size, + (iter + num_smem_stages) * bdy + ty < num_index_sets); + cp_async::commit_group(); + } + cp_async::wait_group<0>(); + __syncthreads(); - threadblock_sum(v_sum_vec, v_smem); + threadblock_sum(v_sum_vec, v_smem); - v_sum_vec.cast_store(v_sum + (pos * num_heads + head_idx) * head_dim + - tx * vec_size); - } + v_sum_vec.cast_store(v_sum + (pos * num_heads + head_idx) * head_dim + tx * vec_size); + } } ///@brief Merge the self-attention state of two index sets A and B. @@ -699,29 +570,19 @@ PersistentVariableLengthAttentionSumKernel(DTypeIn *__restrict__ V, ///@return status Indicates whether CUDA calls are successful ///@note Both s_a and s_b are logsumexp values with base 2. template -hipError_t MergeState(DTypeIn *v_a, - float *s_a, - DTypeIn *v_b, - float *s_b, - DTypeO *v_merged, - float *s_merged, - uint32_t seq_len, - uint32_t num_heads, - uint32_t head_dim, - hipStream_t stream = nullptr) -{ - DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { - constexpr uint32_t vec_size = - std::max(16U / sizeof(DTypeIn), HEAD_DIM / 32U); - uint32_t bdx = HEAD_DIM / vec_size; - uint32_t bdy = num_heads; - dim3 nblks(seq_len); - dim3 nthrs(bdx, bdy); - MergeStateKernel - <<>>(v_a, s_a, v_b, s_b, v_merged, - s_merged, num_heads, head_dim); - }); - return hipSuccess; +hipError_t MergeState(DTypeIn* v_a, float* s_a, DTypeIn* v_b, float* s_b, DTypeO* v_merged, + float* s_merged, uint32_t seq_len, uint32_t num_heads, uint32_t head_dim, + hipStream_t stream = nullptr) { + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + constexpr uint32_t vec_size = std::max(16U / sizeof(DTypeIn), HEAD_DIM / 32U); + uint32_t bdx = HEAD_DIM / vec_size; + uint32_t bdy = num_heads; + dim3 nblks(seq_len); + dim3 nthrs(bdx, bdy); + MergeStateKernel + <<>>(v_a, s_a, v_b, s_b, v_merged, s_merged, num_heads, head_dim); + }); + return hipSuccess; } ///@brief Merge the self-attention state with another state in place. @@ -738,28 +599,20 @@ hipError_t MergeState(DTypeIn *v_a, ///@return status Indicates whether CUDA calls are successful ///@note Both s and s_other are logsumexp values with base 2. template -hipError_t MergeStateInPlace(DType *v, - float *s, - DType *v_other, - float *s_other, - uint32_t seq_len, - uint32_t num_heads, - uint32_t head_dim, - uint8_t *mask = nullptr, - hipStream_t stream = nullptr) -{ - DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { - constexpr uint32_t vec_size = - std::max(16U / sizeof(DType), HEAD_DIM / 32U); - uint32_t bdx = HEAD_DIM / vec_size; - uint32_t bdy = num_heads; - dim3 nblks(seq_len); - dim3 nthrs(bdx, bdy); - auto kernel = MergeStateInPlaceKernel; - MergeStateInPlaceKernel<<>>( - v, s, v_other, s_other, mask, num_heads, head_dim); - }); - return hipSuccess; +hipError_t MergeStateInPlace(DType* v, float* s, DType* v_other, float* s_other, uint32_t seq_len, + uint32_t num_heads, uint32_t head_dim, uint8_t* mask = nullptr, + hipStream_t stream = nullptr) { + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + constexpr uint32_t vec_size = std::max(16U / sizeof(DType), HEAD_DIM / 32U); + uint32_t bdx = HEAD_DIM / vec_size; + uint32_t bdy = num_heads; + dim3 nblks(seq_len); + dim3 nthrs(bdx, bdy); + auto kernel = MergeStateInPlaceKernel; + MergeStateInPlaceKernel + <<>>(v, s, v_other, s_other, mask, num_heads, head_dim); + }); + return hipSuccess; } ///@brief Merge self-attention states of a list of index sets. @@ -777,171 +630,125 @@ hipError_t MergeStateInPlace(DType *v, ///@return status Indicates whether CUDA calls are successful ///@note s are logsumexp values with base 2. template -hipError_t MergeStates(DTypeIn *v, - float *s, - DTypeO *v_merged, - float *s_merged, - uint32_t num_index_sets, - uint32_t seq_len, - uint32_t num_heads, - uint32_t head_dim, - hipStream_t stream = nullptr) -{ - DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { - constexpr uint32_t vec_size = - std::max(16U / sizeof(DTypeIn), HEAD_DIM / 32U); - constexpr uint32_t bdx = HEAD_DIM / vec_size; - if (num_index_sets >= seq_len) { - constexpr uint32_t num_threads = 128; - constexpr uint32_t bdy = num_threads / bdx; - dim3 nblks(seq_len, num_heads); - dim3 nthrs(bdx, bdy); - constexpr uint32_t num_smem_stages = 4; - uint32_t smem_size = - num_smem_stages * bdy * head_dim * sizeof(DTypeIn) + - num_threads * sizeof(float); - auto kernel = MergeStatesLargeNumIndexSetsKernel< - vec_size, bdx, bdy, num_smem_stages, DTypeIn, DTypeO>; - CHECK_HIP_ERROR(hipFuncSetAttribute( - (void *)kernel, hipFuncAttributeMaxDynamicSharedMemorySize, - smem_size)); - MergeStatesLargeNumIndexSetsKernel - <<>>( - v, s, v_merged, s_merged, num_index_sets, num_heads); - } +hipError_t MergeStates(DTypeIn* v, float* s, DTypeO* v_merged, float* s_merged, + uint32_t num_index_sets, uint32_t seq_len, uint32_t num_heads, + uint32_t head_dim, hipStream_t stream = nullptr) { + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + constexpr uint32_t vec_size = std::max(16U / sizeof(DTypeIn), HEAD_DIM / 32U); + constexpr uint32_t bdx = HEAD_DIM / vec_size; + if (num_index_sets >= seq_len) { + constexpr uint32_t num_threads = 128; + constexpr uint32_t bdy = num_threads / bdx; + dim3 nblks(seq_len, num_heads); + dim3 nthrs(bdx, bdy); + constexpr uint32_t num_smem_stages = 4; + uint32_t smem_size = + num_smem_stages * bdy * head_dim * sizeof(DTypeIn) + num_threads * sizeof(float); + auto kernel = + MergeStatesLargeNumIndexSetsKernel; + CHECK_HIP_ERROR(hipFuncSetAttribute((void*)kernel, hipFuncAttributeMaxDynamicSharedMemorySize, + smem_size)); + MergeStatesLargeNumIndexSetsKernel + <<>>(v, s, v_merged, s_merged, num_index_sets, + num_heads); + } - else { - uint32_t bdy = num_heads; - dim3 nblks(seq_len); - dim3 nthrs(bdx, bdy); - MergeStatesKernel - <<>>(v, s, v_merged, s_merged, - num_index_sets, num_heads, - head_dim); - } - }); - return hipSuccess; + else { + uint32_t bdy = num_heads; + dim3 nblks(seq_len); + dim3 nthrs(bdx, bdy); + MergeStatesKernel<<>>( + v, s, v_merged, s_merged, num_index_sets, num_heads, head_dim); + } + }); + return hipSuccess; } template -hipError_t AttentionSum(DTypeIn *v, - DTypeO *v_sum, - uint32_t num_index_sets, - uint32_t seq_len, - uint32_t num_heads, - uint32_t head_dim, - hipStream_t stream = nullptr) -{ - DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { - constexpr uint32_t vec_size = - std::max(16U / sizeof(DTypeIn), HEAD_DIM / 32U); - constexpr uint32_t bdx = HEAD_DIM / vec_size; - uint32_t bdy = num_heads; - dim3 nblks(seq_len); - dim3 nthrs(bdx, bdy); - AttentionSumKernel - <<>>(v, v_sum, num_index_sets, num_heads, - head_dim); - }); - return hipSuccess; +hipError_t AttentionSum(DTypeIn* v, DTypeO* v_sum, uint32_t num_index_sets, uint32_t seq_len, + uint32_t num_heads, uint32_t head_dim, hipStream_t stream = nullptr) { + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + constexpr uint32_t vec_size = std::max(16U / sizeof(DTypeIn), HEAD_DIM / 32U); + constexpr uint32_t bdx = HEAD_DIM / vec_size; + uint32_t bdy = num_heads; + dim3 nblks(seq_len); + dim3 nthrs(bdx, bdy); + AttentionSumKernel + <<>>(v, v_sum, num_index_sets, num_heads, head_dim); + }); + return hipSuccess; } template -hipError_t VariableLengthMergeStates(DTypeIn *v, - float *s, - IdType *indptr, - DTypeO *v_merged, - float *s_merged, - uint32_t max_seq_len, - uint32_t *seq_len, - uint32_t num_heads, - uint32_t head_dim, - hipStream_t stream = nullptr) -{ - int dev_id = 0; - int num_sms = 0; - int num_blocks_per_sm = 0; - CHECK_HIP_ERROR(hipGetDevice(&dev_id)); - CHECK_HIP_ERROR(hipDeviceGetAttribute( - &num_sms, hipDeviceAttributeMultiprocessorCount, dev_id)); - - DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { - constexpr uint32_t vec_size = - std::max(16U / sizeof(DTypeIn), HEAD_DIM / 32U); - constexpr uint32_t bdx = HEAD_DIM / vec_size; - constexpr uint32_t num_threads = 128; - constexpr uint32_t bdy = num_threads / bdx; - constexpr uint32_t num_smem_stages = 4; - uint32_t smem_size = - num_smem_stages * bdy * head_dim * sizeof(DTypeIn) + - num_threads * sizeof(float); - auto kernel = PersistentVariableLengthMergeStatesKernel< - vec_size, bdx, bdy, num_smem_stages, DTypeIn, DTypeO, IdType>; - CHECK_HIP_ERROR(hipOccupancyMaxActiveBlocksPerMultiprocessor( - &num_blocks_per_sm, kernel, num_threads, smem_size)); - num_blocks_per_sm = - min(num_blocks_per_sm, ceil_div(max_seq_len * num_heads, num_sms)); - - dim3 nblks(num_sms * num_blocks_per_sm); - dim3 nthrs(bdx, bdy); - CHECK_HIP_ERROR(hipFuncSetAttribute( - (void *)kernel, hipFuncAttributeMaxDynamicSharedMemorySize, - smem_size)); - PersistentVariableLengthMergeStatesKernel< - vec_size, bdx, bdy, num_smem_stages, DTypeIn, DTypeO, IdType> - <<>>(v, s, indptr, v_merged, - s_merged, max_seq_len, - seq_len, num_heads); - }); - return hipSuccess; +hipError_t VariableLengthMergeStates(DTypeIn* v, float* s, IdType* indptr, DTypeO* v_merged, + float* s_merged, uint32_t max_seq_len, uint32_t* seq_len, + uint32_t num_heads, uint32_t head_dim, + hipStream_t stream = nullptr) { + int dev_id = 0; + int num_sms = 0; + int num_blocks_per_sm = 0; + CHECK_HIP_ERROR(hipGetDevice(&dev_id)); + CHECK_HIP_ERROR(hipDeviceGetAttribute(&num_sms, hipDeviceAttributeMultiprocessorCount, dev_id)); + + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + constexpr uint32_t vec_size = std::max(16U / sizeof(DTypeIn), HEAD_DIM / 32U); + constexpr uint32_t bdx = HEAD_DIM / vec_size; + constexpr uint32_t num_threads = 128; + constexpr uint32_t bdy = num_threads / bdx; + constexpr uint32_t num_smem_stages = 4; + uint32_t smem_size = + num_smem_stages * bdy * head_dim * sizeof(DTypeIn) + num_threads * sizeof(float); + auto kernel = PersistentVariableLengthMergeStatesKernel; + CHECK_HIP_ERROR(hipOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks_per_sm, kernel, + num_threads, smem_size)); + num_blocks_per_sm = min(num_blocks_per_sm, ceil_div(max_seq_len * num_heads, num_sms)); + + dim3 nblks(num_sms * num_blocks_per_sm); + dim3 nthrs(bdx, bdy); + CHECK_HIP_ERROR( + hipFuncSetAttribute((void*)kernel, hipFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + PersistentVariableLengthMergeStatesKernel<<>>( + v, s, indptr, v_merged, s_merged, max_seq_len, seq_len, num_heads); + }); + return hipSuccess; } template -hipError_t VariableLengthAttentionSum(DTypeIn *v, - IdType *indptr, - DTypeO *v_sum, - uint32_t max_seq_len, - uint32_t *seq_len, - uint32_t num_heads, - uint32_t head_dim, - hipStream_t stream = nullptr) -{ - int dev_id = 0; - int num_sms = 0; - int num_blocks_per_sm = 0; - CHECK_HIP_ERROR(hipGetDevice(&dev_id)); - CHECK_HIP_ERROR(hipDeviceGetAttribute( - &num_sms, hipDeviceAttributeMultiprocessorCount, dev_id)); - - DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { - constexpr uint32_t vec_size = - std::max(16U / sizeof(DTypeIn), HEAD_DIM / 32U); - constexpr uint32_t bdx = HEAD_DIM / vec_size; - constexpr uint32_t num_threads = 128; - constexpr uint32_t bdy = num_threads / bdx; - constexpr uint32_t num_smem_stages = 4; - uint32_t smem_size = num_smem_stages * bdy * head_dim * sizeof(DTypeIn); - auto kernel = PersistentVariableLengthAttentionSumKernel< - vec_size, bdx, bdy, num_smem_stages, DTypeIn, DTypeO, IdType>; - CHECK_HIP_ERROR(hipOccupancyMaxActiveBlocksPerMultiprocessor( - &num_blocks_per_sm, kernel, num_threads, smem_size)); - num_blocks_per_sm = - min(num_blocks_per_sm, ceil_div(max_seq_len * num_heads, num_sms)); - - dim3 nblks(num_sms * num_blocks_per_sm); - dim3 nthrs(bdx, bdy); - CHECK_HIP_ERROR(hipFuncSetAttribute( - (void *)kernel, hipFuncAttributeMaxDynamicSharedMemorySize, - smem_size)); - PersistentVariableLengthAttentionSumKernel< - vec_size, bdx, bdy, num_smem_stages, DTypeIn, DTypeO, IdType> - <<>>(v, indptr, v_sum, max_seq_len, - seq_len, num_heads); - }); - return hipSuccess; +hipError_t VariableLengthAttentionSum(DTypeIn* v, IdType* indptr, DTypeO* v_sum, + uint32_t max_seq_len, uint32_t* seq_len, uint32_t num_heads, + uint32_t head_dim, hipStream_t stream = nullptr) { + int dev_id = 0; + int num_sms = 0; + int num_blocks_per_sm = 0; + CHECK_HIP_ERROR(hipGetDevice(&dev_id)); + CHECK_HIP_ERROR(hipDeviceGetAttribute(&num_sms, hipDeviceAttributeMultiprocessorCount, dev_id)); + + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + constexpr uint32_t vec_size = std::max(16U / sizeof(DTypeIn), HEAD_DIM / 32U); + constexpr uint32_t bdx = HEAD_DIM / vec_size; + constexpr uint32_t num_threads = 128; + constexpr uint32_t bdy = num_threads / bdx; + constexpr uint32_t num_smem_stages = 4; + uint32_t smem_size = num_smem_stages * bdy * head_dim * sizeof(DTypeIn); + auto kernel = PersistentVariableLengthAttentionSumKernel; + CHECK_HIP_ERROR(hipOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks_per_sm, kernel, + num_threads, smem_size)); + num_blocks_per_sm = min(num_blocks_per_sm, ceil_div(max_seq_len * num_heads, num_sms)); + + dim3 nblks(num_sms * num_blocks_per_sm); + dim3 nthrs(bdx, bdy); + CHECK_HIP_ERROR( + hipFuncSetAttribute((void*)kernel, hipFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + PersistentVariableLengthAttentionSumKernel + <<>>(v, indptr, v_sum, max_seq_len, seq_len, num_heads); + }); + return hipSuccess; } -} // namespace flashinfer +} // namespace flashinfer -#endif // FLASHINFER_CASCADE_CUH_ +#endif // FLASHINFER_CASCADE_CUH_ diff --git a/libflashinfer/include/flashinfer/hip/attention/decode.hip.h b/libflashinfer/include/flashinfer/hip/attention/decode.hip.h index 22e7fb5b09..b99324d136 100644 --- a/libflashinfer/include/flashinfer/hip/attention/decode.hip.h +++ b/libflashinfer/include/flashinfer/hip/attention/decode.hip.h @@ -9,16 +9,6 @@ #define HIP_ENABLE_WARP_SYNC_BUILTINS 1 -#include "../cp_async.hip.h" -#include "../pos_enc.hip.h" -#include "../utils.hip.h" -#include "../vec_dtypes.hip.h" -#include "cascade.hip.h" -#include "state.hip.h" - -#include "default_decode_params.hip.h" -#include "variants.hip.h" - #include #include #include @@ -27,22 +17,29 @@ #include -namespace flashinfer -{ +#include "../cp_async.hip.h" +#include "../pos_enc.hip.h" +#include "../utils.hip.h" +#include "../vec_dtypes.hip.h" +#include "cascade.hip.h" +#include "default_decode_params.hip.h" +#include "state.hip.h" +#include "variants.hip.h" + +namespace flashinfer { DEFINE_HAS_MEMBER(decode_maybe_q_rope_offset) -#define FLASHINFER_CHECK_STATUS(status) \ - if (status != hipSuccess) { \ - return status; \ - } +#define FLASHINFER_CHECK_STATUS(status) \ + if (status != hipSuccess) { \ + return status; \ + } namespace cg = cooperative_groups; using cp_async::PrefetchMode; using cp_async::SharedMemFillMode; -namespace -{ +namespace { /*! * \brief Load k tile from smem and compute qk @@ -64,81 +61,63 @@ namespace * \param s A float indicates the thread-local result of qk * \param st The self-attention state to be updated */ -template -__device__ __forceinline__ void compute_qk(const Params ¶ms, - AttentionVariant variant, - const uint32_t batch_idx, - const T *smem, - const vec_t &q_vec, - const vec_t &freq, - uint32_t kv_idx_base, - uint32_t iter_base, - uint32_t iter_bound, - uint32_t qo_head_idx, - uint32_t kv_head_idx, - float *s, - state_t &st) -{ - uint32_t tx = threadIdx.x, tz = threadIdx.z; - float m_prev = st.m; +template +__device__ __forceinline__ void compute_qk(const Params& params, AttentionVariant variant, + const uint32_t batch_idx, const T* smem, + const vec_t& q_vec, + const vec_t& freq, uint32_t kv_idx_base, + uint32_t iter_base, uint32_t iter_bound, + uint32_t qo_head_idx, uint32_t kv_head_idx, float* s, + state_t& st) { + uint32_t tx = threadIdx.x, tz = threadIdx.z; + float m_prev = st.m; #pragma unroll - for (uint32_t j = 0; j < tile_size; ++j) { - vec_t k_vec; - if constexpr (pos_encoding_mode == PosEncodingMode::kRoPELlama) { - // apply rotary embedding for all rows in k matrix of kv-cache - k_vec = vec_apply_llama_rope( - smem + j * bdx * vec_size, freq, - kv_idx_base + tz * tile_size + j); - } - else { - // do not apply rotary embedding - k_vec.cast_load(smem + (j * bdx + tx) * vec_size); - } - s[j] = 0.f; + for (uint32_t j = 0; j < tile_size; ++j) { + vec_t k_vec; + if constexpr (pos_encoding_mode == PosEncodingMode::kRoPELlama) { + // apply rotary embedding for all rows in k matrix of kv-cache + k_vec = vec_apply_llama_rope(smem + j * bdx * vec_size, freq, + kv_idx_base + tz * tile_size + j); + } else { + // do not apply rotary embedding + k_vec.cast_load(smem + (j * bdx + tx) * vec_size); + } + s[j] = 0.f; #pragma unroll - for (uint32_t i = 0; i < vec_size; ++i) { - s[j] += q_vec[i] * k_vec[i]; - } + for (uint32_t i = 0; i < vec_size; ++i) { + s[j] += q_vec[i] * k_vec[i]; + } #pragma unroll - for (uint32_t offset = bdx / 2; offset > 0; offset /= 2) { - s[j] += math::shfl_xor_sync(s[j], offset); - } - const uint32_t pos = kv_idx_base + tz * tile_size + j; - s[j] = - variant.LogitsTransform(params, s[j], batch_idx, /*qo_idx=*/0, - /*kv_idx=*/pos, qo_head_idx, kv_head_idx); - if constexpr (variant.use_softmax) { - s[j] *= variant.sm_scale_log2; - } - - bool mask = - variant.LogitsMask(params, batch_idx, /*qo_idx=*/0, /*kv_idx=*/pos, - qo_head_idx, kv_head_idx); - s[j] = (iter_base + tz * tile_size + j < iter_bound && mask) - ? s[j] - : -math::inf; - st.m = max(st.m, s[j]); + for (uint32_t offset = bdx / 2; offset > 0; offset /= 2) { + s[j] += math::shfl_xor_sync(s[j], offset); } - + const uint32_t pos = kv_idx_base + tz * tile_size + j; + s[j] = variant.LogitsTransform(params, s[j], batch_idx, /*qo_idx=*/0, + /*kv_idx=*/pos, qo_head_idx, kv_head_idx); if constexpr (variant.use_softmax) { - float o_scale = math::ptx_exp2(m_prev - st.m); - st.d *= o_scale; + s[j] *= variant.sm_scale_log2; + } + + bool mask = variant.LogitsMask(params, batch_idx, /*qo_idx=*/0, /*kv_idx=*/pos, qo_head_idx, + kv_head_idx); + s[j] = (iter_base + tz * tile_size + j < iter_bound && mask) ? s[j] : -math::inf; + st.m = max(st.m, s[j]); + } + + if constexpr (variant.use_softmax) { + float o_scale = math::ptx_exp2(m_prev - st.m); + st.d *= o_scale; #pragma unroll - for (uint32_t j = 0; j < tile_size; ++j) { - s[j] = math::ptx_exp2(s[j] - st.m); - st.d += s[j]; - } + for (uint32_t j = 0; j < tile_size; ++j) { + s[j] = math::ptx_exp2(s[j] - st.m); + st.d += s[j]; + } #pragma unroll - for (uint32_t i = 0; i < vec_size; ++i) { - st.o[i] = st.o[i] * o_scale; - } + for (uint32_t i = 0; i < vec_size; ++i) { + st.o[i] = st.o[i] * o_scale; } + } } /*! @@ -157,21 +136,19 @@ __device__ __forceinline__ void compute_qk(const Params ¶ms, * \param st The flashattention state to be updated */ template -__device__ __forceinline__ void update_local_state(const T *smem, - const float *s, +__device__ __forceinline__ void update_local_state(const T* smem, const float* s, uint32_t compute_stage_idx, - state_t &st) -{ - uint32_t tx = threadIdx.x; + state_t& st) { + uint32_t tx = threadIdx.x; #pragma unroll - for (uint32_t j = 0; j < tile_size; ++j) { - vec_t v_vec; - v_vec.cast_load(smem + (j * bdx + tx) * vec_size); + for (uint32_t j = 0; j < tile_size; ++j) { + vec_t v_vec; + v_vec.cast_load(smem + (j * bdx + tx) * vec_size); #pragma unroll - for (uint32_t i = 0; i < vec_size; ++i) { - st.o[i] = st.o[i] + s[j] * v_vec[i]; - } + for (uint32_t i = 0; i < vec_size; ++i) { + st.o[i] = st.o[i] + s[j] * v_vec[i]; } + } } /*! @@ -183,52 +160,43 @@ __device__ __forceinline__ void update_local_state(const T *smem, * \param smem The pointer to shared memory buffer for o * \param smem_md The pointer to shared memory buffer for m/d */ -template -__device__ __forceinline__ void sync_state(AttentionVariant variant, - state_t &st, - float *smem, - float *smem_md) -{ - if constexpr (bdz > 1) { - constexpr uint32_t head_dim = bdx * vec_size; - auto block = cg::this_thread_block(); - uint32_t tx = threadIdx.x, ty = threadIdx.y, tz = threadIdx.z; - st.o.store(smem + (tz * bdy + ty) * head_dim + tx * vec_size); - if constexpr (variant.use_softmax) { - smem_md[(tz * bdy + ty) * 2] = st.m; - smem_md[(tz * bdy + ty) * 2 + 1] = st.d; - block.sync(); - st.init(); +template +__device__ __forceinline__ void sync_state(AttentionVariant variant, state_t& st, + float* smem, float* smem_md) { + if constexpr (bdz > 1) { + constexpr uint32_t head_dim = bdx * vec_size; + auto block = cg::this_thread_block(); + uint32_t tx = threadIdx.x, ty = threadIdx.y, tz = threadIdx.z; + st.o.store(smem + (tz * bdy + ty) * head_dim + tx * vec_size); + if constexpr (variant.use_softmax) { + smem_md[(tz * bdy + ty) * 2] = st.m; + smem_md[(tz * bdy + ty) * 2 + 1] = st.d; + block.sync(); + st.init(); #pragma unroll - for (uint32_t j = 0; j < bdz; ++j) { - float mz = smem_md[(j * bdy + ty) * 2], - dz = smem_md[(j * bdy + ty) * 2 + 1]; - vec_t oz; - oz.load(smem + (j * bdy + ty) * head_dim + tx * vec_size); - st.merge(oz, mz, dz); - } - } - else { - block.sync(); - st.init(); + for (uint32_t j = 0; j < bdz; ++j) { + float mz = smem_md[(j * bdy + ty) * 2], dz = smem_md[(j * bdy + ty) * 2 + 1]; + vec_t oz; + oz.load(smem + (j * bdy + ty) * head_dim + tx * vec_size); + st.merge(oz, mz, dz); + } + } else { + block.sync(); + st.init(); #pragma unroll - for (uint32_t j = 0; j < bdz; ++j) { - vec_t oz; - oz.load(smem + (j * bdy + ty) * head_dim + tx * vec_size); + for (uint32_t j = 0; j < bdz; ++j) { + vec_t oz; + oz.load(smem + (j * bdy + ty) * head_dim + tx * vec_size); #pragma unroll - for (uint32_t i = 0; i < vec_size; ++i) { - st.o[i] += oz[i]; - } - } + for (uint32_t i = 0; i < vec_size; ++i) { + st.o[i] += oz[i]; } + } } + } } -} // namespace +} // namespace /*! * \brief FlashAttention decoding cuda kernel with kv-cache for a single request @@ -251,205 +219,155 @@ __device__ __forceinline__ void sync_state(AttentionVariant variant, * of "theta" used in RoPE (Rotary Positional Embeddings) * \param kv_chunk_size A integer indicates the kv-chunk size */ -template -__global__ void SingleDecodeWithKVCacheKernel(const Params params) -{ - using DTypeQ = typename Params::DTypeQ; - using DTypeKV = typename Params::DTypeKV; - using DTypeO = typename Params::DTypeO; - const DTypeQ *q = params.q; - const DTypeKV *k = params.k; - const DTypeKV *v = params.v; - const uint32_t q_stride_n = params.q_stride_n; - const uint32_t q_stride_h = params.q_stride_h; - const uint32_t kv_stride_n = params.kv_stride_n; - const uint32_t kv_stride_h = params.kv_stride_h; - DTypeO *o = params.o; - float *lse = params.lse; - uint32_t kv_chunk_size = params.kv_chunk_size; - - auto block = cg::this_thread_block(); - auto grid = cg::this_grid(); - - constexpr uint32_t head_dim = bdx * vec_size; - uint32_t kv_head_idx = blockIdx.y; - uint32_t qo_head_idx = kv_head_idx * bdy + threadIdx.y; - uint32_t kv_chunk_idx = blockIdx.x; - uint32_t num_qo_heads = params.num_qo_heads; - - extern __shared__ uint8_t smem[]; - AttentionVariant variant(params, /*batch_idx=*/0, smem); - const uint32_t seq_len = variant.kv_len; - DTypeKV *k_smem = (DTypeKV *)smem; - DTypeKV *v_smem = - (DTypeKV *)(smem + num_stages_smem * bdy * tile_size_per_bdx * bdz * - head_dim * sizeof(DTypeKV)); - float *smem_md = - (float *)(smem + 2 * num_stages_smem * bdy * tile_size_per_bdx * bdz * - head_dim * sizeof(DTypeKV)); - - uint32_t tx = threadIdx.x, ty = threadIdx.y, tz = threadIdx.z; - vec_t q_vec; - vec_t freq; - if constexpr (pos_encoding_mode == PosEncodingMode::kRoPELlama) { - const float rope_rcp_scale = params.rope_rcp_scale; - const float rope_rcp_theta = params.rope_rcp_theta; +__global__ void SingleDecodeWithKVCacheKernel(const Params params) { + using DTypeQ = typename Params::DTypeQ; + using DTypeKV = typename Params::DTypeKV; + using DTypeO = typename Params::DTypeO; + const DTypeQ* q = params.q; + const DTypeKV* k = params.k; + const DTypeKV* v = params.v; + const uint32_t q_stride_n = params.q_stride_n; + const uint32_t q_stride_h = params.q_stride_h; + const uint32_t kv_stride_n = params.kv_stride_n; + const uint32_t kv_stride_h = params.kv_stride_h; + DTypeO* o = params.o; + float* lse = params.lse; + uint32_t kv_chunk_size = params.kv_chunk_size; + + auto block = cg::this_thread_block(); + auto grid = cg::this_grid(); + + constexpr uint32_t head_dim = bdx * vec_size; + uint32_t kv_head_idx = blockIdx.y; + uint32_t qo_head_idx = kv_head_idx * bdy + threadIdx.y; + uint32_t kv_chunk_idx = blockIdx.x; + uint32_t num_qo_heads = params.num_qo_heads; + + extern __shared__ uint8_t smem[]; + AttentionVariant variant(params, /*batch_idx=*/0, smem); + const uint32_t seq_len = variant.kv_len; + DTypeKV* k_smem = (DTypeKV*)smem; + DTypeKV* v_smem = (DTypeKV*)(smem + num_stages_smem * bdy * tile_size_per_bdx * bdz * head_dim * + sizeof(DTypeKV)); + float* smem_md = (float*)(smem + 2 * num_stages_smem * bdy * tile_size_per_bdx * bdz * head_dim * + sizeof(DTypeKV)); + + uint32_t tx = threadIdx.x, ty = threadIdx.y, tz = threadIdx.z; + vec_t q_vec; + vec_t freq; + if constexpr (pos_encoding_mode == PosEncodingMode::kRoPELlama) { + const float rope_rcp_scale = params.rope_rcp_scale; + const float rope_rcp_theta = params.rope_rcp_theta; #pragma unroll - for (uint32_t i = 0; i < vec_size; ++i) { - freq[i] = rope_rcp_scale * - __powf(rope_rcp_theta, - float(2 * ((tx * vec_size + i) % (head_dim / 2))) / - float(head_dim)); - } - // apply rotary embedding to q matrix - q_vec = vec_apply_llama_rope( - q + qo_head_idx * q_stride_h, freq, seq_len - 1); - } - else { - // do not apply rotary embedding to q matrix - q_vec.cast_load(q + qo_head_idx * q_stride_h + tx * vec_size); + for (uint32_t i = 0; i < vec_size; ++i) { + freq[i] = rope_rcp_scale * + __powf(rope_rcp_theta, + float(2 * ((tx * vec_size + i) % (head_dim / 2))) / float(head_dim)); } - block.sync(); - - uint32_t chunk_start = kv_chunk_idx * kv_chunk_size; - kv_chunk_size = min(kv_chunk_size, seq_len - chunk_start); - uint32_t chunk_end = chunk_start + kv_chunk_size; - - // preload k tiles and v tiles - uint32_t producer_kv_idx_base = chunk_start; - constexpr uint32_t vec_bits = sizeof(DTypeKV) * vec_size * 8; + // apply rotary embedding to q matrix + q_vec = vec_apply_llama_rope(q + qo_head_idx * q_stride_h, freq, seq_len - 1); + } else { + // do not apply rotary embedding to q matrix + q_vec.cast_load(q + qo_head_idx * q_stride_h + tx * vec_size); + } + block.sync(); + + uint32_t chunk_start = kv_chunk_idx * kv_chunk_size; + kv_chunk_size = min(kv_chunk_size, seq_len - chunk_start); + uint32_t chunk_end = chunk_start + kv_chunk_size; + + // preload k tiles and v tiles + uint32_t producer_kv_idx_base = chunk_start; + constexpr uint32_t vec_bits = sizeof(DTypeKV) * vec_size * 8; #pragma unroll - for (uint32_t iter = 0; iter < num_stages_smem; ++iter) { - for (uint32_t j = 0; j < tile_size_per_bdx; ++j) { - cp_async::pred_load( - k_smem + - (((iter * bdz + tz) * bdy + ty) * tile_size_per_bdx + j) * - head_dim + - tx * vec_size, - k + - (producer_kv_idx_base + - (tz * bdy + ty) * tile_size_per_bdx + j) * - kv_stride_n + - kv_head_idx * kv_stride_h + tx * vec_size, - producer_kv_idx_base + (tz * bdy + ty) * tile_size_per_bdx + j < - chunk_end); - } - cp_async::commit_group(); - for (uint32_t j = 0; j < tile_size_per_bdx; ++j) { - cp_async::pred_load( - v_smem + - (((iter * bdz + tz) * bdy + ty) * tile_size_per_bdx + j) * - head_dim + - tx * vec_size, - v + - (producer_kv_idx_base + - (tz * bdy + ty) * tile_size_per_bdx + j) * - kv_stride_n + - kv_head_idx * kv_stride_h + tx * vec_size, - producer_kv_idx_base + (tz * bdy + ty) * tile_size_per_bdx + j < - chunk_end); - } - cp_async::commit_group(); - producer_kv_idx_base += bdy * bdz * tile_size_per_bdx; + for (uint32_t iter = 0; iter < num_stages_smem; ++iter) { + for (uint32_t j = 0; j < tile_size_per_bdx; ++j) { + cp_async::pred_load( + k_smem + (((iter * bdz + tz) * bdy + ty) * tile_size_per_bdx + j) * head_dim + + tx * vec_size, + k + (producer_kv_idx_base + (tz * bdy + ty) * tile_size_per_bdx + j) * kv_stride_n + + kv_head_idx * kv_stride_h + tx * vec_size, + producer_kv_idx_base + (tz * bdy + ty) * tile_size_per_bdx + j < chunk_end); } + cp_async::commit_group(); + for (uint32_t j = 0; j < tile_size_per_bdx; ++j) { + cp_async::pred_load( + v_smem + (((iter * bdz + tz) * bdy + ty) * tile_size_per_bdx + j) * head_dim + + tx * vec_size, + v + (producer_kv_idx_base + (tz * bdy + ty) * tile_size_per_bdx + j) * kv_stride_n + + kv_head_idx * kv_stride_h + tx * vec_size, + producer_kv_idx_base + (tz * bdy + ty) * tile_size_per_bdx + j < chunk_end); + } + cp_async::commit_group(); + producer_kv_idx_base += bdy * bdz * tile_size_per_bdx; + } - // pipelining k/v tiles loading and state updating - uint32_t consumer_kv_idx_base = chunk_start, stage_idx = 0; - state_t st_local; - float s[bdy * tile_size_per_bdx]; + // pipelining k/v tiles loading and state updating + uint32_t consumer_kv_idx_base = chunk_start, stage_idx = 0; + state_t st_local; + float s[bdy * tile_size_per_bdx]; #pragma unroll 2 - for (uint32_t iter = 0; - iter < ceil_div(kv_chunk_size, tile_size_per_bdx * bdy * bdz); ++iter) - { - // compute qk - cp_async::wait_group<2 * num_stages_smem - 1>(); - block.sync(); - compute_qk( - params, variant, /*batch_idx=*/0, - k_smem + - (stage_idx * bdz + tz) * bdy * tile_size_per_bdx * head_dim, - q_vec, freq, consumer_kv_idx_base, - iter * bdy * tile_size_per_bdx * bdz, kv_chunk_size, qo_head_idx, - kv_head_idx, s, st_local); - block.sync(); - // load k - for (uint32_t j = 0; j < tile_size_per_bdx; ++j) { - cp_async::pred_load( - k_smem + - (((stage_idx * bdz + tz) * bdy + ty) * tile_size_per_bdx + - j) * - head_dim + - tx * vec_size, - k + - (producer_kv_idx_base + - (tz * bdy + ty) * tile_size_per_bdx + j) * - kv_stride_n + - kv_head_idx * kv_stride_h + tx * vec_size, - producer_kv_idx_base + (tz * bdy + ty) * tile_size_per_bdx + j < - chunk_end); - } - cp_async::commit_group(); - - // update m/d/o state - cp_async::wait_group<2 * num_stages_smem - 1>(); - block.sync(); - update_local_state( - v_smem + - (stage_idx * bdz + tz) * bdy * tile_size_per_bdx * head_dim, - s, stage_idx, st_local); - block.sync(); - - // load v - for (uint32_t j = 0; j < tile_size_per_bdx; ++j) { - cp_async::pred_load( - v_smem + - (((stage_idx * bdz + tz) * bdy + ty) * tile_size_per_bdx + - j) * - head_dim + - tx * vec_size, - v + - (producer_kv_idx_base + - (tz * bdy + ty) * tile_size_per_bdx + j) * - kv_stride_n + - kv_head_idx * kv_stride_h + tx * vec_size, - producer_kv_idx_base + (tz * bdy + ty) * tile_size_per_bdx + j < - chunk_end); - } - cp_async::commit_group(); - - stage_idx = (stage_idx + 1) % num_stages_smem; - producer_kv_idx_base += tile_size_per_bdx * bdy * bdz; - consumer_kv_idx_base += tile_size_per_bdx * bdy * bdz; - } - cp_async::wait_group<0>(); + for (uint32_t iter = 0; iter < ceil_div(kv_chunk_size, tile_size_per_bdx * bdy * bdz); ++iter) { + // compute qk + cp_async::wait_group<2 * num_stages_smem - 1>(); block.sync(); - - // sync local state of all warps inside a threadblock - sync_state( - variant, st_local, reinterpret_cast(smem), smem_md); - if constexpr (variant.use_softmax) { - st_local.normalize(); + compute_qk( + params, variant, /*batch_idx=*/0, + k_smem + (stage_idx * bdz + tz) * bdy * tile_size_per_bdx * head_dim, q_vec, freq, + consumer_kv_idx_base, iter * bdy * tile_size_per_bdx * bdz, kv_chunk_size, qo_head_idx, + kv_head_idx, s, st_local); + block.sync(); + // load k + for (uint32_t j = 0; j < tile_size_per_bdx; ++j) { + cp_async::pred_load( + k_smem + (((stage_idx * bdz + tz) * bdy + ty) * tile_size_per_bdx + j) * head_dim + + tx * vec_size, + k + (producer_kv_idx_base + (tz * bdy + ty) * tile_size_per_bdx + j) * kv_stride_n + + kv_head_idx * kv_stride_h + tx * vec_size, + producer_kv_idx_base + (tz * bdy + ty) * tile_size_per_bdx + j < chunk_end); } + cp_async::commit_group(); - st_local.o.cast_store( - o + (kv_chunk_idx * num_qo_heads + qo_head_idx) * head_dim + - tx * vec_size); - if (lse != nullptr) { - lse[kv_chunk_idx * num_qo_heads + qo_head_idx] = st_local.get_lse(); + // update m/d/o state + cp_async::wait_group<2 * num_stages_smem - 1>(); + block.sync(); + update_local_state( + v_smem + (stage_idx * bdz + tz) * bdy * tile_size_per_bdx * head_dim, s, stage_idx, + st_local); + block.sync(); + + // load v + for (uint32_t j = 0; j < tile_size_per_bdx; ++j) { + cp_async::pred_load( + v_smem + (((stage_idx * bdz + tz) * bdy + ty) * tile_size_per_bdx + j) * head_dim + + tx * vec_size, + v + (producer_kv_idx_base + (tz * bdy + ty) * tile_size_per_bdx + j) * kv_stride_n + + kv_head_idx * kv_stride_h + tx * vec_size, + producer_kv_idx_base + (tz * bdy + ty) * tile_size_per_bdx + j < chunk_end); } + cp_async::commit_group(); + + stage_idx = (stage_idx + 1) % num_stages_smem; + producer_kv_idx_base += tile_size_per_bdx * bdy * bdz; + consumer_kv_idx_base += tile_size_per_bdx * bdy * bdz; + } + cp_async::wait_group<0>(); + block.sync(); + + // sync local state of all warps inside a threadblock + sync_state(variant, st_local, reinterpret_cast(smem), smem_md); + if constexpr (variant.use_softmax) { + st_local.normalize(); + } + + st_local.o.cast_store(o + (kv_chunk_idx * num_qo_heads + qo_head_idx) * head_dim + tx * vec_size); + if (lse != nullptr) { + lse[kv_chunk_idx * num_qo_heads + qo_head_idx] = st_local.get_lse(); + } } /*! @@ -476,269 +394,205 @@ __global__ void SingleDecodeWithKVCacheKernel(const Params params) * \param rope_rcp_theta A floating number indicate the reciprocal * of "theta" used in RoPE (Rotary Positional Embeddings) */ -template -__global__ void BatchDecodeWithPagedKVCacheKernel(const Params params) -{ - auto block = cg::this_thread_block(); - using DTypeQ = typename Params::DTypeQ; - using DTypeKV = typename Params::DTypeKV; - using DTypeO = typename Params::DTypeO; - using IdType = typename Params::IdType; - const DTypeQ *q = params.q; - DTypeO *o = params.o; - float *lse = params.lse; - const auto paged_kv = params.paged_kv; - const bool *block_valid_mask = params.block_valid_mask; - const uint32_t padded_batch_size = params.padded_batch_size; - const uint32_t num_qo_heads = params.num_qo_heads; - const bool partition_kv = params.partition_kv; - - constexpr uint32_t head_dim = bdx * vec_size; - const uint32_t bx = blockIdx.x, by = blockIdx.y; - const uint32_t batch_idx = params.request_indices[bx]; - const uint32_t kv_tile_idx = params.kv_tile_indices[bx]; - const uint32_t kv_head_idx = by; - const uint32_t qo_head_idx = kv_head_idx * bdy + threadIdx.y; - // NOTE(Zihao): when CUDAGraph is enabled, we will launch more blocks than - // the actual batch size, so we need to check if the current batch is valid - if (block_valid_mask && !block_valid_mask[bx]) - return; - const uint32_t kv_chunk_size = *(params.kv_chunk_size_ptr); - const uint32_t kv_len = paged_kv.get_length(batch_idx); - const uint32_t max_chunk_size = partition_kv ? kv_chunk_size : kv_len; - const uint32_t chunk_start = - partition_kv ? kv_tile_idx * max_chunk_size : 0; - const uint32_t chunk_end = - partition_kv ? min((kv_tile_idx + 1) * max_chunk_size, kv_len) : kv_len; - const uint32_t chunk_size = chunk_end - chunk_start; - - extern __shared__ uint8_t smem[]; - AttentionVariant variant(params, batch_idx, smem); - DTypeKV *k_smem = (DTypeKV *)smem; - DTypeKV *v_smem = - (DTypeKV *)(smem + num_stages_smem * tile_size_per_bdx * bdy * bdz * - head_dim * sizeof(DTypeKV)); - size_t *kv_offset_smem = - (size_t *)(smem + 2 * num_stages_smem * tile_size_per_bdx * bdy * bdz * - head_dim * sizeof(DTypeKV)); - float *smem_md = - (float *)(smem + 2 * num_stages_smem * tile_size_per_bdx * bdy * bdz * - head_dim * sizeof(DTypeKV)); - - const uint32_t tx = threadIdx.x, ty = threadIdx.y, tz = threadIdx.z; - vec_t q_vec; - vec_t freq; - const uint32_t q_stride_n = params.q_stride_n; - const uint32_t q_stride_h = params.q_stride_h; - if constexpr (POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) { - const IdType *q_rope_offset = nullptr; - if constexpr (has_decode_maybe_q_rope_offset_v) { - q_rope_offset = params.decode_maybe_q_rope_offset; - } - int32_t q_rope_offset_val = - q_rope_offset == nullptr ? (kv_len - 1) : q_rope_offset[batch_idx]; - const float rope_rcp_scale = params.rope_rcp_scale; - const float rope_rcp_theta = params.rope_rcp_theta; -#pragma unroll - for (uint32_t i = 0; i < vec_size; ++i) { - freq[i] = rope_rcp_scale * - __powf(rope_rcp_theta, - float(2 * ((tx * vec_size + i) % (head_dim / 2))) / - float(head_dim)); - } - // apply rotary embedding to q matrix - q_vec = vec_apply_llama_rope( - q + batch_idx * q_stride_n + qo_head_idx * q_stride_h, freq, - q_rope_offset_val); - } - else { - // do not apply rotary embedding to q matrix - q_vec.cast_load(q + batch_idx * q_stride_n + qo_head_idx * q_stride_h + - tx * vec_size); +__global__ void BatchDecodeWithPagedKVCacheKernel(const Params params) { + auto block = cg::this_thread_block(); + using DTypeQ = typename Params::DTypeQ; + using DTypeKV = typename Params::DTypeKV; + using DTypeO = typename Params::DTypeO; + using IdType = typename Params::IdType; + const DTypeQ* q = params.q; + DTypeO* o = params.o; + float* lse = params.lse; + const auto paged_kv = params.paged_kv; + const bool* block_valid_mask = params.block_valid_mask; + const uint32_t padded_batch_size = params.padded_batch_size; + const uint32_t num_qo_heads = params.num_qo_heads; + const bool partition_kv = params.partition_kv; + + constexpr uint32_t head_dim = bdx * vec_size; + const uint32_t bx = blockIdx.x, by = blockIdx.y; + const uint32_t batch_idx = params.request_indices[bx]; + const uint32_t kv_tile_idx = params.kv_tile_indices[bx]; + const uint32_t kv_head_idx = by; + const uint32_t qo_head_idx = kv_head_idx * bdy + threadIdx.y; + // NOTE(Zihao): when CUDAGraph is enabled, we will launch more blocks than + // the actual batch size, so we need to check if the current batch is valid + if (block_valid_mask && !block_valid_mask[bx]) return; + const uint32_t kv_chunk_size = *(params.kv_chunk_size_ptr); + const uint32_t kv_len = paged_kv.get_length(batch_idx); + const uint32_t max_chunk_size = partition_kv ? kv_chunk_size : kv_len; + const uint32_t chunk_start = partition_kv ? kv_tile_idx * max_chunk_size : 0; + const uint32_t chunk_end = + partition_kv ? min((kv_tile_idx + 1) * max_chunk_size, kv_len) : kv_len; + const uint32_t chunk_size = chunk_end - chunk_start; + + extern __shared__ uint8_t smem[]; + AttentionVariant variant(params, batch_idx, smem); + DTypeKV* k_smem = (DTypeKV*)smem; + DTypeKV* v_smem = (DTypeKV*)(smem + num_stages_smem * tile_size_per_bdx * bdy * bdz * head_dim * + sizeof(DTypeKV)); + size_t* kv_offset_smem = (size_t*)(smem + 2 * num_stages_smem * tile_size_per_bdx * bdy * bdz * + head_dim * sizeof(DTypeKV)); + float* smem_md = (float*)(smem + 2 * num_stages_smem * tile_size_per_bdx * bdy * bdz * head_dim * + sizeof(DTypeKV)); + + const uint32_t tx = threadIdx.x, ty = threadIdx.y, tz = threadIdx.z; + vec_t q_vec; + vec_t freq; + const uint32_t q_stride_n = params.q_stride_n; + const uint32_t q_stride_h = params.q_stride_h; + if constexpr (POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) { + const IdType* q_rope_offset = nullptr; + if constexpr (has_decode_maybe_q_rope_offset_v) { + q_rope_offset = params.decode_maybe_q_rope_offset; } - - // preload k/v tiles - uint32_t stage_idx = 0; - constexpr uint32_t vec_bits = sizeof(DTypeKV) * vec_size * 8; - const IdType last_indptr = paged_kv.indptr[paged_kv.batch_size]; - - static_assert(num_stages_smem <= bdx); - uint32_t packed_page_iter_base = - paged_kv.indptr[batch_idx] * paged_kv.page_size + chunk_start; + int32_t q_rope_offset_val = q_rope_offset == nullptr ? (kv_len - 1) : q_rope_offset[batch_idx]; + const float rope_rcp_scale = params.rope_rcp_scale; + const float rope_rcp_theta = params.rope_rcp_theta; #pragma unroll - for (uint32_t j = 0; j < tile_size_per_bdx; ++j) { - uint32_t q, r; - paged_kv.page_size.divmod(packed_page_iter_base + - ((j * bdz + tz) * bdy + ty) * bdx + tx, - q, r); - kv_offset_smem[((j * bdz + tz) * bdy + ty) * bdx + tx] = - paged_kv.protective_get_kv_offset(q, kv_head_idx, r, 0, - last_indptr); + for (uint32_t i = 0; i < vec_size; ++i) { + freq[i] = rope_rcp_scale * + __powf(rope_rcp_theta, + float(2 * ((tx * vec_size + i) % (head_dim / 2))) / float(head_dim)); } - block.sync(); + // apply rotary embedding to q matrix + q_vec = vec_apply_llama_rope( + q + batch_idx * q_stride_n + qo_head_idx * q_stride_h, freq, q_rope_offset_val); + } else { + // do not apply rotary embedding to q matrix + q_vec.cast_load(q + batch_idx * q_stride_n + qo_head_idx * q_stride_h + tx * vec_size); + } + + // preload k/v tiles + uint32_t stage_idx = 0; + constexpr uint32_t vec_bits = sizeof(DTypeKV) * vec_size * 8; + const IdType last_indptr = paged_kv.indptr[paged_kv.batch_size]; + + static_assert(num_stages_smem <= bdx); + uint32_t packed_page_iter_base = paged_kv.indptr[batch_idx] * paged_kv.page_size + chunk_start; +#pragma unroll + for (uint32_t j = 0; j < tile_size_per_bdx; ++j) { + uint32_t q, r; + paged_kv.page_size.divmod(packed_page_iter_base + ((j * bdz + tz) * bdy + ty) * bdx + tx, q, r); + kv_offset_smem[((j * bdz + tz) * bdy + ty) * bdx + tx] = + paged_kv.protective_get_kv_offset(q, kv_head_idx, r, 0, last_indptr); + } + block.sync(); - size_t kv_offset[tile_size_per_bdx]; + size_t kv_offset[tile_size_per_bdx]; #pragma unroll - for (uint32_t iter = 0; iter < num_stages_smem; ++iter) { + for (uint32_t iter = 0; iter < num_stages_smem; ++iter) { #pragma unroll - for (uint32_t j = 0; j < tile_size_per_bdx; ++j) { - kv_offset[j] = kv_offset_smem[((iter * bdz + tz) * bdy + ty) * - tile_size_per_bdx + - j] + - tx * vec_size; - } + for (uint32_t j = 0; j < tile_size_per_bdx; ++j) { + kv_offset[j] = + kv_offset_smem[((iter * bdz + tz) * bdy + ty) * tile_size_per_bdx + j] + tx * vec_size; + } #pragma unroll - for (uint32_t j = 0; j < tile_size_per_bdx; ++j) { - cp_async::pred_load( - k_smem + - (((stage_idx * bdz + tz) * bdy + ty) * tile_size_per_bdx + - j) * - head_dim + - tx * vec_size, - paged_kv.k_data + kv_offset[j], - ((iter * bdz + tz) * bdy + ty) * tile_size_per_bdx + j < - chunk_size); - } - cp_async::commit_group(); + for (uint32_t j = 0; j < tile_size_per_bdx; ++j) { + cp_async::pred_load( + k_smem + (((stage_idx * bdz + tz) * bdy + ty) * tile_size_per_bdx + j) * head_dim + + tx * vec_size, + paged_kv.k_data + kv_offset[j], + ((iter * bdz + tz) * bdy + ty) * tile_size_per_bdx + j < chunk_size); + } + cp_async::commit_group(); #pragma unroll - for (uint32_t j = 0; j < tile_size_per_bdx; ++j) { - cp_async::pred_load( - v_smem + - (((stage_idx * bdz + tz) * bdy + ty) * tile_size_per_bdx + - j) * - head_dim + - tx * vec_size, - paged_kv.v_data + kv_offset[j], - ((iter * bdz + tz) * bdy + ty) * tile_size_per_bdx + j < - chunk_size); - } - cp_async::commit_group(); - stage_idx = (stage_idx + 1) % num_stages_smem; + for (uint32_t j = 0; j < tile_size_per_bdx; ++j) { + cp_async::pred_load( + v_smem + (((stage_idx * bdz + tz) * bdy + ty) * tile_size_per_bdx + j) * head_dim + + tx * vec_size, + paged_kv.v_data + kv_offset[j], + ((iter * bdz + tz) * bdy + ty) * tile_size_per_bdx + j < chunk_size); } + cp_async::commit_group(); + stage_idx = (stage_idx + 1) % num_stages_smem; + } - state_t st; - float s[bdy * tile_size_per_bdx]; + state_t st; + float s[bdy * tile_size_per_bdx]; #pragma unroll 2 - for (uint32_t iter = 0; - iter < ceil_div(chunk_size, tile_size_per_bdx * bdy * bdz); ++iter) - { - if ((iter + num_stages_smem) % bdx == 0) { + for (uint32_t iter = 0; iter < ceil_div(chunk_size, tile_size_per_bdx * bdy * bdz); ++iter) { + if ((iter + num_stages_smem) % bdx == 0) { #pragma unroll - for (uint32_t j = 0; j < tile_size_per_bdx; ++j) { - uint32_t q, r; - paged_kv.page_size.divmod( - packed_page_iter_base + - ((iter + num_stages_smem) * tile_size_per_bdx * bdy * - bdz + - ((j * bdz + tz) * bdy + ty) * bdx + tx), - q, r); - kv_offset_smem[((j * bdz + tz) * bdy + ty) * bdx + tx] = - paged_kv.protective_get_kv_offset(q, kv_head_idx, r, 0, - last_indptr); - } - } - // compute qk - cp_async::wait_group<2 * num_stages_smem - 1>(); - block.sync(); - compute_qk( - params, variant, batch_idx, - k_smem + - (stage_idx * bdz + tz) * bdy * tile_size_per_bdx * head_dim, - q_vec, freq, - (paged_kv.rope_pos_offset == nullptr - ? 0 - : paged_kv.rope_pos_offset[batch_idx]) + - chunk_start + iter * tile_size_per_bdx * bdy * bdz, - iter * tile_size_per_bdx * bdy * bdz, chunk_size, qo_head_idx, - kv_head_idx, s, st); - block.sync(); + for (uint32_t j = 0; j < tile_size_per_bdx; ++j) { + uint32_t q, r; + paged_kv.page_size.divmod( + packed_page_iter_base + ((iter + num_stages_smem) * tile_size_per_bdx * bdy * bdz + + ((j * bdz + tz) * bdy + ty) * bdx + tx), + q, r); + kv_offset_smem[((j * bdz + tz) * bdy + ty) * bdx + tx] = + paged_kv.protective_get_kv_offset(q, kv_head_idx, r, 0, last_indptr); + } + } + // compute qk + cp_async::wait_group<2 * num_stages_smem - 1>(); + block.sync(); + compute_qk( + params, variant, batch_idx, + k_smem + (stage_idx * bdz + tz) * bdy * tile_size_per_bdx * head_dim, q_vec, freq, + (paged_kv.rope_pos_offset == nullptr ? 0 : paged_kv.rope_pos_offset[batch_idx]) + + chunk_start + iter * tile_size_per_bdx * bdy * bdz, + iter * tile_size_per_bdx * bdy * bdz, chunk_size, qo_head_idx, kv_head_idx, s, st); + block.sync(); #pragma unroll - for (uint32_t j = 0; j < tile_size_per_bdx; ++j) { - kv_offset[j] = - kv_offset_smem[((((iter + num_stages_smem) % bdx) * bdz + tz) * - bdy + - ty) * - tile_size_per_bdx + - j] + - tx * vec_size; - } + for (uint32_t j = 0; j < tile_size_per_bdx; ++j) { + kv_offset[j] = kv_offset_smem[((((iter + num_stages_smem) % bdx) * bdz + tz) * bdy + ty) * + tile_size_per_bdx + + j] + + tx * vec_size; + } - // load k tiles -#pragma unroll - for (uint32_t j = 0; j < tile_size_per_bdx; ++j) { - cp_async::pred_load( - k_smem + - (((stage_idx * bdz + tz) * bdy + ty) * tile_size_per_bdx + - j) * - head_dim + - tx * vec_size, - paged_kv.k_data + kv_offset[j], - (((iter + num_stages_smem) * bdz + tz) * bdy + ty) * - tile_size_per_bdx + - j < - chunk_size); - } - cp_async::commit_group(); - - // update m/d/o states - cp_async::wait_group<2 * num_stages_smem - 1>(); - block.sync(); - update_local_state( - v_smem + - (stage_idx * bdz + tz) * bdy * tile_size_per_bdx * head_dim, - s, stage_idx, st); - block.sync(); - - // load v tiles + // load k tiles #pragma unroll - for (uint32_t j = 0; j < tile_size_per_bdx; ++j) { - cp_async::pred_load( - v_smem + - (((stage_idx * bdz + tz) * bdy + ty) * tile_size_per_bdx + - j) * - head_dim + - tx * vec_size, - paged_kv.v_data + kv_offset[j], - (((iter + num_stages_smem) * bdz + tz) * bdy + ty) * - tile_size_per_bdx + - j < - chunk_size); - } - cp_async::commit_group(); - stage_idx = (stage_idx + 1) % num_stages_smem; + for (uint32_t j = 0; j < tile_size_per_bdx; ++j) { + cp_async::pred_load( + k_smem + (((stage_idx * bdz + tz) * bdy + ty) * tile_size_per_bdx + j) * head_dim + + tx * vec_size, + paged_kv.k_data + kv_offset[j], + (((iter + num_stages_smem) * bdz + tz) * bdy + ty) * tile_size_per_bdx + j < chunk_size); } - cp_async::wait_group<0>(); + cp_async::commit_group(); + + // update m/d/o states + cp_async::wait_group<2 * num_stages_smem - 1>(); + block.sync(); + update_local_state( + v_smem + (stage_idx * bdz + tz) * bdy * tile_size_per_bdx * head_dim, s, stage_idx, st); block.sync(); - // sync local state of all warps inside a threadblock - sync_state( - variant, st, reinterpret_cast(smem), smem_md); - if constexpr (variant.use_softmax) { - st.normalize(); + // load v tiles +#pragma unroll + for (uint32_t j = 0; j < tile_size_per_bdx; ++j) { + cp_async::pred_load( + v_smem + (((stage_idx * bdz + tz) * bdy + ty) * tile_size_per_bdx + j) * head_dim + + tx * vec_size, + paged_kv.v_data + kv_offset[j], + (((iter + num_stages_smem) * bdz + tz) * bdy + ty) * tile_size_per_bdx + j < chunk_size); } - - if (tz == 0) { - st.o.cast_store(o + (bx * num_qo_heads + qo_head_idx) * head_dim + - tx * vec_size); - // write lse - if (lse != nullptr) { - lse[bx * num_qo_heads + qo_head_idx] = st.get_lse(); - } + cp_async::commit_group(); + stage_idx = (stage_idx + 1) % num_stages_smem; + } + cp_async::wait_group<0>(); + block.sync(); + + // sync local state of all warps inside a threadblock + sync_state(variant, st, reinterpret_cast(smem), smem_md); + if constexpr (variant.use_softmax) { + st.normalize(); + } + + if (tz == 0) { + st.o.cast_store(o + (bx * num_qo_heads + qo_head_idx) * head_dim + tx * vec_size); + // write lse + if (lse != nullptr) { + lse[bx * num_qo_heads + qo_head_idx] = st.get_lse(); } + } } /*! @@ -747,20 +601,16 @@ __global__ void BatchDecodeWithPagedKVCacheKernel(const Params params) * GQA. * \param sizeof_dtype The size (in terms of bytes) of the input data type */ -constexpr uint32_t get_heuristic_num_threads(uint32_t group_size, - uint32_t sizeof_dtype) -{ - if (group_size == 8U) { - if (sizeof_dtype == 1U) { - return 256U; // not enough registers for 512 threads - } - else { - return 512U; - } - } - else { - return 128U; +constexpr uint32_t get_heuristic_num_threads(uint32_t group_size, uint32_t sizeof_dtype) { + if (group_size == 8U) { + if (sizeof_dtype == 1U) { + return 256U; // not enough registers for 512 threads + } else { + return 512U; } + } else { + return 128U; + } } /*! @@ -786,609 +636,484 @@ constexpr uint32_t get_heuristic_num_threads(uint32_t group_size, * \param stream The cuda stream to launch the kernel * \return status Indicates whether CUDA calls are successful */ -template -hipError_t SingleDecodeWithKVCacheDispatched(Params params, - typename Params::DTypeO *tmp, - hipStream_t stream) -{ - using DTypeQ = typename Params::DTypeQ; - using DTypeKV = typename Params::DTypeKV; - using DTypeO = typename Params::DTypeO; - const uint32_t num_qo_heads = params.num_qo_heads; - const uint32_t num_kv_heads = params.num_kv_heads; - const uint32_t seq_len = params.kv_len; - constexpr uint32_t vec_size = - std::max(16UL / sizeof(DTypeKV), HEAD_DIM / 32UL); - constexpr uint32_t bdx = HEAD_DIM / vec_size; - auto compute_capacity = GetCudaComputeCapability(); - static_assert(bdx <= 32U); - DISPATCH_GQA_GROUP_SIZE(num_qo_heads / num_kv_heads, GROUP_SIZE, { - constexpr uint32_t bdy = GROUP_SIZE; - constexpr uint32_t num_threads = std::max( - get_heuristic_num_threads(GROUP_SIZE, sizeof(DTypeKV)), bdx * bdy); - constexpr uint32_t bdz = num_threads / (bdx * bdy); - constexpr uint32_t tile_size_per_bdx = - GROUP_SIZE == 1 ? (sizeof(DTypeKV) == 1 ? 2U : 8U) : 1U; - DISPATCH_COMPUTE_CAP_DECODE_NUM_STAGES_SMEM( - compute_capacity, NUM_STAGES_SMEM, { - const uint32_t smem_size = 2U * NUM_STAGES_SMEM * bdy * - tile_size_per_bdx * bdz * - HEAD_DIM * sizeof(DTypeKV) + - 2U * bdy * bdz * sizeof(float); - auto kernel = SingleDecodeWithKVCacheKernel< - POS_ENCODING_MODE, NUM_STAGES_SMEM, tile_size_per_bdx, - vec_size, bdx, bdy, bdz, AttentionVariant, Params>; - FLASHINFER_CUDA_CALL(hipFuncSetAttribute( - (void *)kernel, hipFuncAttributeMaxDynamicSharedMemorySize, - smem_size)); - if (seq_len <= 256 || tmp == nullptr) { - // no need to use partition-kv kernel - dim3 nblks = dim3(1, num_kv_heads); - dim3 nthrs = dim3(bdx, bdy, bdz); - params.kv_chunk_size = seq_len; - void *args[] = {(void *)¶ms}; - SingleDecodeWithKVCacheKernel< - POS_ENCODING_MODE, NUM_STAGES_SMEM, tile_size_per_bdx, - vec_size, bdx, bdy, bdz, AttentionVariant, Params> - <<>>(params); - } - else { - // use partition-kv kernel - int num_blocks_per_sm = 0; - int num_sm = 0; - int dev_id = 0; - FLASHINFER_CUDA_CALL(hipGetDevice(&dev_id)); - FLASHINFER_CUDA_CALL(hipDeviceGetAttribute( - &num_sm, hipDeviceAttributeMultiprocessorCount, - dev_id)); - FLASHINFER_CUDA_CALL( - hipOccupancyMaxActiveBlocksPerMultiprocessor( - &num_blocks_per_sm, kernel, num_threads, - smem_size)); - // FIXME: The hipOccupancyMaxActiveBlocksPerMultiprocessor - // function does not return accurate results always causing - // a potential of a division by zero error for certain - // configurations. For now, instead of using the HIP - // function to derive num_blocks_per_sm we are hard coding - // it to eight based on limited tuning runs. A better - // heuristics to derive the maximum number of blocks per - // SM/CU is needed to not have to hardcode the value. - uint32_t max_grid_size = uint32_t(8) * uint32_t(num_sm); - uint32_t max_num_kv_chunks = max_grid_size / num_kv_heads; - uint32_t kv_chunk_size = - max(ceil_div(seq_len, max_num_kv_chunks), 256); - uint32_t num_chunks = ceil_div(seq_len, kv_chunk_size); - dim3 nblks = dim3(num_chunks, num_kv_heads); - if (nblks.x == 0 || nblks.y == 0) { - std::ostringstream err_msg; - err_msg << "Invalid kernel configuration: nblks=(" - << nblks.x << "," << nblks.y << ")"; - FLASHINFER_ERROR(err_msg.str()); - } - dim3 nthrs = dim3(bdx, bdy, bdz); - float *tmp_lse = - (float *)(tmp + num_chunks * num_qo_heads * HEAD_DIM); - auto o = params.o; - params.o = tmp; - params.lse = tmp_lse; - params.kv_chunk_size = kv_chunk_size; - - SingleDecodeWithKVCacheKernel< - POS_ENCODING_MODE, NUM_STAGES_SMEM, tile_size_per_bdx, - vec_size, bdx, bdy, bdz, AttentionVariant, Params> - <<>>(params); - - if constexpr (AttentionVariant::use_softmax) { - CHECK_HIP_ERROR(MergeStates(tmp, tmp_lse, o, nullptr, - num_chunks, 1, num_qo_heads, - HEAD_DIM, stream)); - } - else { - CHECK_HIP_ERROR(AttentionSum(tmp, o, num_chunks, 1, - num_qo_heads, HEAD_DIM, - stream)); - } - } - }); +hipError_t SingleDecodeWithKVCacheDispatched(Params params, typename Params::DTypeO* tmp, + hipStream_t stream) { + using DTypeQ = typename Params::DTypeQ; + using DTypeKV = typename Params::DTypeKV; + using DTypeO = typename Params::DTypeO; + const uint32_t num_qo_heads = params.num_qo_heads; + const uint32_t num_kv_heads = params.num_kv_heads; + const uint32_t seq_len = params.kv_len; + constexpr uint32_t vec_size = std::max(16UL / sizeof(DTypeKV), HEAD_DIM / 32UL); + constexpr uint32_t bdx = HEAD_DIM / vec_size; + auto compute_capacity = GetCudaComputeCapability(); + static_assert(bdx <= 32U); + DISPATCH_GQA_GROUP_SIZE(num_qo_heads / num_kv_heads, GROUP_SIZE, { + constexpr uint32_t bdy = GROUP_SIZE; + constexpr uint32_t num_threads = + std::max(get_heuristic_num_threads(GROUP_SIZE, sizeof(DTypeKV)), bdx * bdy); + constexpr uint32_t bdz = num_threads / (bdx * bdy); + constexpr uint32_t tile_size_per_bdx = GROUP_SIZE == 1 ? (sizeof(DTypeKV) == 1 ? 2U : 8U) : 1U; + DISPATCH_COMPUTE_CAP_DECODE_NUM_STAGES_SMEM(compute_capacity, NUM_STAGES_SMEM, { + const uint32_t smem_size = + 2U * NUM_STAGES_SMEM * bdy * tile_size_per_bdx * bdz * HEAD_DIM * sizeof(DTypeKV) + + 2U * bdy * bdz * sizeof(float); + auto kernel = + SingleDecodeWithKVCacheKernel; + FLASHINFER_CUDA_CALL(hipFuncSetAttribute( + (void*)kernel, hipFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + if (seq_len <= 256 || tmp == nullptr) { + // no need to use partition-kv kernel + dim3 nblks = dim3(1, num_kv_heads); + dim3 nthrs = dim3(bdx, bdy, bdz); + params.kv_chunk_size = seq_len; + void* args[] = {(void*)¶ms}; + SingleDecodeWithKVCacheKernel + <<>>(params); + } else { + // use partition-kv kernel + int num_blocks_per_sm = 0; + int num_sm = 0; + int dev_id = 0; + FLASHINFER_CUDA_CALL(hipGetDevice(&dev_id)); + FLASHINFER_CUDA_CALL( + hipDeviceGetAttribute(&num_sm, hipDeviceAttributeMultiprocessorCount, dev_id)); + FLASHINFER_CUDA_CALL(hipOccupancyMaxActiveBlocksPerMultiprocessor( + &num_blocks_per_sm, kernel, num_threads, smem_size)); + // FIXME: The hipOccupancyMaxActiveBlocksPerMultiprocessor + // function does not return accurate results always causing + // a potential of a division by zero error for certain + // configurations. For now, instead of using the HIP + // function to derive num_blocks_per_sm we are hard coding + // it to eight based on limited tuning runs. A better + // heuristics to derive the maximum number of blocks per + // SM/CU is needed to not have to hardcode the value. + uint32_t max_grid_size = uint32_t(8) * uint32_t(num_sm); + uint32_t max_num_kv_chunks = max_grid_size / num_kv_heads; + uint32_t kv_chunk_size = max(ceil_div(seq_len, max_num_kv_chunks), 256); + uint32_t num_chunks = ceil_div(seq_len, kv_chunk_size); + dim3 nblks = dim3(num_chunks, num_kv_heads); + if (nblks.x == 0 || nblks.y == 0) { + std::ostringstream err_msg; + err_msg << "Invalid kernel configuration: nblks=(" << nblks.x << "," << nblks.y << ")"; + FLASHINFER_ERROR(err_msg.str()); + } + dim3 nthrs = dim3(bdx, bdy, bdz); + float* tmp_lse = (float*)(tmp + num_chunks * num_qo_heads * HEAD_DIM); + auto o = params.o; + params.o = tmp; + params.lse = tmp_lse; + params.kv_chunk_size = kv_chunk_size; + + SingleDecodeWithKVCacheKernel + <<>>(params); + + if constexpr (AttentionVariant::use_softmax) { + CHECK_HIP_ERROR( + MergeStates(tmp, tmp_lse, o, nullptr, num_chunks, 1, num_qo_heads, HEAD_DIM, stream)); + } else { + CHECK_HIP_ERROR(AttentionSum(tmp, o, num_chunks, 1, num_qo_heads, HEAD_DIM, stream)); + } + } }); - return hipSuccess; + }); + return hipSuccess; } -template -hipError_t BatchDecodeWithPagedKVCacheDispatched(Params params, - typename Params::DTypeO *tmp_v, - float *tmp_s, - hipStream_t stream) -{ - using DTypeQ = typename Params::DTypeQ; - using DTypeKV = typename Params::DTypeKV; - using DTypeO = typename Params::DTypeO; - using IdType = typename Params::IdType; - const uint32_t num_qo_heads = params.num_qo_heads; - const uint32_t num_kv_heads = params.paged_kv.num_heads; - const uint32_t padded_batch_size = params.padded_batch_size; - - constexpr uint32_t vec_size = - std::max(16UL / sizeof(DTypeKV), HEAD_DIM / 32UL); - auto compute_capacity = GetCudaComputeCapability(); - constexpr uint32_t bdx = HEAD_DIM / vec_size; - static_assert(bdx <= 32); - DISPATCH_GQA_GROUP_SIZE(num_qo_heads / num_kv_heads, GROUP_SIZE, { - constexpr uint32_t bdy = GROUP_SIZE; - constexpr uint32_t num_threads = std::max(128U, bdx * bdy); - constexpr uint32_t bdz = num_threads / (bdx * bdy); - constexpr uint32_t tile_size_per_bdx = - GROUP_SIZE == 1 ? (sizeof(DTypeKV) == 1 ? 2U : 4U) : 1U; - DISPATCH_COMPUTE_CAP_DECODE_NUM_STAGES_SMEM( - compute_capacity, NUM_STAGES_SMEM, { - const uint32_t smem_size = - 2 * NUM_STAGES_SMEM * tile_size_per_bdx * bdy * bdz * - HEAD_DIM * sizeof(DTypeKV) + - std::max(tile_size_per_bdx * num_threads * - sizeof(DTypeKV *), - 2 * bdy * bdz * sizeof(float)); - auto kernel = BatchDecodeWithPagedKVCacheKernel< - POS_ENCODING_MODE, NUM_STAGES_SMEM, tile_size_per_bdx, - vec_size, bdx, bdy, bdz, AttentionVariant, Params>; - FLASHINFER_CUDA_CALL(hipFuncSetAttribute( - (void *)kernel, hipFuncAttributeMaxDynamicSharedMemorySize, - smem_size)); - - if (tmp_v == nullptr) { - // do not use partition-kv kernel - dim3 nblks(padded_batch_size, num_kv_heads); - dim3 nthrs(bdx, bdy, bdz); - params.partition_kv = false; - BatchDecodeWithPagedKVCacheKernel< - POS_ENCODING_MODE, NUM_STAGES_SMEM, tile_size_per_bdx, - vec_size, bdx, bdy, bdz, AttentionVariant, Params> - <<>>(params); - } - else { - // use partition-kv kernel - params.partition_kv = true; - auto o = params.o; - auto lse = params.lse; - params.o = tmp_v; - params.lse = tmp_s; - void *args[] = {(void *)¶ms}; - dim3 nblks(padded_batch_size, num_kv_heads); - dim3 nthrs(bdx, bdy, bdz); - BatchDecodeWithPagedKVCacheKernel< - POS_ENCODING_MODE, NUM_STAGES_SMEM, tile_size_per_bdx, - vec_size, bdx, bdy, bdz, AttentionVariant, Params> - <<>>(params); - if constexpr (AttentionVariant::use_softmax) { - CHECK_HIP_ERROR(VariableLengthMergeStates( - tmp_v, tmp_s, params.o_indptr, o, lse, - params.paged_kv.batch_size, nullptr, num_qo_heads, - HEAD_DIM, stream)); - } - else { - CHECK_HIP_ERROR(VariableLengthAttentionSum( - tmp_v, params.o_indptr, o, - params.paged_kv.batch_size, nullptr, num_qo_heads, - HEAD_DIM, stream)); - } - } - }); +hipError_t BatchDecodeWithPagedKVCacheDispatched(Params params, typename Params::DTypeO* tmp_v, + float* tmp_s, hipStream_t stream) { + using DTypeQ = typename Params::DTypeQ; + using DTypeKV = typename Params::DTypeKV; + using DTypeO = typename Params::DTypeO; + using IdType = typename Params::IdType; + const uint32_t num_qo_heads = params.num_qo_heads; + const uint32_t num_kv_heads = params.paged_kv.num_heads; + const uint32_t padded_batch_size = params.padded_batch_size; + + constexpr uint32_t vec_size = std::max(16UL / sizeof(DTypeKV), HEAD_DIM / 32UL); + auto compute_capacity = GetCudaComputeCapability(); + constexpr uint32_t bdx = HEAD_DIM / vec_size; + static_assert(bdx <= 32); + DISPATCH_GQA_GROUP_SIZE(num_qo_heads / num_kv_heads, GROUP_SIZE, { + constexpr uint32_t bdy = GROUP_SIZE; + constexpr uint32_t num_threads = std::max(128U, bdx * bdy); + constexpr uint32_t bdz = num_threads / (bdx * bdy); + constexpr uint32_t tile_size_per_bdx = GROUP_SIZE == 1 ? (sizeof(DTypeKV) == 1 ? 2U : 4U) : 1U; + DISPATCH_COMPUTE_CAP_DECODE_NUM_STAGES_SMEM(compute_capacity, NUM_STAGES_SMEM, { + const uint32_t smem_size = + 2 * NUM_STAGES_SMEM * tile_size_per_bdx * bdy * bdz * HEAD_DIM * sizeof(DTypeKV) + + std::max(tile_size_per_bdx * num_threads * sizeof(DTypeKV*), + 2 * bdy * bdz * sizeof(float)); + auto kernel = + BatchDecodeWithPagedKVCacheKernel; + FLASHINFER_CUDA_CALL(hipFuncSetAttribute( + (void*)kernel, hipFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + + if (tmp_v == nullptr) { + // do not use partition-kv kernel + dim3 nblks(padded_batch_size, num_kv_heads); + dim3 nthrs(bdx, bdy, bdz); + params.partition_kv = false; + BatchDecodeWithPagedKVCacheKernel + <<>>(params); + } else { + // use partition-kv kernel + params.partition_kv = true; + auto o = params.o; + auto lse = params.lse; + params.o = tmp_v; + params.lse = tmp_s; + void* args[] = {(void*)¶ms}; + dim3 nblks(padded_batch_size, num_kv_heads); + dim3 nthrs(bdx, bdy, bdz); + BatchDecodeWithPagedKVCacheKernel + <<>>(params); + if constexpr (AttentionVariant::use_softmax) { + CHECK_HIP_ERROR(VariableLengthMergeStates(tmp_v, tmp_s, params.o_indptr, o, lse, + params.paged_kv.batch_size, nullptr, + num_qo_heads, HEAD_DIM, stream)); + } else { + CHECK_HIP_ERROR(VariableLengthAttentionSum(tmp_v, params.o_indptr, o, + params.paged_kv.batch_size, nullptr, + num_qo_heads, HEAD_DIM, stream)); + } + } }); - return hipSuccess; + }); + return hipSuccess; } -template +template __device__ __forceinline__ void compute_qk_and_update_local_stat_mla( - const Params ¶ms, - AttentionVariant variant, - const uint32_t batch_idx, - const T *ckv_smem, - const vec_t &q_nope_vec, - const T *kpe_smem, - const vec_t &q_pe_vec, - const vec_t &freq, - uint32_t kv_idx_base, - uint32_t iter_base, - uint32_t iter_bound, - state_t &st) -{ - uint32_t tx = threadIdx.x, tz = threadIdx.z; - constexpr uint32_t head_dim_ckv = bdx * vec_size_ckv; - constexpr uint32_t head_dim_kpe = bdx * vec_size_kpe; - float s[tile_size]; - float m_prev = st.m; + const Params& params, AttentionVariant variant, const uint32_t batch_idx, const T* ckv_smem, + const vec_t& q_nope_vec, const T* kpe_smem, + const vec_t& q_pe_vec, const vec_t& freq, + uint32_t kv_idx_base, uint32_t iter_base, uint32_t iter_bound, state_t& st) { + uint32_t tx = threadIdx.x, tz = threadIdx.z; + constexpr uint32_t head_dim_ckv = bdx * vec_size_ckv; + constexpr uint32_t head_dim_kpe = bdx * vec_size_kpe; + float s[tile_size]; + float m_prev = st.m; #pragma unroll - for (uint32_t j = 0; j < tile_size; ++j) { - vec_t ckv_vec; - ckv_vec.cast_load(ckv_smem + j * head_dim_ckv + tx * vec_size_ckv); + for (uint32_t j = 0; j < tile_size; ++j) { + vec_t ckv_vec; + ckv_vec.cast_load(ckv_smem + j * head_dim_ckv + tx * vec_size_ckv); - vec_t kpe_vec; - kpe_vec.cast_load(kpe_smem + j * head_dim_kpe + tx * vec_size_kpe); + vec_t kpe_vec; + kpe_vec.cast_load(kpe_smem + j * head_dim_kpe + tx * vec_size_kpe); - s[j] = 0.f; + s[j] = 0.f; #pragma unroll - for (uint32_t i = 0; i < vec_size_ckv; ++i) { - s[j] += q_nope_vec[i] * ckv_vec[i]; - } + for (uint32_t i = 0; i < vec_size_ckv; ++i) { + s[j] += q_nope_vec[i] * ckv_vec[i]; + } #pragma unroll - for (uint32_t i = 0; i < vec_size_kpe; ++i) { - s[j] += q_pe_vec[i] * kpe_vec[i]; - } - s[j] *= params.sm_scale; + for (uint32_t i = 0; i < vec_size_kpe; ++i) { + s[j] += q_pe_vec[i] * kpe_vec[i]; + } + s[j] *= params.sm_scale; #pragma unroll - for (uint32_t offset = bdx / 2; offset > 0; offset /= 2) { - s[j] += math::shfl_xor_sync(s[j], offset); - } - s[j] = - (iter_base + tz * tile_size + j < iter_bound) ? s[j] : -math::inf; - st.m = max(st.m, s[j]); + for (uint32_t offset = bdx / 2; offset > 0; offset /= 2) { + s[j] += math::shfl_xor_sync(s[j], offset); } + s[j] = (iter_base + tz * tile_size + j < iter_bound) ? s[j] : -math::inf; + st.m = max(st.m, s[j]); + } - float o_scale = math::ptx_exp2(m_prev - st.m); - st.d *= o_scale; + float o_scale = math::ptx_exp2(m_prev - st.m); + st.d *= o_scale; #pragma unroll - for (uint32_t j = 0; j < tile_size; ++j) { - s[j] = math::ptx_exp2(s[j] - st.m); - st.d += s[j]; - } + for (uint32_t j = 0; j < tile_size; ++j) { + s[j] = math::ptx_exp2(s[j] - st.m); + st.d += s[j]; + } #pragma unroll - for (uint32_t i = 0; i < vec_size_ckv; ++i) { - st.o[i] = st.o[i] * o_scale; - } + for (uint32_t i = 0; i < vec_size_ckv; ++i) { + st.o[i] = st.o[i] * o_scale; + } #pragma unroll - for (uint32_t j = 0; j < tile_size; ++j) { - vec_t v_vec; - v_vec.cast_load(ckv_smem + j * head_dim_ckv + tx * vec_size_ckv); + for (uint32_t j = 0; j < tile_size; ++j) { + vec_t v_vec; + v_vec.cast_load(ckv_smem + j * head_dim_ckv + tx * vec_size_ckv); #pragma unroll - for (uint32_t i = 0; i < vec_size_ckv; ++i) { - st.o[i] = st.o[i] + s[j] * v_vec[i]; - } + for (uint32_t i = 0; i < vec_size_ckv; ++i) { + st.o[i] = st.o[i] + s[j] * v_vec[i]; } + } } -template -__global__ void BatchDecodeWithPagedKVCacheKernelMLA(Params params) -{ - auto block = cg::this_thread_block(); - using DTypeQ = typename Params::DTypeQ; - using DTypeKV = typename Params::DTypeKV; - using DTypeO = typename Params::DTypeO; - using IdType = typename Params::IdType; - const DTypeQ *q_nope = params.q_nope; - const DTypeQ *q_pe = params.q_pe; - DTypeO *o = params.o; - float *lse = params.lse; - const auto &paged_kv = params.paged_kv; - const IdType *q_rope_offset = params.q_rope_offset; - const bool *block_valid_mask = params.block_valid_mask; - const uint32_t num_qo_heads = params.num_qo_heads; - const float rope_rcp_scale = params.rope_rcp_scale; - const float rope_rcp_theta = params.rope_rcp_theta; - const bool partition_kv = params.partition_kv; - params.sm_scale *= math::log2e; - - constexpr uint32_t head_dim_ckv = bdx * vec_size_ckv; - constexpr uint32_t head_dim_kpe = bdx * vec_size_kpe; - const uint32_t batch_idx = blockIdx.x; - const uint32_t tx = threadIdx.x, ty = threadIdx.y, tz = threadIdx.z; - const uint32_t t_offset = dim3_offset(bdy, bdx, tz, ty, tx); - - // NOTE(Zihao): when CUDAGraph is enabled, we will launch more blocks than - // the actual batch size, so we need to check if the current batch is valid - if (block_valid_mask && !block_valid_mask[batch_idx]) - return; - const uint32_t mapped_batch_idx = params.request_indices[batch_idx]; - - const uint32_t orig_seq_len = paged_kv.get_length(mapped_batch_idx); - int32_t q_rope_offset_val = q_rope_offset == nullptr - ? (orig_seq_len - 1) - : q_rope_offset[mapped_batch_idx]; - - const uint32_t kv_chunk_idx_in_orig_mapped_batch = - params.kv_tile_indices[batch_idx]; - const uint32_t kv_chunk_size = *(params.kv_chunk_size_ptr); - const uint32_t cur_chunk_start = - partition_kv ? kv_chunk_idx_in_orig_mapped_batch * kv_chunk_size : 0; - const uint32_t cur_chunk_end = - partition_kv - ? min((kv_chunk_idx_in_orig_mapped_batch + 1) * kv_chunk_size, - orig_seq_len) - : orig_seq_len; - const uint32_t cur_chunk_len = cur_chunk_end - cur_chunk_start; - - uint32_t packed_page_iter_base = - paged_kv.indptr[mapped_batch_idx] * paged_kv.page_size + - cur_chunk_start; - const IdType last_indptr = paged_kv.indptr[paged_kv.batch_size]; - - constexpr uint32_t kv_iter_len = bdy * bdz; - constexpr uint32_t compute_qk_tile = bdy; - - extern __attribute__((shared)) uint8_t smem[]; - DTypeKV *ckv_smem = (DTypeKV *)smem; - DTypeKV *kpe_smem = - (DTypeKV *)((uint8_t *)ckv_smem + num_stages_smem * kv_iter_len * - head_dim_ckv * sizeof(DTypeKV)); - size_t *ckv_offset_smem = - (size_t *)((uint8_t *)kpe_smem + num_stages_smem * kv_iter_len * - head_dim_kpe * sizeof(DTypeKV)); - size_t *kpe_offset_smem = (size_t *)((uint8_t *)ckv_offset_smem + - bdx * bdy * bdz * sizeof(size_t)); - float *smem_md = (float *)ckv_offset_smem; - - AttentionVariant variant(params, batch_idx, smem); - - vec_t q_nope_vec[tile_size_qo_heads]; - vec_t q_pe_vec[tile_size_qo_heads]; - state_t st[tile_size_qo_heads]; - uint32_t qo_head_idx[tile_size_qo_heads]; - - vec_t freq; +__global__ void BatchDecodeWithPagedKVCacheKernelMLA(Params params) { + auto block = cg::this_thread_block(); + using DTypeQ = typename Params::DTypeQ; + using DTypeKV = typename Params::DTypeKV; + using DTypeO = typename Params::DTypeO; + using IdType = typename Params::IdType; + const DTypeQ* q_nope = params.q_nope; + const DTypeQ* q_pe = params.q_pe; + DTypeO* o = params.o; + float* lse = params.lse; + const auto& paged_kv = params.paged_kv; + const IdType* q_rope_offset = params.q_rope_offset; + const bool* block_valid_mask = params.block_valid_mask; + const uint32_t num_qo_heads = params.num_qo_heads; + const float rope_rcp_scale = params.rope_rcp_scale; + const float rope_rcp_theta = params.rope_rcp_theta; + const bool partition_kv = params.partition_kv; + params.sm_scale *= math::log2e; + + constexpr uint32_t head_dim_ckv = bdx * vec_size_ckv; + constexpr uint32_t head_dim_kpe = bdx * vec_size_kpe; + const uint32_t batch_idx = blockIdx.x; + const uint32_t tx = threadIdx.x, ty = threadIdx.y, tz = threadIdx.z; + const uint32_t t_offset = dim3_offset(bdy, bdx, tz, ty, tx); + + // NOTE(Zihao): when CUDAGraph is enabled, we will launch more blocks than + // the actual batch size, so we need to check if the current batch is valid + if (block_valid_mask && !block_valid_mask[batch_idx]) return; + const uint32_t mapped_batch_idx = params.request_indices[batch_idx]; + + const uint32_t orig_seq_len = paged_kv.get_length(mapped_batch_idx); + int32_t q_rope_offset_val = + q_rope_offset == nullptr ? (orig_seq_len - 1) : q_rope_offset[mapped_batch_idx]; + + const uint32_t kv_chunk_idx_in_orig_mapped_batch = params.kv_tile_indices[batch_idx]; + const uint32_t kv_chunk_size = *(params.kv_chunk_size_ptr); + const uint32_t cur_chunk_start = + partition_kv ? kv_chunk_idx_in_orig_mapped_batch * kv_chunk_size : 0; + const uint32_t cur_chunk_end = + partition_kv ? min((kv_chunk_idx_in_orig_mapped_batch + 1) * kv_chunk_size, orig_seq_len) + : orig_seq_len; + const uint32_t cur_chunk_len = cur_chunk_end - cur_chunk_start; + + uint32_t packed_page_iter_base = + paged_kv.indptr[mapped_batch_idx] * paged_kv.page_size + cur_chunk_start; + const IdType last_indptr = paged_kv.indptr[paged_kv.batch_size]; + + constexpr uint32_t kv_iter_len = bdy * bdz; + constexpr uint32_t compute_qk_tile = bdy; + + extern __attribute__((shared)) uint8_t smem[]; + DTypeKV* ckv_smem = (DTypeKV*)smem; + DTypeKV* kpe_smem = (DTypeKV*)((uint8_t*)ckv_smem + + num_stages_smem * kv_iter_len * head_dim_ckv * sizeof(DTypeKV)); + size_t* ckv_offset_smem = (size_t*)((uint8_t*)kpe_smem + num_stages_smem * kv_iter_len * + head_dim_kpe * sizeof(DTypeKV)); + size_t* kpe_offset_smem = (size_t*)((uint8_t*)ckv_offset_smem + bdx * bdy * bdz * sizeof(size_t)); + float* smem_md = (float*)ckv_offset_smem; + + AttentionVariant variant(params, batch_idx, smem); + + vec_t q_nope_vec[tile_size_qo_heads]; + vec_t q_pe_vec[tile_size_qo_heads]; + state_t st[tile_size_qo_heads]; + uint32_t qo_head_idx[tile_size_qo_heads]; + + vec_t freq; #pragma unroll - for (uint32_t i = 0; i < vec_size_kpe; ++i) { - freq[i] = - rope_rcp_scale * - __powf(rope_rcp_theta, float(2 * ((tx * vec_size_kpe + i) / 2)) / - float(head_dim_kpe)); - } - // load q_nope and q_pe tile + for (uint32_t i = 0; i < vec_size_kpe; ++i) { + freq[i] = rope_rcp_scale * __powf(rope_rcp_theta, float(2 * ((tx * vec_size_kpe + i) / 2)) / + float(head_dim_kpe)); + } + // load q_nope and q_pe tile #pragma unroll - for (int i = 0; i < tile_size_qo_heads; ++i) { - qo_head_idx[i] = - dim3_offset(bdy, tile_size_qo_heads, blockIdx.y, threadIdx.y, i); - if (qo_head_idx[i] < num_qo_heads) { - q_nope_vec[i].cast_load( - q_nope + - (mapped_batch_idx * num_qo_heads + qo_head_idx[i]) * - head_dim_ckv + - tx * vec_size_ckv); - q_pe_vec[i].cast_load( - q_pe + - (mapped_batch_idx * num_qo_heads + qo_head_idx[i]) * - head_dim_kpe + - tx * vec_size_kpe); - } + for (int i = 0; i < tile_size_qo_heads; ++i) { + qo_head_idx[i] = dim3_offset(bdy, tile_size_qo_heads, blockIdx.y, threadIdx.y, i); + if (qo_head_idx[i] < num_qo_heads) { + q_nope_vec[i].cast_load(q_nope + + (mapped_batch_idx * num_qo_heads + qo_head_idx[i]) * head_dim_ckv + + tx * vec_size_ckv); + q_pe_vec[i].cast_load(q_pe + + (mapped_batch_idx * num_qo_heads + qo_head_idx[i]) * head_dim_kpe + + tx * vec_size_kpe); } - - // init paged-cache read offset to be used - uint32_t q, r; - paged_kv.page_size.divmod(packed_page_iter_base + t_offset, q, r); - ckv_offset_smem[t_offset] = - paged_kv.protective_get_offset_ckv(q, r, /*feat_idx*/ 0, last_indptr); - kpe_offset_smem[t_offset] = - paged_kv.protective_get_offset_kpe(q, r, /*feat_idx*/ 0, last_indptr); - block.sync(); - - uint32_t stage_idx = 0; - constexpr uint32_t vec_bits = sizeof(DTypeKV) * vec_size_ckv * 8; - constexpr uint32_t tx_fold = vec_size_ckv / vec_size_kpe; - static_assert(num_stages_smem <= bdx); - size_t offset_bytes; - bool is_valid_range; + } + + // init paged-cache read offset to be used + uint32_t q, r; + paged_kv.page_size.divmod(packed_page_iter_base + t_offset, q, r); + ckv_offset_smem[t_offset] = paged_kv.protective_get_offset_ckv(q, r, /*feat_idx*/ 0, last_indptr); + kpe_offset_smem[t_offset] = paged_kv.protective_get_offset_kpe(q, r, /*feat_idx*/ 0, last_indptr); + block.sync(); + + uint32_t stage_idx = 0; + constexpr uint32_t vec_bits = sizeof(DTypeKV) * vec_size_ckv * 8; + constexpr uint32_t tx_fold = vec_size_ckv / vec_size_kpe; + static_assert(num_stages_smem <= bdx); + size_t offset_bytes; + bool is_valid_range; #pragma unroll - for (uint32_t iter = 0; iter < num_stages_smem; ++iter) { - is_valid_range = - (iter * kv_iter_len + dim2_offset(bdy, tz, ty)) < cur_chunk_len; - - offset_bytes = ckv_offset_smem[dim3_offset(bdz, bdy, iter, tz, ty)] + - tx * vec_size_ckv; - cp_async::pred_load( - ckv_smem + - (stage_idx * kv_iter_len + dim2_offset(bdy, tz, ty)) * - head_dim_ckv + - tx * vec_size_ckv, - paged_kv.ckv_data + offset_bytes, is_valid_range); - - offset_bytes = kpe_offset_smem[dim3_offset(bdz, bdy, iter, tz, ty)] + - tx / tx_fold * vec_size_ckv; - cp_async::pred_load( - kpe_smem + - (stage_idx * kv_iter_len + dim2_offset(bdy, tz, ty)) * - head_dim_kpe + - tx / tx_fold * vec_size_ckv, - paged_kv.kpe_data + offset_bytes, is_valid_range); - - cp_async::commit_group(); - stage_idx = (stage_idx + 1) % num_stages_smem; - } + for (uint32_t iter = 0; iter < num_stages_smem; ++iter) { + is_valid_range = (iter * kv_iter_len + dim2_offset(bdy, tz, ty)) < cur_chunk_len; + + offset_bytes = ckv_offset_smem[dim3_offset(bdz, bdy, iter, tz, ty)] + tx * vec_size_ckv; + cp_async::pred_load( + ckv_smem + (stage_idx * kv_iter_len + dim2_offset(bdy, tz, ty)) * head_dim_ckv + + tx * vec_size_ckv, + paged_kv.ckv_data + offset_bytes, is_valid_range); + + offset_bytes = + kpe_offset_smem[dim3_offset(bdz, bdy, iter, tz, ty)] + tx / tx_fold * vec_size_ckv; + cp_async::pred_load( + kpe_smem + (stage_idx * kv_iter_len + dim2_offset(bdy, tz, ty)) * head_dim_kpe + + tx / tx_fold * vec_size_ckv, + paged_kv.kpe_data + offset_bytes, is_valid_range); + + cp_async::commit_group(); + stage_idx = (stage_idx + 1) % num_stages_smem; + } #pragma unroll - for (uint32_t iter = 0; iter < ceil_div(cur_chunk_len, kv_iter_len); ++iter) - { - cp_async::wait_group<1 * num_stages_smem - 1>(); - block.sync(); - const int32_t kv_idx_base = - (paged_kv.rope_pos_offset == nullptr - ? 0 - : paged_kv.rope_pos_offset[mapped_batch_idx]) + - cur_chunk_start + iter * kv_iter_len; + for (uint32_t iter = 0; iter < ceil_div(cur_chunk_len, kv_iter_len); ++iter) { + cp_async::wait_group<1 * num_stages_smem - 1>(); + block.sync(); + const int32_t kv_idx_base = + (paged_kv.rope_pos_offset == nullptr ? 0 : paged_kv.rope_pos_offset[mapped_batch_idx]) + + cur_chunk_start + iter * kv_iter_len; #pragma unroll - for (int i = 0; i < tile_size_qo_heads; ++i) { - compute_qk_and_update_local_stat_mla( - params, variant, mapped_batch_idx, - ckv_smem + (stage_idx * kv_iter_len + tz * compute_qk_tile) * - head_dim_ckv, - q_nope_vec[i], - kpe_smem + (stage_idx * kv_iter_len + tz * compute_qk_tile) * - head_dim_kpe, - q_pe_vec[i], freq, kv_idx_base, - /*iter_base*/ iter * kv_iter_len, /*iter_bound*/ cur_chunk_len, - st[i]); - } + for (int i = 0; i < tile_size_qo_heads; ++i) { + compute_qk_and_update_local_stat_mla( + params, variant, mapped_batch_idx, + ckv_smem + (stage_idx * kv_iter_len + tz * compute_qk_tile) * head_dim_ckv, q_nope_vec[i], + kpe_smem + (stage_idx * kv_iter_len + tz * compute_qk_tile) * head_dim_kpe, q_pe_vec[i], + freq, kv_idx_base, + /*iter_base*/ iter * kv_iter_len, /*iter_bound*/ cur_chunk_len, st[i]); + } - if ((iter + num_stages_smem) % bdx == 0) { - uint32_t q, r; - paged_kv.page_size.divmod( - packed_page_iter_base + (iter + num_stages_smem) * kv_iter_len + - t_offset, - q, r); - ckv_offset_smem[t_offset] = paged_kv.protective_get_offset_ckv( - q, r, /*feat_idx*/ 0, last_indptr); - kpe_offset_smem[t_offset] = paged_kv.protective_get_offset_kpe( - q, r, /*feat_idx*/ 0, last_indptr); - } - block.sync(); - - is_valid_range = ((iter + num_stages_smem) * kv_iter_len + - dim2_offset(bdy, tz, ty)) < cur_chunk_len; - offset_bytes = ckv_offset_smem[dim3_offset( - bdz, bdy, (iter + num_stages_smem) % bdx, tz, ty)] + - tx * vec_size_ckv; - cp_async::pred_load( - ckv_smem + - (stage_idx * kv_iter_len + dim2_offset(bdy, tz, ty)) * - head_dim_ckv + - tx * vec_size_ckv, - paged_kv.ckv_data + offset_bytes, is_valid_range); - - offset_bytes = kpe_offset_smem[dim3_offset( - bdz, bdy, (iter + num_stages_smem) % bdx, tz, ty)] + - tx / tx_fold * vec_size_ckv; - cp_async::pred_load( - kpe_smem + - (stage_idx * kv_iter_len + dim2_offset(bdy, tz, ty)) * - head_dim_kpe + - tx / tx_fold * vec_size_ckv, - paged_kv.kpe_data + offset_bytes, is_valid_range); - cp_async::commit_group(); - - stage_idx = (stage_idx + 1) % num_stages_smem; + if ((iter + num_stages_smem) % bdx == 0) { + uint32_t q, r; + paged_kv.page_size.divmod( + packed_page_iter_base + (iter + num_stages_smem) * kv_iter_len + t_offset, q, r); + ckv_offset_smem[t_offset] = + paged_kv.protective_get_offset_ckv(q, r, /*feat_idx*/ 0, last_indptr); + kpe_offset_smem[t_offset] = + paged_kv.protective_get_offset_kpe(q, r, /*feat_idx*/ 0, last_indptr); } - cp_async::wait_group<0>(); block.sync(); - if (bdz != 1) { + is_valid_range = + ((iter + num_stages_smem) * kv_iter_len + dim2_offset(bdy, tz, ty)) < cur_chunk_len; + offset_bytes = ckv_offset_smem[dim3_offset(bdz, bdy, (iter + num_stages_smem) % bdx, tz, ty)] + + tx * vec_size_ckv; + cp_async::pred_load( + ckv_smem + (stage_idx * kv_iter_len + dim2_offset(bdy, tz, ty)) * head_dim_ckv + + tx * vec_size_ckv, + paged_kv.ckv_data + offset_bytes, is_valid_range); + + offset_bytes = kpe_offset_smem[dim3_offset(bdz, bdy, (iter + num_stages_smem) % bdx, tz, ty)] + + tx / tx_fold * vec_size_ckv; + cp_async::pred_load( + kpe_smem + (stage_idx * kv_iter_len + dim2_offset(bdy, tz, ty)) * head_dim_kpe + + tx / tx_fold * vec_size_ckv, + paged_kv.kpe_data + offset_bytes, is_valid_range); + cp_async::commit_group(); + + stage_idx = (stage_idx + 1) % num_stages_smem; + } + cp_async::wait_group<0>(); + block.sync(); + + if (bdz != 1) { #pragma unroll - for (int i = 0; i < tile_size_qo_heads; ++i) { - if (qo_head_idx[i] < num_qo_heads) - sync_state(variant, st[i], - (float *)smem, smem_md); - } + for (int i = 0; i < tile_size_qo_heads; ++i) { + if (qo_head_idx[i] < num_qo_heads) + sync_state(variant, st[i], (float*)smem, smem_md); } + } - if (tz == 0) { + if (tz == 0) { #pragma unroll - for (int i = 0; i < tile_size_qo_heads; ++i) { - if (qo_head_idx[i] < num_qo_heads) { - st[i].normalize(); - st[i].o.cast_store(o + - (batch_idx * num_qo_heads + qo_head_idx[i]) * - head_dim_ckv + - tx * vec_size_ckv); - - if (lse != nullptr) { - lse[batch_idx * num_qo_heads + qo_head_idx[i]] = - st[i].get_lse(); - } - } + for (int i = 0; i < tile_size_qo_heads; ++i) { + if (qo_head_idx[i] < num_qo_heads) { + st[i].normalize(); + st[i].o.cast_store(o + (batch_idx * num_qo_heads + qo_head_idx[i]) * head_dim_ckv + + tx * vec_size_ckv); + + if (lse != nullptr) { + lse[batch_idx * num_qo_heads + qo_head_idx[i]] = st[i].get_lse(); } + } } + } } -template -hipError_t -BatchDecodeWithPagedKVCacheDispatchedMLA(Params params, - typename Params::DTypeO *tmp_v, - float *tmp_s, - hipStream_t stream) -{ - using DTypeQ = typename Params::DTypeQ; - using DTypeKV = typename Params::DTypeKV; - using DTypeO = typename Params::DTypeO; - using IdType = typename Params::IdType; - const uint32_t num_qo_heads = params.num_qo_heads; - const uint32_t padded_batch_size = params.padded_batch_size; - - constexpr uint32_t vec_size_ckv = - std::max(16UL / sizeof(DTypeKV), HEAD_DIM_CKV / 32UL); - constexpr uint32_t bdx = HEAD_DIM_CKV / vec_size_ckv; - constexpr uint32_t vec_size_kpe = HEAD_DIM_KPE / bdx; - - constexpr uint32_t bdy = 8; - constexpr uint32_t tile_size_qo_heads = 2; - constexpr uint32_t qo_heads_per_block = bdy * tile_size_qo_heads; - constexpr uint32_t num_threads = std::max(128U, bdx * bdy); - constexpr uint32_t bdz = num_threads / (bdx * bdy); - const uint32_t gdy = ceil_div(num_qo_heads, qo_heads_per_block); - - auto compute_capacity = GetCudaComputeCapability(); - DISPATCH_COMPUTE_CAP_DECODE_NUM_STAGES_SMEM( - compute_capacity, NUM_STAGES_SMEM, { - const uint32_t smem_size = - NUM_STAGES_SMEM * bdy * bdz * (HEAD_DIM_CKV + HEAD_DIM_KPE) * - sizeof(DTypeKV) + - std::max(num_threads * sizeof(size_t) * 2, - 2 * bdy * bdz * sizeof(float)); - - auto kernel = BatchDecodeWithPagedKVCacheKernelMLA< - NUM_STAGES_SMEM, vec_size_ckv, vec_size_kpe, bdx, bdy, bdz, - tile_size_qo_heads, AttentionVariant, Params>; - FLASHINFER_CUDA_CALL(hipFuncSetAttribute( - (void *)kernel, hipFuncAttributeMaxDynamicSharedMemorySize, - smem_size)); - - if (tmp_v == nullptr) { - // do not use partition-kv kernel - dim3 nblks(padded_batch_size, gdy); - dim3 nthrs(bdx, bdy, bdz); - params.partition_kv = false; - BatchDecodeWithPagedKVCacheKernelMLA< - NUM_STAGES_SMEM, vec_size_ckv, vec_size_kpe, bdx, bdy, bdz, - tile_size_qo_heads, AttentionVariant, Params> - <<>>(params); - } - else { - // use partition-kv kernel - params.partition_kv = true; - auto o = params.o; - auto lse = params.lse; - params.o = tmp_v; - params.lse = tmp_s; - void *args[] = {(void *)¶ms}; - dim3 nblks(padded_batch_size, gdy); - dim3 nthrs(bdx, bdy, bdz); - BatchDecodeWithPagedKVCacheKernelMLA< - NUM_STAGES_SMEM, vec_size_ckv, vec_size_kpe, bdx, bdy, bdz, - tile_size_qo_heads, AttentionVariant, Params> - <<>>(params); - - CHECK_HIP_ERROR(VariableLengthMergeStates( - tmp_v, tmp_s, params.o_indptr, o, lse, - params.paged_kv.batch_size, nullptr, num_qo_heads, - HEAD_DIM_CKV, stream)); - } - }); - return hipSuccess; +template +hipError_t BatchDecodeWithPagedKVCacheDispatchedMLA(Params params, typename Params::DTypeO* tmp_v, + float* tmp_s, hipStream_t stream) { + using DTypeQ = typename Params::DTypeQ; + using DTypeKV = typename Params::DTypeKV; + using DTypeO = typename Params::DTypeO; + using IdType = typename Params::IdType; + const uint32_t num_qo_heads = params.num_qo_heads; + const uint32_t padded_batch_size = params.padded_batch_size; + + constexpr uint32_t vec_size_ckv = std::max(16UL / sizeof(DTypeKV), HEAD_DIM_CKV / 32UL); + constexpr uint32_t bdx = HEAD_DIM_CKV / vec_size_ckv; + constexpr uint32_t vec_size_kpe = HEAD_DIM_KPE / bdx; + + constexpr uint32_t bdy = 8; + constexpr uint32_t tile_size_qo_heads = 2; + constexpr uint32_t qo_heads_per_block = bdy * tile_size_qo_heads; + constexpr uint32_t num_threads = std::max(128U, bdx * bdy); + constexpr uint32_t bdz = num_threads / (bdx * bdy); + const uint32_t gdy = ceil_div(num_qo_heads, qo_heads_per_block); + + auto compute_capacity = GetCudaComputeCapability(); + DISPATCH_COMPUTE_CAP_DECODE_NUM_STAGES_SMEM(compute_capacity, NUM_STAGES_SMEM, { + const uint32_t smem_size = + NUM_STAGES_SMEM * bdy * bdz * (HEAD_DIM_CKV + HEAD_DIM_KPE) * sizeof(DTypeKV) + + std::max(num_threads * sizeof(size_t) * 2, 2 * bdy * bdz * sizeof(float)); + + auto kernel = + BatchDecodeWithPagedKVCacheKernelMLA; + FLASHINFER_CUDA_CALL( + hipFuncSetAttribute((void*)kernel, hipFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + + if (tmp_v == nullptr) { + // do not use partition-kv kernel + dim3 nblks(padded_batch_size, gdy); + dim3 nthrs(bdx, bdy, bdz); + params.partition_kv = false; + BatchDecodeWithPagedKVCacheKernelMLA + <<>>(params); + } else { + // use partition-kv kernel + params.partition_kv = true; + auto o = params.o; + auto lse = params.lse; + params.o = tmp_v; + params.lse = tmp_s; + void* args[] = {(void*)¶ms}; + dim3 nblks(padded_batch_size, gdy); + dim3 nthrs(bdx, bdy, bdz); + BatchDecodeWithPagedKVCacheKernelMLA + <<>>(params); + + CHECK_HIP_ERROR(VariableLengthMergeStates(tmp_v, tmp_s, params.o_indptr, o, lse, + params.paged_kv.batch_size, nullptr, num_qo_heads, + HEAD_DIM_CKV, stream)); + } + }); + return hipSuccess; } -} // namespace flashinfer +} // namespace flashinfer -#endif // FLASHINFER_DECODE_CUH_ +#endif // FLASHINFER_DECODE_CUH_ diff --git a/libflashinfer/include/flashinfer/hip/attention/default_decode_params.hip.h b/libflashinfer/include/flashinfer/hip/attention/default_decode_params.hip.h index 3078becf13..6cb4c8e767 100644 --- a/libflashinfer/include/flashinfer/hip/attention/default_decode_params.hip.h +++ b/libflashinfer/include/flashinfer/hip/attention/default_decode_params.hip.h @@ -7,259 +7,263 @@ #ifndef FLASHINFER_DECODE_PARAMS_CUH_ #define FLASHINFER_DECODE_PARAMS_CUH_ -#include "../page.hip.h" - #include #include -namespace flashinfer -{ +#include "../page.hip.h" + +namespace flashinfer { template -struct SingleDecodeParams -{ - using DTypeQ = DTypeQ_; - using DTypeKV = DTypeKV_; - using DTypeO = DTypeO_; - using IdType = int32_t; - DTypeQ *q; - DTypeKV *k; - DTypeKV *v; - DTypeO *o; - float *lse; - float *maybe_alibi_slopes; - uint32_t kv_len; - uint32_t num_qo_heads; - uint32_t num_kv_heads; - uint32_t q_stride_n; - uint32_t q_stride_h; - uint32_t kv_stride_n; - uint32_t kv_stride_h; - int32_t window_left; - float logits_soft_cap; - float sm_scale; - float rope_rcp_scale; - float rope_rcp_theta; - uint32_t kv_chunk_size; +struct SingleDecodeParams { + using DTypeQ = DTypeQ_; + using DTypeKV = DTypeKV_; + using DTypeO = DTypeO_; + using IdType = int32_t; + DTypeQ* q; + DTypeKV* k; + DTypeKV* v; + DTypeO* o; + float* lse; + float* maybe_alibi_slopes; + uint32_t kv_len; + uint32_t num_qo_heads; + uint32_t num_kv_heads; + uint32_t q_stride_n; + uint32_t q_stride_h; + uint32_t kv_stride_n; + uint32_t kv_stride_h; + int32_t window_left; + float logits_soft_cap; + float sm_scale; + float rope_rcp_scale; + float rope_rcp_theta; + uint32_t kv_chunk_size; - __device__ __host__ SingleDecodeParams() - : q(nullptr), k(nullptr), v(nullptr), o(nullptr), lse(nullptr), - maybe_alibi_slopes(nullptr), kv_len(0), num_qo_heads(0), - num_kv_heads(0), q_stride_n(0), q_stride_h(0), kv_stride_n(0), - kv_stride_h(0), window_left(0), logits_soft_cap(0.0f), sm_scale(0.0f), - rope_rcp_scale(0.0f), rope_rcp_theta(0.0f), kv_chunk_size(0) - { - } + __device__ __host__ SingleDecodeParams() + : q(nullptr), + k(nullptr), + v(nullptr), + o(nullptr), + lse(nullptr), + maybe_alibi_slopes(nullptr), + kv_len(0), + num_qo_heads(0), + num_kv_heads(0), + q_stride_n(0), + q_stride_h(0), + kv_stride_n(0), + kv_stride_h(0), + window_left(0), + logits_soft_cap(0.0f), + sm_scale(0.0f), + rope_rcp_scale(0.0f), + rope_rcp_theta(0.0f), + kv_chunk_size(0) {} - __device__ __host__ SingleDecodeParams(DTypeQ *q, - DTypeKV *k, - DTypeKV *v, - DTypeO *o, - float *maybe_alibi_slopes, - uint32_t seq_len, - uint32_t num_qo_heads, - uint32_t num_kv_heads, - QKVLayout kv_layout, - uint32_t head_dim, - int32_t window_left, - float logits_soft_cap, - float sm_scale, - float rope_scale, - float rope_theta) - : q(q), k(k), v(v), o(o), lse(nullptr), - maybe_alibi_slopes(maybe_alibi_slopes), kv_len(seq_len), - num_qo_heads(num_qo_heads), num_kv_heads(num_kv_heads), - q_stride_n(num_qo_heads * head_dim), q_stride_h(head_dim), - kv_stride_n((kv_layout == QKVLayout::kNHD) ? num_kv_heads * head_dim - : head_dim), - kv_stride_h((kv_layout == QKVLayout::kNHD) ? head_dim - : seq_len * head_dim), - window_left(window_left), logits_soft_cap(logits_soft_cap), - sm_scale(sm_scale), rope_rcp_scale(1.f / rope_scale), - rope_rcp_theta(1.f / rope_theta), kv_chunk_size(0) - { - } + __device__ __host__ SingleDecodeParams(DTypeQ* q, DTypeKV* k, DTypeKV* v, DTypeO* o, + float* maybe_alibi_slopes, uint32_t seq_len, + uint32_t num_qo_heads, uint32_t num_kv_heads, + QKVLayout kv_layout, uint32_t head_dim, + int32_t window_left, float logits_soft_cap, float sm_scale, + float rope_scale, float rope_theta) + : q(q), + k(k), + v(v), + o(o), + lse(nullptr), + maybe_alibi_slopes(maybe_alibi_slopes), + kv_len(seq_len), + num_qo_heads(num_qo_heads), + num_kv_heads(num_kv_heads), + q_stride_n(num_qo_heads * head_dim), + q_stride_h(head_dim), + kv_stride_n((kv_layout == QKVLayout::kNHD) ? num_kv_heads * head_dim : head_dim), + kv_stride_h((kv_layout == QKVLayout::kNHD) ? head_dim : seq_len * head_dim), + window_left(window_left), + logits_soft_cap(logits_soft_cap), + sm_scale(sm_scale), + rope_rcp_scale(1.f / rope_scale), + rope_rcp_theta(1.f / rope_theta), + kv_chunk_size(0) {} - __host__ __device__ __forceinline__ uint32_t - get_qo_len(uint32_t batch_idx) const - { - return 1; - } + __host__ __device__ __forceinline__ uint32_t get_qo_len(uint32_t batch_idx) const { return 1; } - __host__ __device__ __forceinline__ uint32_t - get_kv_len(uint32_t batch_idx) const - { - return kv_len; - } + __host__ __device__ __forceinline__ uint32_t get_kv_len(uint32_t batch_idx) const { + return kv_len; + } }; -template -struct BatchDecodeParams -{ - using DTypeQ = DTypeQ_; - using DTypeKV = DTypeKV_; - using DTypeO = DTypeO_; - using IdType = IdType_; +template +struct BatchDecodeParams { + using DTypeQ = DTypeQ_; + using DTypeKV = DTypeKV_; + using DTypeO = DTypeO_; + using IdType = IdType_; - DTypeQ *q; - IdType *q_rope_offset; - paged_kv_t paged_kv; - DTypeO *o; - float *lse; - float *maybe_alibi_slopes; - uint32_t padded_batch_size; - uint32_t num_qo_heads; - IdType q_stride_n; - IdType q_stride_h; - int32_t window_left; - float logits_soft_cap; - float sm_scale; - float rope_rcp_scale; - float rope_rcp_theta; + DTypeQ* q; + IdType* q_rope_offset; + paged_kv_t paged_kv; + DTypeO* o; + float* lse; + float* maybe_alibi_slopes; + uint32_t padded_batch_size; + uint32_t num_qo_heads; + IdType q_stride_n; + IdType q_stride_h; + int32_t window_left; + float logits_soft_cap; + float sm_scale; + float rope_rcp_scale; + float rope_rcp_theta; - IdType *request_indices; - IdType *kv_tile_indices; - IdType *o_indptr; - IdType *kv_chunk_size_ptr; - bool *block_valid_mask; - bool partition_kv; + IdType* request_indices; + IdType* kv_tile_indices; + IdType* o_indptr; + IdType* kv_chunk_size_ptr; + bool* block_valid_mask; + bool partition_kv; - __device__ __host__ BatchDecodeParams() - : q(nullptr), q_rope_offset(nullptr), paged_kv(), o(nullptr), - lse(nullptr), maybe_alibi_slopes(nullptr), padded_batch_size(0), - num_qo_heads(0), q_stride_n(0), q_stride_h(0), window_left(0), - logits_soft_cap(0.0f), sm_scale(0.0f), rope_rcp_scale(0.0f), - rope_rcp_theta(0.0f), request_indices(nullptr), - kv_tile_indices(nullptr), o_indptr(nullptr), - kv_chunk_size_ptr(nullptr), block_valid_mask(nullptr), - partition_kv(false) - { - } + __device__ __host__ BatchDecodeParams() + : q(nullptr), + q_rope_offset(nullptr), + paged_kv(), + o(nullptr), + lse(nullptr), + maybe_alibi_slopes(nullptr), + padded_batch_size(0), + num_qo_heads(0), + q_stride_n(0), + q_stride_h(0), + window_left(0), + logits_soft_cap(0.0f), + sm_scale(0.0f), + rope_rcp_scale(0.0f), + rope_rcp_theta(0.0f), + request_indices(nullptr), + kv_tile_indices(nullptr), + o_indptr(nullptr), + kv_chunk_size_ptr(nullptr), + block_valid_mask(nullptr), + partition_kv(false) {} - __device__ __host__ BatchDecodeParams(DTypeQ *q, - IdType *q_rope_offset, - paged_kv_t paged_kv, - DTypeO *o, - float *lse, - float *maybe_alibi_slopes, - uint32_t num_qo_heads, - IdType q_stride_n, - IdType q_stride_h, - int32_t window_left, - float logits_soft_cap, - float sm_scale, - float rope_scale, - float rope_theta) - : q(q), q_rope_offset(q_rope_offset), paged_kv(paged_kv), o(o), - lse(lse), maybe_alibi_slopes(maybe_alibi_slopes), - padded_batch_size(0), num_qo_heads(num_qo_heads), - q_stride_n(q_stride_n), q_stride_h(q_stride_h), - window_left(window_left), logits_soft_cap(logits_soft_cap), - sm_scale(sm_scale), rope_rcp_scale(1.f / rope_scale), - rope_rcp_theta(1.f / rope_theta), request_indices(nullptr), - kv_tile_indices(nullptr), o_indptr(nullptr), - kv_chunk_size_ptr(nullptr), block_valid_mask(nullptr), - partition_kv(false) - { - } + __device__ __host__ BatchDecodeParams(DTypeQ* q, IdType* q_rope_offset, + paged_kv_t paged_kv, DTypeO* o, float* lse, + float* maybe_alibi_slopes, uint32_t num_qo_heads, + IdType q_stride_n, IdType q_stride_h, int32_t window_left, + float logits_soft_cap, float sm_scale, float rope_scale, + float rope_theta) + : q(q), + q_rope_offset(q_rope_offset), + paged_kv(paged_kv), + o(o), + lse(lse), + maybe_alibi_slopes(maybe_alibi_slopes), + padded_batch_size(0), + num_qo_heads(num_qo_heads), + q_stride_n(q_stride_n), + q_stride_h(q_stride_h), + window_left(window_left), + logits_soft_cap(logits_soft_cap), + sm_scale(sm_scale), + rope_rcp_scale(1.f / rope_scale), + rope_rcp_theta(1.f / rope_theta), + request_indices(nullptr), + kv_tile_indices(nullptr), + o_indptr(nullptr), + kv_chunk_size_ptr(nullptr), + block_valid_mask(nullptr), + partition_kv(false) {} - __host__ __device__ __forceinline__ int32_t - get_qo_len(int32_t batch_idx) const - { - return 1; - } + __host__ __device__ __forceinline__ int32_t get_qo_len(int32_t batch_idx) const { return 1; } - __host__ __device__ __forceinline__ int32_t - get_kv_len(int32_t batch_idx) const - { - return paged_kv.get_length(batch_idx); - } + __host__ __device__ __forceinline__ int32_t get_kv_len(int32_t batch_idx) const { + return paged_kv.get_length(batch_idx); + } }; -template -struct BatchDecodeParamsMLA -{ - using DTypeQ = DTypeQ_; - using DTypeKV = DTypeKV_; - using DTypeO = DTypeO_; - using IdType = IdType_; +template +struct BatchDecodeParamsMLA { + using DTypeQ = DTypeQ_; + using DTypeKV = DTypeKV_; + using DTypeO = DTypeO_; + using IdType = IdType_; - DTypeQ *q_nope; - DTypeQ *q_pe; - DTypeO *o; - float *lse; - float sm_scale; + DTypeQ* q_nope; + DTypeQ* q_pe; + DTypeO* o; + float* lse; + float sm_scale; - IdType *q_rope_offset; - paged_kv_mla_t paged_kv; - uint32_t padded_batch_size; - uint32_t num_qo_heads; - int32_t window_left; - float logits_soft_cap; - float rope_rcp_scale; - float rope_rcp_theta; + IdType* q_rope_offset; + paged_kv_mla_t paged_kv; + uint32_t padded_batch_size; + uint32_t num_qo_heads; + int32_t window_left; + float logits_soft_cap; + float rope_rcp_scale; + float rope_rcp_theta; - IdType *request_indices; - IdType *kv_tile_indices; - IdType *o_indptr; - IdType *kv_chunk_size_ptr; - bool *block_valid_mask; - bool partition_kv; + IdType* request_indices; + IdType* kv_tile_indices; + IdType* o_indptr; + IdType* kv_chunk_size_ptr; + bool* block_valid_mask; + bool partition_kv; - __device__ __host__ BatchDecodeParamsMLA() - : q_nope(nullptr), q_pe(nullptr), o(nullptr), lse(nullptr), - sm_scale(0.0f), q_rope_offset(nullptr), paged_kv(), - padded_batch_size(0), num_qo_heads(0), window_left(0), - logits_soft_cap(0.0f), rope_rcp_scale(0.0f), rope_rcp_theta(0.0f), - request_indices(nullptr), kv_tile_indices(nullptr), o_indptr(nullptr), - kv_chunk_size_ptr(nullptr), block_valid_mask(nullptr), - partition_kv(false) - { - } + __device__ __host__ BatchDecodeParamsMLA() + : q_nope(nullptr), + q_pe(nullptr), + o(nullptr), + lse(nullptr), + sm_scale(0.0f), + q_rope_offset(nullptr), + paged_kv(), + padded_batch_size(0), + num_qo_heads(0), + window_left(0), + logits_soft_cap(0.0f), + rope_rcp_scale(0.0f), + rope_rcp_theta(0.0f), + request_indices(nullptr), + kv_tile_indices(nullptr), + o_indptr(nullptr), + kv_chunk_size_ptr(nullptr), + block_valid_mask(nullptr), + partition_kv(false) {} - __device__ __host__ - BatchDecodeParamsMLA(DTypeQ *q_nope, - DTypeQ *q_pe, - IdType *q_rope_offset, - paged_kv_mla_t paged_kv, - DTypeO *o, - float *lse, - uint32_t num_qo_heads, - int32_t window_left, - float logits_soft_cap, - float sm_scale, - float rope_scale, - float rope_theta) - : q_nope(q_nope), q_pe(q_pe), o(o), lse(lse), sm_scale(sm_scale), - q_rope_offset(q_rope_offset), paged_kv(paged_kv), - padded_batch_size(0), num_qo_heads(num_qo_heads), - window_left(window_left), logits_soft_cap(logits_soft_cap), - rope_rcp_scale(1.f / rope_scale), rope_rcp_theta(1.f / rope_theta), - request_indices(nullptr), kv_tile_indices(nullptr), o_indptr(nullptr), - kv_chunk_size_ptr(nullptr), block_valid_mask(nullptr), - partition_kv(false) - { - } + __device__ __host__ BatchDecodeParamsMLA(DTypeQ* q_nope, DTypeQ* q_pe, IdType* q_rope_offset, + paged_kv_mla_t paged_kv, DTypeO* o, + float* lse, uint32_t num_qo_heads, int32_t window_left, + float logits_soft_cap, float sm_scale, float rope_scale, + float rope_theta) + : q_nope(q_nope), + q_pe(q_pe), + o(o), + lse(lse), + sm_scale(sm_scale), + q_rope_offset(q_rope_offset), + paged_kv(paged_kv), + padded_batch_size(0), + num_qo_heads(num_qo_heads), + window_left(window_left), + logits_soft_cap(logits_soft_cap), + rope_rcp_scale(1.f / rope_scale), + rope_rcp_theta(1.f / rope_theta), + request_indices(nullptr), + kv_tile_indices(nullptr), + o_indptr(nullptr), + kv_chunk_size_ptr(nullptr), + block_valid_mask(nullptr), + partition_kv(false) {} - __host__ __device__ __forceinline__ int32_t - get_qo_len(int32_t batch_idx) const - { - return 1; - } - __host__ __device__ __forceinline__ int32_t - get_kv_len(int32_t batch_idx) const - { - return paged_kv.get_length(batch_idx); - } + __host__ __device__ __forceinline__ int32_t get_qo_len(int32_t batch_idx) const { return 1; } + __host__ __device__ __forceinline__ int32_t get_kv_len(int32_t batch_idx) const { + return paged_kv.get_length(batch_idx); + } }; -} // namespace flashinfer +} // namespace flashinfer -#endif // FLASHINFER_DECODE_PARAMS_CUH_ +#endif // FLASHINFER_DECODE_PARAMS_CUH_ diff --git a/libflashinfer/include/flashinfer/hip/attention/heap.h b/libflashinfer/include/flashinfer/hip/attention/heap.h index 052d35ac4f..f703c6dbb4 100644 --- a/libflashinfer/include/flashinfer/hip/attention/heap.h +++ b/libflashinfer/include/flashinfer/hip/attention/heap.h @@ -12,52 +12,46 @@ #include #include -namespace flashinfer -{ +namespace flashinfer { /*! * \brief Heap data structure for (index, value) pairs * \note minimal element on top */ -class MinHeap -{ -public: - // first: index, second: cost - using Element = std::pair; - - MinHeap(int capacity) : heap_(capacity) - { - for (int i = 0; i < capacity; ++i) { - heap_[i] = std::make_pair(i, 0.f); - } +class MinHeap { + public: + // first: index, second: cost + using Element = std::pair; + + MinHeap(int capacity) : heap_(capacity) { + for (int i = 0; i < capacity; ++i) { + heap_[i] = std::make_pair(i, 0.f); } + } - void insert(const Element &element) - { - heap_.push_back(element); - std::push_heap(heap_.begin(), heap_.end(), compare); - } + void insert(const Element& element) { + heap_.push_back(element); + std::push_heap(heap_.begin(), heap_.end(), compare); + } - Element pop() - { - std::pop_heap(heap_.begin(), heap_.end(), compare); - Element minElement = heap_.back(); - heap_.pop_back(); - return minElement; - } + Element pop() { + std::pop_heap(heap_.begin(), heap_.end(), compare); + Element minElement = heap_.back(); + heap_.pop_back(); + return minElement; + } - std::vector getHeap() const { return heap_; } + std::vector getHeap() const { return heap_; } -private: - // Custom comparator for the min-heap: compare based on 'val' in the pair - static bool compare(const Element &a, const Element &b) - { - return a.second > b.second; // create a min-heap based on val - } + private: + // Custom comparator for the min-heap: compare based on 'val' in the pair + static bool compare(const Element& a, const Element& b) { + return a.second > b.second; // create a min-heap based on val + } - std::vector heap_; + std::vector heap_; }; -} // namespace flashinfer +} // namespace flashinfer -#endif // FLASHINFER_ATTENTION_HEAP_H +#endif // FLASHINFER_ATTENTION_HEAP_H diff --git a/libflashinfer/include/flashinfer/hip/attention/mask.hip.h b/libflashinfer/include/flashinfer/hip/attention/mask.hip.h index 55fe475d2d..c3a0d1563c 100644 --- a/libflashinfer/include/flashinfer/hip/attention/mask.hip.h +++ b/libflashinfer/include/flashinfer/hip/attention/mask.hip.h @@ -7,16 +7,14 @@ #ifndef FLASHINFER_ATTENTION_MASK_CUH_ #define FLASHINFER_ATTENTION_MASK_CUH_ -namespace flashinfer -{ +namespace flashinfer { -enum class MaskMode -{ - kNone = 0U, // No mask - kCausal = 1U, // Causal mask - kCustom = 2U, // Custom mask +enum class MaskMode { + kNone = 0U, // No mask + kCausal = 1U, // Causal mask + kCustom = 2U, // Custom mask }; -} // namespace flashinfer +} // namespace flashinfer -#endif // FLASHINFER_ATTENTION_MASK_CUH_ +#endif // FLASHINFER_ATTENTION_MASK_CUH_ diff --git a/libflashinfer/include/flashinfer/hip/attention/scheduler.hip.h b/libflashinfer/include/flashinfer/hip/attention/scheduler.hip.h index 6b514d9db7..350e120e3d 100644 --- a/libflashinfer/include/flashinfer/hip/attention/scheduler.hip.h +++ b/libflashinfer/include/flashinfer/hip/attention/scheduler.hip.h @@ -7,12 +7,6 @@ #ifndef FLASHINFER_ATTENTION_SCHEDULER_CUH_ #define FLASHINFER_ATTENTION_SCHEDULER_CUH_ -#include "../../allocator.h" -#include "../../exception.h" -#include "../pos_enc.hip.h" -#include "../utils.hip.h" -#include "heap.h" - #include #include @@ -21,42 +15,29 @@ #include #include -namespace flashinfer -{ - -template __global__ void BatchDecodeWithPagedKVCacheKernel(const Params params); -template __global__ void BatchDecodeWithPagedKVCacheKernelMLA(Params params); -template -std::tuple -LaunchSpecForDecodeKernelMlaCuteSM80(const uint32_t num_qo_heads); +template +std::tuple LaunchSpecForDecodeKernelMlaCuteSM80( + const uint32_t num_qo_heads); -template +template __global__ void BatchDecodeWithPagedKVCacheKernelMlaCuteSM80(Params params); /*! @@ -74,73 +55,65 @@ __global__ void BatchDecodeWithPagedKVCacheKernelMlaCuteSM80(Params params); */ template inline auto PartitionPagedKVCacheBinarySearchMinNumPagePerBatch( - const uint32_t max_grid_size, - const uint32_t gdy, - const std::vector &num_pages, - const uint32_t min_num_pages_per_batch = 1) -{ - uint32_t low = min_num_pages_per_batch, high = 0; - for (const IdType &elem : num_pages) { - high = max(high, elem); - } - uint32_t new_batch_size; - while (low < high) { - uint32_t mid = (low + high) / 2; - new_batch_size = 0; - assert(mid > 0); - for (const IdType &elem : num_pages) { - new_batch_size += ceil_div(elem, mid); - } - if (new_batch_size * gdy > max_grid_size) { - low = mid + 1; - } - else { - high = mid; - } - } + const uint32_t max_grid_size, const uint32_t gdy, const std::vector& num_pages, + const uint32_t min_num_pages_per_batch = 1) { + uint32_t low = min_num_pages_per_batch, high = 0; + for (const IdType& elem : num_pages) { + high = max(high, elem); + } + uint32_t new_batch_size; + while (low < high) { + uint32_t mid = (low + high) / 2; new_batch_size = 0; - assert(low > 0); - for (const IdType &elem : num_pages) { - new_batch_size += ceil_div(std::max(elem, 1), low); + assert(mid > 0); + for (const IdType& elem : num_pages) { + new_batch_size += ceil_div(elem, mid); } - return std::make_tuple(low, new_batch_size); + if (new_batch_size * gdy > max_grid_size) { + low = mid + 1; + } else { + high = mid; + } + } + new_batch_size = 0; + assert(low > 0); + for (const IdType& elem : num_pages) { + new_batch_size += ceil_div(std::max(elem, 1), low); + } + return std::make_tuple(low, new_batch_size); } -inline auto -PrefillBinarySearchKVChunkSize(const bool enable_cuda_graph, - const uint32_t max_batch_size_if_split, - const std::vector &packed_qo_len_arr, - const std::vector &kv_len_arr, - const uint32_t qo_chunk_size, - const uint32_t min_kv_chunk_size = 1) -{ - const int64_t batch_size = packed_qo_len_arr.size(); - int64_t max_kv_len = 1; - for (const int64_t &kv_len : kv_len_arr) { - max_kv_len = std::max(max_kv_len, kv_len); +inline auto PrefillBinarySearchKVChunkSize(const bool enable_cuda_graph, + const uint32_t max_batch_size_if_split, + const std::vector& packed_qo_len_arr, + const std::vector& kv_len_arr, + const uint32_t qo_chunk_size, + const uint32_t min_kv_chunk_size = 1) { + const int64_t batch_size = packed_qo_len_arr.size(); + int64_t max_kv_len = 1; + for (const int64_t& kv_len : kv_len_arr) { + max_kv_len = std::max(max_kv_len, kv_len); + } + + int64_t low = min_kv_chunk_size; + int64_t high = max_kv_len; + constexpr int64_t min_kv_len = 1; + assert(qo_chunk_size > 0); + while (low < high) { + const int64_t mid = (low + high) / 2; + int64_t new_batch_size = 0; + assert(mid > 0); + for (uint32_t i = 0; i < batch_size; ++i) { + new_batch_size += ceil_div(packed_qo_len_arr[i], qo_chunk_size) * + ceil_div(std::max(kv_len_arr[i], min_kv_len), mid); } - - int64_t low = min_kv_chunk_size; - int64_t high = max_kv_len; - constexpr int64_t min_kv_len = 1; - assert(qo_chunk_size > 0); - while (low < high) { - const int64_t mid = (low + high) / 2; - int64_t new_batch_size = 0; - assert(mid > 0); - for (uint32_t i = 0; i < batch_size; ++i) { - new_batch_size += - ceil_div(packed_qo_len_arr[i], qo_chunk_size) * - ceil_div(std::max(kv_len_arr[i], min_kv_len), mid); - } - if (new_batch_size > max_batch_size_if_split) { - low = mid + 1; - } - else { - high = mid; - } + if (new_batch_size > max_batch_size_if_split) { + low = mid + 1; + } else { + high = mid; } - return std::make_tuple(enable_cuda_graph || low < max_kv_len, low); + } + return std::make_tuple(enable_cuda_graph || low < max_kv_len, low); } /*! @@ -161,268 +134,209 @@ PrefillBinarySearchKVChunkSize(const bool enable_cuda_graph, * \param stream The cuda stream to launch the kernel * \return status Indicates whether CUDA calls are successful */ -template +template inline hipError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatched( - bool &split_kv, - uint32_t &max_grid_size, - uint32_t &max_num_pages_per_batch, - uint32_t &new_batch_size, - uint32_t &gdy, - uint32_t batch_size, - typename Params::IdType *kv_indptr_h, - const uint32_t num_qo_heads, - const uint32_t page_size, - bool enable_cuda_graph, - hipStream_t stream) -{ - using DTypeKV = typename Params::DTypeKV; - using IdType = typename Params::IdType; - constexpr uint32_t temp_1st = 16UL / sizeof(DTypeKV); - constexpr uint32_t temp_2nd = HEAD_DIM / 32UL; - constexpr uint32_t vec_size = temp_1st < temp_2nd ? temp_2nd : temp_1st; - auto compute_capacity = GetCudaComputeCapability(); - DISPATCH_COMPUTE_CAP_DECODE_NUM_STAGES_SMEM( - compute_capacity, NUM_STAGES_SMEM, { - constexpr uint32_t bdx = HEAD_DIM / vec_size; - static_assert(bdx <= 32); - constexpr uint32_t bdy = GROUP_SIZE; - constexpr uint32_t num_threads = 128U < bdx * bdy ? bdx * bdy : 128U; - static_assert(bdx > 0); - static_assert(bdy > 0); - constexpr uint32_t bdz = num_threads / (bdx * bdy); - constexpr uint32_t tile_size_per_bdx = - GROUP_SIZE == 1 ? (sizeof(DTypeKV) == 1 ? 2U : 4U) : 1U; - const uint32_t num_kv_heads = num_qo_heads / GROUP_SIZE; - gdy = num_kv_heads; - const uint32_t smem_size = - 2 * NUM_STAGES_SMEM * tile_size_per_bdx * bdy * bdz * HEAD_DIM * - sizeof(DTypeKV) + - std::max(tile_size_per_bdx * num_threads * sizeof(DTypeKV *), - 2 * bdy * bdz * sizeof(float)); - - auto kernel = BatchDecodeWithPagedKVCacheKernel< - POS_ENCODING_MODE, NUM_STAGES_SMEM, tile_size_per_bdx, vec_size, - bdx, bdy, bdz, AttentionVariant, Params>; - int num_blocks_per_sm = 0; - int num_sm = 0; - int dev_id = 0; - FLASHINFER_CUDA_CALL(hipGetDevice(&dev_id)); - FLASHINFER_CUDA_CALL(hipDeviceGetAttribute( - &num_sm, hipDeviceAttributeMultiprocessorCount, dev_id)); - FLASHINFER_CUDA_CALL(hipOccupancyMaxActiveBlocksPerMultiprocessor( - &num_blocks_per_sm, kernel, num_threads, smem_size)); - max_grid_size = num_blocks_per_sm * num_sm; - if (batch_size * gdy >= max_grid_size) { - split_kv = false; - max_num_pages_per_batch = 1; - for (uint32_t batch_idx = 0; batch_idx < batch_size; - ++batch_idx) { - max_num_pages_per_batch = std::max( - max_num_pages_per_batch, - kv_indptr_h[batch_idx + 1] - kv_indptr_h[batch_idx]); - } - new_batch_size = batch_size; - } - else { - // compute max_num_pages_per_batch and new_batch_size - std::vector num_pages(batch_size); - for (uint32_t batch_idx = 0; batch_idx < batch_size; - ++batch_idx) { - num_pages[batch_idx] = - kv_indptr_h[batch_idx + 1] - kv_indptr_h[batch_idx]; - } - std::tie(max_num_pages_per_batch, new_batch_size) = - PartitionPagedKVCacheBinarySearchMinNumPagePerBatch( - max_grid_size, gdy, num_pages, - std::max(128 / page_size, 1U)); - if (new_batch_size == batch_size && !enable_cuda_graph) { - // do not use partition-kv kernel for short sequence, when - // not using CUDAGraph - split_kv = false; - } - else { - // when using CUDAGraph, we always use partition-kv kernel - split_kv = true; - } - } - return hipSuccess; - }) + bool& split_kv, uint32_t& max_grid_size, uint32_t& max_num_pages_per_batch, + uint32_t& new_batch_size, uint32_t& gdy, uint32_t batch_size, + typename Params::IdType* kv_indptr_h, const uint32_t num_qo_heads, const uint32_t page_size, + bool enable_cuda_graph, hipStream_t stream) { + using DTypeKV = typename Params::DTypeKV; + using IdType = typename Params::IdType; + constexpr uint32_t temp_1st = 16UL / sizeof(DTypeKV); + constexpr uint32_t temp_2nd = HEAD_DIM / 32UL; + constexpr uint32_t vec_size = temp_1st < temp_2nd ? temp_2nd : temp_1st; + auto compute_capacity = GetCudaComputeCapability(); + DISPATCH_COMPUTE_CAP_DECODE_NUM_STAGES_SMEM(compute_capacity, NUM_STAGES_SMEM, { + constexpr uint32_t bdx = HEAD_DIM / vec_size; + static_assert(bdx <= 32); + constexpr uint32_t bdy = GROUP_SIZE; + constexpr uint32_t num_threads = 128U < bdx * bdy ? bdx * bdy : 128U; + static_assert(bdx > 0); + static_assert(bdy > 0); + constexpr uint32_t bdz = num_threads / (bdx * bdy); + constexpr uint32_t tile_size_per_bdx = GROUP_SIZE == 1 ? (sizeof(DTypeKV) == 1 ? 2U : 4U) : 1U; + const uint32_t num_kv_heads = num_qo_heads / GROUP_SIZE; + gdy = num_kv_heads; + const uint32_t smem_size = + 2 * NUM_STAGES_SMEM * tile_size_per_bdx * bdy * bdz * HEAD_DIM * sizeof(DTypeKV) + + std::max(tile_size_per_bdx * num_threads * sizeof(DTypeKV*), 2 * bdy * bdz * sizeof(float)); + + auto kernel = + BatchDecodeWithPagedKVCacheKernel; + int num_blocks_per_sm = 0; + int num_sm = 0; + int dev_id = 0; + FLASHINFER_CUDA_CALL(hipGetDevice(&dev_id)); + FLASHINFER_CUDA_CALL( + hipDeviceGetAttribute(&num_sm, hipDeviceAttributeMultiprocessorCount, dev_id)); + FLASHINFER_CUDA_CALL(hipOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks_per_sm, kernel, + num_threads, smem_size)); + max_grid_size = num_blocks_per_sm * num_sm; + if (batch_size * gdy >= max_grid_size) { + split_kv = false; + max_num_pages_per_batch = 1; + for (uint32_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) { + max_num_pages_per_batch = std::max( + max_num_pages_per_batch, kv_indptr_h[batch_idx + 1] - kv_indptr_h[batch_idx]); + } + new_batch_size = batch_size; + } else { + // compute max_num_pages_per_batch and new_batch_size + std::vector num_pages(batch_size); + for (uint32_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) { + num_pages[batch_idx] = kv_indptr_h[batch_idx + 1] - kv_indptr_h[batch_idx]; + } + std::tie(max_num_pages_per_batch, new_batch_size) = + PartitionPagedKVCacheBinarySearchMinNumPagePerBatch(max_grid_size, gdy, num_pages, + std::max(128 / page_size, 1U)); + if (new_batch_size == batch_size && !enable_cuda_graph) { + // do not use partition-kv kernel for short sequence, when + // not using CUDAGraph + split_kv = false; + } else { + // when using CUDAGraph, we always use partition-kv kernel + split_kv = true; + } + } + return hipSuccess; + }) } -template +template inline hipError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatchedMLA( - bool &split_kv, - uint32_t &max_grid_size, - uint32_t &max_num_pages_per_batch, - uint32_t &new_batch_size, - uint32_t &gdy, - uint32_t batch_size, - typename Params::IdType *kv_indptr_h, - const uint32_t num_qo_heads, - const uint32_t page_size, - bool enable_cuda_graph, - hipStream_t stream) -{ - using DTypeKV = typename Params::DTypeKV; - using IdType = typename Params::IdType; - - auto compute_capacity = GetCudaComputeCapability(); - DISPATCH_COMPUTE_CAP_DECODE_NUM_STAGES_SMEM( - compute_capacity, NUM_STAGES_SMEM, { - constexpr uint32_t temp_1st = 16UL / sizeof(DTypeKV); - constexpr uint32_t temp_2nd = HEAD_DIM_CKV / 32UL; - constexpr uint32_t vec_size_ckv = temp_1st < temp_2nd ? temp_2nd : temp_1st; - constexpr uint32_t bdx = HEAD_DIM_CKV / vec_size_ckv; - constexpr uint32_t vec_size_kpe = HEAD_DIM_KPE / bdx; - - constexpr uint32_t bdy = 8; - constexpr uint32_t tile_size_qo_heads = 2; - constexpr uint32_t qo_heads_per_block = bdy * tile_size_qo_heads; - static_assert(qo_heads_per_block > 0); - constexpr uint32_t num_threads = 128U < bdx * bdy ? bdx * bdy : 128U; - constexpr uint32_t bdz = num_threads / (bdx * bdy); - gdy = ceil_div(num_qo_heads, qo_heads_per_block); - - const uint32_t smem_size = - NUM_STAGES_SMEM * bdy * bdz * (HEAD_DIM_CKV + HEAD_DIM_KPE) * - sizeof(DTypeKV) + - std::max(num_threads * sizeof(size_t) * 2, - 2 * bdy * bdz * sizeof(float)); - - auto kernel = BatchDecodeWithPagedKVCacheKernelMLA< - NUM_STAGES_SMEM, vec_size_ckv, vec_size_kpe, bdx, bdy, bdz, - tile_size_qo_heads, AttentionVariant, Params>; - int num_blocks_per_sm = 0; - int num_sm = 0; - int dev_id = 0; - FLASHINFER_CUDA_CALL(hipGetDevice(&dev_id)); - FLASHINFER_CUDA_CALL(hipDeviceGetAttribute( - &num_sm, hipDeviceAttributeMultiprocessorCount, dev_id)); - FLASHINFER_CUDA_CALL(hipOccupancyMaxActiveBlocksPerMultiprocessor( - &num_blocks_per_sm, kernel, num_threads, smem_size)); - max_grid_size = num_blocks_per_sm * num_sm; - if (batch_size * gdy >= max_grid_size) { - split_kv = false; - max_num_pages_per_batch = 1; - for (uint32_t batch_idx = 0; batch_idx < batch_size; - ++batch_idx) { - max_num_pages_per_batch = std::max( - max_num_pages_per_batch, - kv_indptr_h[batch_idx + 1] - kv_indptr_h[batch_idx]); - } - new_batch_size = batch_size; - } - else { - // compute max_num_pages_per_batch and new_batch_size - std::vector num_pages(batch_size); - for (uint32_t batch_idx = 0; batch_idx < batch_size; - ++batch_idx) { - num_pages[batch_idx] = - kv_indptr_h[batch_idx + 1] - kv_indptr_h[batch_idx]; - } - std::tie(max_num_pages_per_batch, new_batch_size) = - PartitionPagedKVCacheBinarySearchMinNumPagePerBatch( - max_grid_size, gdy, num_pages, - std::max(128 / page_size, 1U)); - if (new_batch_size == batch_size && !enable_cuda_graph) { - // do not use partition-kv kernel for short sequence, when - // not using CUDAGraph - split_kv = false; - } - else { - // when using CUDAGraph, we always use partition-kv kernel - split_kv = true; - } - } - - return hipSuccess; - }); -} + bool& split_kv, uint32_t& max_grid_size, uint32_t& max_num_pages_per_batch, + uint32_t& new_batch_size, uint32_t& gdy, uint32_t batch_size, + typename Params::IdType* kv_indptr_h, const uint32_t num_qo_heads, const uint32_t page_size, + bool enable_cuda_graph, hipStream_t stream) { + using DTypeKV = typename Params::DTypeKV; + using IdType = typename Params::IdType; + + auto compute_capacity = GetCudaComputeCapability(); + DISPATCH_COMPUTE_CAP_DECODE_NUM_STAGES_SMEM(compute_capacity, NUM_STAGES_SMEM, { + constexpr uint32_t temp_1st = 16UL / sizeof(DTypeKV); + constexpr uint32_t temp_2nd = HEAD_DIM_CKV / 32UL; + constexpr uint32_t vec_size_ckv = temp_1st < temp_2nd ? temp_2nd : temp_1st; + constexpr uint32_t bdx = HEAD_DIM_CKV / vec_size_ckv; + constexpr uint32_t vec_size_kpe = HEAD_DIM_KPE / bdx; + + constexpr uint32_t bdy = 8; + constexpr uint32_t tile_size_qo_heads = 2; + constexpr uint32_t qo_heads_per_block = bdy * tile_size_qo_heads; + static_assert(qo_heads_per_block > 0); + constexpr uint32_t num_threads = 128U < bdx * bdy ? bdx * bdy : 128U; + constexpr uint32_t bdz = num_threads / (bdx * bdy); + gdy = ceil_div(num_qo_heads, qo_heads_per_block); + + const uint32_t smem_size = + NUM_STAGES_SMEM * bdy * bdz * (HEAD_DIM_CKV + HEAD_DIM_KPE) * sizeof(DTypeKV) + + std::max(num_threads * sizeof(size_t) * 2, 2 * bdy * bdz * sizeof(float)); -template -inline hipError_t -BatchDecodeWithPagedKVCacheWorkEstimationDispatchedMlaCuteSM80( - bool &split_kv, - uint32_t &max_grid_size, - uint32_t &max_num_pages_per_batch, - uint32_t &new_batch_size, - uint32_t &gdy_, - uint32_t batch_size, - typename Params::IdType *kv_indptr_h, - const uint32_t num_qo_heads, - const uint32_t page_size, - bool enable_cuda_graph, - hipStream_t stream) -{ - using DTypeKV = typename Params::DTypeKV; - using IdType = typename Params::IdType; - - auto [smem_size, gdy, k_warps] = LaunchSpecForDecodeKernelMlaCuteSM80< - HEAD_DIM_CKV, HEAD_DIM_KPE, QO_TILE_LEN, DTypeKV>(num_qo_heads); - gdy_ = gdy; - const uint32_t num_threads = k_warps * 32; auto kernel = - BatchDecodeWithPagedKVCacheKernelMlaCuteSM80; - int num_blocks_per_sm; + BatchDecodeWithPagedKVCacheKernelMLA; + int num_blocks_per_sm = 0; int num_sm = 0; int dev_id = 0; FLASHINFER_CUDA_CALL(hipGetDevice(&dev_id)); - FLASHINFER_CUDA_CALL(hipDeviceGetAttribute( - &num_sm, hipDeviceAttributeMultiprocessorCount, dev_id)); - - // FLASHINFER_CUDA_CALL(hipOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks_per_sm, - // kernel, - // num_threads, smem_size)); - // fixme: num_blocks_per_sm is 0 derived from - // hipOccupancyMaxActiveBlocksPerMultiprocessor at times, and we fill smem - // with q-heads as many as possible, so num_blocks_per_sm should be 1 - num_blocks_per_sm = 1; - + FLASHINFER_CUDA_CALL( + hipDeviceGetAttribute(&num_sm, hipDeviceAttributeMultiprocessorCount, dev_id)); + FLASHINFER_CUDA_CALL(hipOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks_per_sm, kernel, + num_threads, smem_size)); max_grid_size = num_blocks_per_sm * num_sm; if (batch_size * gdy >= max_grid_size) { + split_kv = false; + max_num_pages_per_batch = 1; + for (uint32_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) { + max_num_pages_per_batch = std::max( + max_num_pages_per_batch, kv_indptr_h[batch_idx + 1] - kv_indptr_h[batch_idx]); + } + new_batch_size = batch_size; + } else { + // compute max_num_pages_per_batch and new_batch_size + std::vector num_pages(batch_size); + for (uint32_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) { + num_pages[batch_idx] = kv_indptr_h[batch_idx + 1] - kv_indptr_h[batch_idx]; + } + std::tie(max_num_pages_per_batch, new_batch_size) = + PartitionPagedKVCacheBinarySearchMinNumPagePerBatch(max_grid_size, gdy, num_pages, + std::max(128 / page_size, 1U)); + if (new_batch_size == batch_size && !enable_cuda_graph) { + // do not use partition-kv kernel for short sequence, when + // not using CUDAGraph split_kv = false; - max_num_pages_per_batch = 1; - for (uint32_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) { - max_num_pages_per_batch = std::max( - max_num_pages_per_batch, - kv_indptr_h[batch_idx + 1] - kv_indptr_h[batch_idx]); - } - new_batch_size = batch_size; - } - else { - // compute max_num_pages_per_batch and new_batch_size - std::vector num_pages(batch_size); - for (uint32_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) { - num_pages[batch_idx] = - kv_indptr_h[batch_idx + 1] - kv_indptr_h[batch_idx]; - } - std::tie(max_num_pages_per_batch, new_batch_size) = - PartitionPagedKVCacheBinarySearchMinNumPagePerBatch( - max_grid_size, gdy, num_pages, std::max(128 / page_size, 1U)); - if (new_batch_size == batch_size && !enable_cuda_graph) { - // do not use partition-kv kernel for short sequence, when not using - // CUDAGraph - split_kv = false; - } - else { - // when using CUDAGraph, we always use partition-kv kernel - split_kv = true; - } + } else { + // when using CUDAGraph, we always use partition-kv kernel + split_kv = true; + } } return hipSuccess; + }); +} + +template +inline hipError_t BatchDecodeWithPagedKVCacheWorkEstimationDispatchedMlaCuteSM80( + bool& split_kv, uint32_t& max_grid_size, uint32_t& max_num_pages_per_batch, + uint32_t& new_batch_size, uint32_t& gdy_, uint32_t batch_size, + typename Params::IdType* kv_indptr_h, const uint32_t num_qo_heads, const uint32_t page_size, + bool enable_cuda_graph, hipStream_t stream) { + using DTypeKV = typename Params::DTypeKV; + using IdType = typename Params::IdType; + + auto [smem_size, gdy, k_warps] = + LaunchSpecForDecodeKernelMlaCuteSM80( + num_qo_heads); + gdy_ = gdy; + const uint32_t num_threads = k_warps * 32; + auto kernel = + BatchDecodeWithPagedKVCacheKernelMlaCuteSM80; + int num_blocks_per_sm; + int num_sm = 0; + int dev_id = 0; + FLASHINFER_CUDA_CALL(hipGetDevice(&dev_id)); + FLASHINFER_CUDA_CALL( + hipDeviceGetAttribute(&num_sm, hipDeviceAttributeMultiprocessorCount, dev_id)); + + // FLASHINFER_CUDA_CALL(hipOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks_per_sm, + // kernel, + // num_threads, smem_size)); + // fixme: num_blocks_per_sm is 0 derived from + // hipOccupancyMaxActiveBlocksPerMultiprocessor at times, and we fill smem + // with q-heads as many as possible, so num_blocks_per_sm should be 1 + num_blocks_per_sm = 1; + + max_grid_size = num_blocks_per_sm * num_sm; + if (batch_size * gdy >= max_grid_size) { + split_kv = false; + max_num_pages_per_batch = 1; + for (uint32_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) { + max_num_pages_per_batch = std::max( + max_num_pages_per_batch, kv_indptr_h[batch_idx + 1] - kv_indptr_h[batch_idx]); + } + new_batch_size = batch_size; + } else { + // compute max_num_pages_per_batch and new_batch_size + std::vector num_pages(batch_size); + for (uint32_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) { + num_pages[batch_idx] = kv_indptr_h[batch_idx + 1] - kv_indptr_h[batch_idx]; + } + std::tie(max_num_pages_per_batch, new_batch_size) = + PartitionPagedKVCacheBinarySearchMinNumPagePerBatch(max_grid_size, gdy, num_pages, + std::max(128 / page_size, 1U)); + if (new_batch_size == batch_size && !enable_cuda_graph) { + // do not use partition-kv kernel for short sequence, when not using + // CUDAGraph + split_kv = false; + } else { + // when using CUDAGraph, we always use partition-kv kernel + split_kv = true; + } + } + + return hipSuccess; } /*! @@ -436,1178 +350,999 @@ BatchDecodeWithPagedKVCacheWorkEstimationDispatchedMlaCuteSM80( * \return status Indicates whether CUDA calls are successful */ template -inline auto DecodeSplitKVIndptr(IdType *indptr_h, - uint32_t batch_size, - uint32_t kv_chunk_size) -{ - std::vector request_indices, kv_tile_indices, o_indptr; - o_indptr.push_back(0); - - for (uint32_t batch_idx = 0; batch_idx < batch_size; batch_idx++) { - uint32_t num_tiles_kv = - ceil_div(std::max( - indptr_h[batch_idx + 1] - indptr_h[batch_idx], 1U), - kv_chunk_size); - for (uint32_t kv_tile_idx = 0; kv_tile_idx < num_tiles_kv; - ++kv_tile_idx) - { - request_indices.push_back(batch_idx); - kv_tile_indices.push_back(kv_tile_idx); - } - o_indptr.push_back(o_indptr.back() + num_tiles_kv); +inline auto DecodeSplitKVIndptr(IdType* indptr_h, uint32_t batch_size, uint32_t kv_chunk_size) { + std::vector request_indices, kv_tile_indices, o_indptr; + o_indptr.push_back(0); + + for (uint32_t batch_idx = 0; batch_idx < batch_size; batch_idx++) { + uint32_t num_tiles_kv = ceil_div( + std::max(indptr_h[batch_idx + 1] - indptr_h[batch_idx], 1U), kv_chunk_size); + for (uint32_t kv_tile_idx = 0; kv_tile_idx < num_tiles_kv; ++kv_tile_idx) { + request_indices.push_back(batch_idx); + kv_tile_indices.push_back(kv_tile_idx); } + o_indptr.push_back(o_indptr.back() + num_tiles_kv); + } - return std::make_tuple(request_indices, kv_tile_indices, o_indptr); + return std::make_tuple(request_indices, kv_tile_indices, o_indptr); } -struct DecodePlanInfo -{ - int64_t padded_batch_size; - int64_t v_offset; - int64_t s_offset; - int64_t request_indices_offset; - int64_t kv_tile_indices_offset; - int64_t o_indptr_offset; - int64_t block_valid_mask_offset; - int64_t kv_chunk_size_ptr_offset; - bool enable_cuda_graph; - bool split_kv; - - DecodePlanInfo() - : padded_batch_size(0), v_offset(0), s_offset(0), - request_indices_offset(0), kv_tile_indices_offset(0), - o_indptr_offset(0), block_valid_mask_offset(0), - kv_chunk_size_ptr_offset(0), enable_cuda_graph(false), split_kv(false) - { - } - - // convert DecodePlanInfo to std::vector - std::vector ToVector() const - { - return {padded_batch_size, - v_offset, - s_offset, - request_indices_offset, - kv_tile_indices_offset, - o_indptr_offset, - block_valid_mask_offset, - kv_chunk_size_ptr_offset, - enable_cuda_graph, - split_kv}; - } - - // From std::vector to DecodePlanInfo - void FromVector(const std::vector &vec) - { - if (vec.size() != 10) { - std::ostringstream err_msg; - err_msg << "DecodePlanInfo::FromVector: vec.size() should be 10, " - "but got " - << vec.size(); - FLASHINFER_ERROR(err_msg.str()); - } - padded_batch_size = vec[0]; - v_offset = vec[1]; - s_offset = vec[2]; - request_indices_offset = vec[3]; - kv_tile_indices_offset = vec[4]; - o_indptr_offset = vec[5]; - block_valid_mask_offset = vec[6]; - kv_chunk_size_ptr_offset = vec[7]; - enable_cuda_graph = vec[8]; - split_kv = vec[9]; +struct DecodePlanInfo { + int64_t padded_batch_size; + int64_t v_offset; + int64_t s_offset; + int64_t request_indices_offset; + int64_t kv_tile_indices_offset; + int64_t o_indptr_offset; + int64_t block_valid_mask_offset; + int64_t kv_chunk_size_ptr_offset; + bool enable_cuda_graph; + bool split_kv; + + DecodePlanInfo() + : padded_batch_size(0), + v_offset(0), + s_offset(0), + request_indices_offset(0), + kv_tile_indices_offset(0), + o_indptr_offset(0), + block_valid_mask_offset(0), + kv_chunk_size_ptr_offset(0), + enable_cuda_graph(false), + split_kv(false) {} + + // convert DecodePlanInfo to std::vector + std::vector ToVector() const { + return {padded_batch_size, + v_offset, + s_offset, + request_indices_offset, + kv_tile_indices_offset, + o_indptr_offset, + block_valid_mask_offset, + kv_chunk_size_ptr_offset, + enable_cuda_graph, + split_kv}; + } + + // From std::vector to DecodePlanInfo + void FromVector(const std::vector& vec) { + if (vec.size() != 10) { + std::ostringstream err_msg; + err_msg << "DecodePlanInfo::FromVector: vec.size() should be 10, " + "but got " + << vec.size(); + FLASHINFER_ERROR(err_msg.str()); } + padded_batch_size = vec[0]; + v_offset = vec[1]; + s_offset = vec[2]; + request_indices_offset = vec[3]; + kv_tile_indices_offset = vec[4]; + o_indptr_offset = vec[5]; + block_valid_mask_offset = vec[6]; + kv_chunk_size_ptr_offset = vec[7]; + enable_cuda_graph = vec[8]; + split_kv = vec[9]; + } }; -template -inline hipError_t DecodePlan(void *float_buffer, - size_t float_workspace_size_in_bytes, - void *int_buffer, - void *page_locked_int_buffer, - size_t int_workspace_size_in_bytes, - DecodePlanInfo &plan_info, - typename Params::IdType *indptr_h, - uint32_t batch_size, - uint32_t num_qo_heads, - uint32_t page_size, - bool enable_cuda_graph, - hipStream_t stream, - WorkEstimationFunc work_estimation_func) -{ - using DTypeO = typename Params::DTypeO; - using IdType = typename Params::IdType; - bool split_kv; - uint32_t max_grid_size, kv_chunk_size_in_pages, new_batch_size, gdy; - - FLASHINFER_CUDA_CALL(work_estimation_func( - split_kv, max_grid_size, kv_chunk_size_in_pages, new_batch_size, gdy, - batch_size, indptr_h, num_qo_heads, page_size, enable_cuda_graph, - stream)); - size_t padded_batch_size; - plan_info.enable_cuda_graph = enable_cuda_graph; - plan_info.split_kv = split_kv; - padded_batch_size = (enable_cuda_graph) - ? (split_kv ? max_grid_size / gdy : batch_size) - : new_batch_size; - plan_info.padded_batch_size = padded_batch_size; - - auto [request_indices_vec, kv_tile_indices_vec, o_indptr_vec] = - DecodeSplitKVIndptr(indptr_h, batch_size, kv_chunk_size_in_pages); - - AlignedAllocator int_allocator(int_buffer, int_workspace_size_in_bytes); - plan_info.request_indices_offset = int_allocator.aligned_alloc_offset( - padded_batch_size * sizeof(IdType), 16, "batch_decode_request_indices"); - plan_info.kv_tile_indices_offset = int_allocator.aligned_alloc_offset( - padded_batch_size * sizeof(IdType), 16, "batch_decode_kv_tile_indices"); - plan_info.o_indptr_offset = int_allocator.aligned_alloc_offset( - (padded_batch_size + 1) * sizeof(IdType), 16, "batch_decode_o_indptr"); - plan_info.kv_chunk_size_ptr_offset = int_allocator.aligned_alloc_offset( - sizeof(IdType), 1, "batch_decode_kv_chunk_size_ptr"); - IdType *request_indices_h = GetPtrFromBaseOffset( - page_locked_int_buffer, plan_info.request_indices_offset); - IdType *kv_tile_indices_h = GetPtrFromBaseOffset( - page_locked_int_buffer, plan_info.kv_tile_indices_offset); - IdType *o_indptr_h = GetPtrFromBaseOffset( - page_locked_int_buffer, plan_info.o_indptr_offset); - IdType *kv_chunk_size_ptr_h = GetPtrFromBaseOffset( - page_locked_int_buffer, plan_info.kv_chunk_size_ptr_offset); - std::copy(request_indices_vec.begin(), request_indices_vec.end(), - request_indices_h); - std::copy(kv_tile_indices_vec.begin(), kv_tile_indices_vec.end(), - kv_tile_indices_h); - std::copy(o_indptr_vec.begin(), o_indptr_vec.end(), o_indptr_h); - kv_chunk_size_ptr_h[0] = kv_chunk_size_in_pages * page_size; - - if (split_kv) { - AlignedAllocator float_allocator(float_buffer, - float_workspace_size_in_bytes); - plan_info.v_offset = float_allocator.aligned_alloc_offset( - num_qo_heads * padded_batch_size * HEAD_DIM * sizeof(float), 16, - "batch_decode_tmp_v"); - plan_info.s_offset = float_allocator.aligned_alloc_offset( - num_qo_heads * padded_batch_size * sizeof(float), 16, - "batch_decode_tmp_s"); - - plan_info.block_valid_mask_offset = int_allocator.aligned_alloc_offset( - padded_batch_size * sizeof(bool), 16, - "batch_decode_block_valid_mask"); - bool *block_valid_mask_h = GetPtrFromBaseOffset( - page_locked_int_buffer, plan_info.block_valid_mask_offset); - for (uint32_t i = 0; i < padded_batch_size; ++i) { - block_valid_mask_h[i] = i < new_batch_size; - } +template +inline hipError_t DecodePlan(void* float_buffer, size_t float_workspace_size_in_bytes, + void* int_buffer, void* page_locked_int_buffer, + size_t int_workspace_size_in_bytes, DecodePlanInfo& plan_info, + typename Params::IdType* indptr_h, uint32_t batch_size, + uint32_t num_qo_heads, uint32_t page_size, bool enable_cuda_graph, + hipStream_t stream, WorkEstimationFunc work_estimation_func) { + using DTypeO = typename Params::DTypeO; + using IdType = typename Params::IdType; + bool split_kv; + uint32_t max_grid_size, kv_chunk_size_in_pages, new_batch_size, gdy; + + FLASHINFER_CUDA_CALL(work_estimation_func(split_kv, max_grid_size, kv_chunk_size_in_pages, + new_batch_size, gdy, batch_size, indptr_h, num_qo_heads, + page_size, enable_cuda_graph, stream)); + size_t padded_batch_size; + plan_info.enable_cuda_graph = enable_cuda_graph; + plan_info.split_kv = split_kv; + padded_batch_size = + (enable_cuda_graph) ? (split_kv ? max_grid_size / gdy : batch_size) : new_batch_size; + plan_info.padded_batch_size = padded_batch_size; + + auto [request_indices_vec, kv_tile_indices_vec, o_indptr_vec] = + DecodeSplitKVIndptr(indptr_h, batch_size, kv_chunk_size_in_pages); + + AlignedAllocator int_allocator(int_buffer, int_workspace_size_in_bytes); + plan_info.request_indices_offset = int_allocator.aligned_alloc_offset( + padded_batch_size * sizeof(IdType), 16, "batch_decode_request_indices"); + plan_info.kv_tile_indices_offset = int_allocator.aligned_alloc_offset( + padded_batch_size * sizeof(IdType), 16, "batch_decode_kv_tile_indices"); + plan_info.o_indptr_offset = int_allocator.aligned_alloc_offset( + (padded_batch_size + 1) * sizeof(IdType), 16, "batch_decode_o_indptr"); + plan_info.kv_chunk_size_ptr_offset = + int_allocator.aligned_alloc_offset(sizeof(IdType), 1, "batch_decode_kv_chunk_size_ptr"); + IdType* request_indices_h = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.request_indices_offset); + IdType* kv_tile_indices_h = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.kv_tile_indices_offset); + IdType* o_indptr_h = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.o_indptr_offset); + IdType* kv_chunk_size_ptr_h = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.kv_chunk_size_ptr_offset); + std::copy(request_indices_vec.begin(), request_indices_vec.end(), request_indices_h); + std::copy(kv_tile_indices_vec.begin(), kv_tile_indices_vec.end(), kv_tile_indices_h); + std::copy(o_indptr_vec.begin(), o_indptr_vec.end(), o_indptr_h); + kv_chunk_size_ptr_h[0] = kv_chunk_size_in_pages * page_size; + + if (split_kv) { + AlignedAllocator float_allocator(float_buffer, float_workspace_size_in_bytes); + plan_info.v_offset = float_allocator.aligned_alloc_offset( + num_qo_heads * padded_batch_size * HEAD_DIM * sizeof(float), 16, "batch_decode_tmp_v"); + plan_info.s_offset = float_allocator.aligned_alloc_offset( + num_qo_heads * padded_batch_size * sizeof(float), 16, "batch_decode_tmp_s"); + + plan_info.block_valid_mask_offset = int_allocator.aligned_alloc_offset( + padded_batch_size * sizeof(bool), 16, "batch_decode_block_valid_mask"); + bool* block_valid_mask_h = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.block_valid_mask_offset); + for (uint32_t i = 0; i < padded_batch_size; ++i) { + block_valid_mask_h[i] = i < new_batch_size; } + } - size_t num_bytes_to_copy = int_allocator.num_allocated_bytes(); + size_t num_bytes_to_copy = int_allocator.num_allocated_bytes(); - FLASHINFER_CUDA_CALL(hipMemcpyAsync(int_buffer, page_locked_int_buffer, - num_bytes_to_copy, - hipMemcpyHostToDevice, stream)); - return hipSuccess; + FLASHINFER_CUDA_CALL(hipMemcpyAsync(int_buffer, page_locked_int_buffer, num_bytes_to_copy, + hipMemcpyHostToDevice, stream)); + return hipSuccess; } template -inline auto PrefillSplitQOKVIndptr(IdType *qo_indptr_h, - IdType *kv_indptr_h, - uint32_t total_num_rows, - uint32_t batch_size, - uint32_t num_qo_heads, - uint32_t num_kv_heads, - uint32_t head_dim, - uint32_t page_size, - uint32_t max_batch_size_if_split, - bool enable_cuda_graph) -{ - std::vector request_indices, qo_tile_indices, kv_tile_indices, - merge_indptr, o_indptr; - merge_indptr.push_back(0); - o_indptr.push_back(0); - - const uint32_t gqa_group_size = num_qo_heads / num_kv_heads; - - // step 1: determine packed_qo_len_arr and verify qo_indptr contents. - std::vector packed_qo_len_arr(batch_size), kv_len_arr(batch_size); +inline auto PrefillSplitQOKVIndptr(IdType* qo_indptr_h, IdType* kv_indptr_h, + uint32_t total_num_rows, uint32_t batch_size, + uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim, + uint32_t page_size, uint32_t max_batch_size_if_split, + bool enable_cuda_graph) { + std::vector request_indices, qo_tile_indices, kv_tile_indices, merge_indptr, o_indptr; + merge_indptr.push_back(0); + o_indptr.push_back(0); + + const uint32_t gqa_group_size = num_qo_heads / num_kv_heads; + + // step 1: determine packed_qo_len_arr and verify qo_indptr contents. + std::vector packed_qo_len_arr(batch_size), kv_len_arr(batch_size); + for (uint32_t i = 0; i < batch_size; ++i) { + packed_qo_len_arr[i] = int64_t(qo_indptr_h[i + 1] - qo_indptr_h[i]) * int64_t(gqa_group_size); + if (packed_qo_len_arr[i] < 0) { + std::ostringstream err_msg; + err_msg << "qo_indptr[" << i + 1 << "]" << qo_indptr_h[i + 1] << " - qo_indptr[" << i << "]" + << qo_indptr_h[i] << " should be non-negative"; + FLASHINFER_ERROR(err_msg.str()); + } + kv_len_arr[i] = int64_t(kv_indptr_h[i + 1] - kv_indptr_h[i]); + if (kv_len_arr[i] < 0) { + std::ostringstream err_msg; + err_msg << "kv_indptr[" << i + 1 << "]" << kv_indptr_h[i + 1] << " - kv_indptr[" << i << "]" + << kv_indptr_h[i] << " should be non-negative"; + FLASHINFER_ERROR(err_msg.str()); + } + } + + // step 2: determine cta_tile_q, kv_chunk_size and total_num_tiles_q + const uint32_t min_kv_chunk_size = std::max((128 / page_size), 1U); + uint32_t cta_tile_q; + uint32_t total_num_tiles_q; + if (enable_cuda_graph) { + // When CUDA graphs are enabled, the lengths of sequences determined by + // qo_indptr_h can vary. We assume that the dummy data based on which + // the CUDA graph is created fixes the maximum number of tokens. + const uint64_t max_seq_len = total_num_rows - batch_size + 1; + uint64_t max_qo_len = uint64_t(max_seq_len) * gqa_group_size; + cta_tile_q = FA2DetermineCtaTileQ(max_qo_len, head_dim); + + // Find an upper bound for the number of tiles, derived from the total + // number of rows and the batch size. The sum of qo lengths rounded + // up to cta_tile_q will not exceed this number derived from the total + // number of rows. + total_num_tiles_q = ceil_div(total_num_rows * gqa_group_size, cta_tile_q) + batch_size - 1; + } else { + int64_t sum_packed_qo_len = 0; for (uint32_t i = 0; i < batch_size; ++i) { - packed_qo_len_arr[i] = int64_t(qo_indptr_h[i + 1] - qo_indptr_h[i]) * - int64_t(gqa_group_size); - if (packed_qo_len_arr[i] < 0) { - std::ostringstream err_msg; - err_msg << "qo_indptr[" << i + 1 << "]" << qo_indptr_h[i + 1] - << " - qo_indptr[" << i << "]" << qo_indptr_h[i] - << " should be non-negative"; - FLASHINFER_ERROR(err_msg.str()); - } - kv_len_arr[i] = int64_t(kv_indptr_h[i + 1] - kv_indptr_h[i]); - if (kv_len_arr[i] < 0) { - std::ostringstream err_msg; - err_msg << "kv_indptr[" << i + 1 << "]" << kv_indptr_h[i + 1] - << " - kv_indptr[" << i << "]" << kv_indptr_h[i] - << " should be non-negative"; - FLASHINFER_ERROR(err_msg.str()); - } + sum_packed_qo_len += packed_qo_len_arr[i]; } + const int64_t avg_packed_qo_len = sum_packed_qo_len / batch_size; + cta_tile_q = FA2DetermineCtaTileQ(avg_packed_qo_len, head_dim); - // step 2: determine cta_tile_q, kv_chunk_size and total_num_tiles_q - const uint32_t min_kv_chunk_size = std::max((128 / page_size), 1U); - uint32_t cta_tile_q; - uint32_t total_num_tiles_q; - if (enable_cuda_graph) { - // When CUDA graphs are enabled, the lengths of sequences determined by - // qo_indptr_h can vary. We assume that the dummy data based on which - // the CUDA graph is created fixes the maximum number of tokens. - const uint64_t max_seq_len = total_num_rows - batch_size + 1; - uint64_t max_qo_len = uint64_t(max_seq_len) * gqa_group_size; - cta_tile_q = FA2DetermineCtaTileQ(max_qo_len, head_dim); - - // Find an upper bound for the number of tiles, derived from the total - // number of rows and the batch size. The sum of qo lengths rounded - // up to cta_tile_q will not exceed this number derived from the total - // number of rows. - total_num_tiles_q = - ceil_div(total_num_rows * gqa_group_size, cta_tile_q) + batch_size - - 1; + total_num_tiles_q = 0; + for (uint32_t i = 0; i < batch_size; ++i) { + total_num_tiles_q += ceil_div(packed_qo_len_arr[i], cta_tile_q); } - else { - int64_t sum_packed_qo_len = 0; - for (uint32_t i = 0; i < batch_size; ++i) { - sum_packed_qo_len += packed_qo_len_arr[i]; - } - const int64_t avg_packed_qo_len = sum_packed_qo_len / batch_size; - cta_tile_q = FA2DetermineCtaTileQ(avg_packed_qo_len, head_dim); - - total_num_tiles_q = 0; - for (uint32_t i = 0; i < batch_size; ++i) { - total_num_tiles_q += ceil_div(packed_qo_len_arr[i], cta_tile_q); - } + } + + auto [split_kv, kv_chunk_size] = + PrefillBinarySearchKVChunkSize(enable_cuda_graph, max_batch_size_if_split, packed_qo_len_arr, + kv_len_arr, cta_tile_q, min_kv_chunk_size); + + // step 3: split qo_indptr and kv_indptr + uint32_t new_batch_size = 0; + for (uint32_t request_idx = 0; request_idx < batch_size; ++request_idx) { + const int64_t packed_qo_len = packed_qo_len_arr[request_idx]; + const int64_t kv_len = std::max(int(kv_len_arr[request_idx]), 1); + const int64_t num_tiles_q = ceil_div(packed_qo_len, cta_tile_q); + const int64_t num_tiles_kv = ceil_div(kv_len, kv_chunk_size); + + for (uint32_t q_tile_idx = 0; q_tile_idx < num_tiles_q; ++q_tile_idx) { + for (uint32_t kv_tile_idx = 0; kv_tile_idx < num_tiles_kv; ++kv_tile_idx) { + new_batch_size += 1; + request_indices.push_back(request_idx); + qo_tile_indices.push_back(q_tile_idx); + kv_tile_indices.push_back(kv_tile_idx); + } } - auto [split_kv, kv_chunk_size] = PrefillBinarySearchKVChunkSize( - enable_cuda_graph, max_batch_size_if_split, packed_qo_len_arr, - kv_len_arr, cta_tile_q, min_kv_chunk_size); - - // step 3: split qo_indptr and kv_indptr - uint32_t new_batch_size = 0; - for (uint32_t request_idx = 0; request_idx < batch_size; ++request_idx) { - const int64_t packed_qo_len = packed_qo_len_arr[request_idx]; - const int64_t kv_len = std::max(int(kv_len_arr[request_idx]), 1); - const int64_t num_tiles_q = ceil_div(packed_qo_len, cta_tile_q); - const int64_t num_tiles_kv = ceil_div(kv_len, kv_chunk_size); - - for (uint32_t q_tile_idx = 0; q_tile_idx < num_tiles_q; ++q_tile_idx) { - for (uint32_t kv_tile_idx = 0; kv_tile_idx < num_tiles_kv; - ++kv_tile_idx) - { - new_batch_size += 1; - request_indices.push_back(request_idx); - qo_tile_indices.push_back(q_tile_idx); - kv_tile_indices.push_back(kv_tile_idx); - } - } - - int64_t qo_len = packed_qo_len / gqa_group_size; - for (uint32_t row = 0; row < qo_len; ++row) { - merge_indptr.push_back(merge_indptr.back() + num_tiles_kv); - } - o_indptr.push_back(o_indptr.back() + qo_len * num_tiles_kv); + int64_t qo_len = packed_qo_len / gqa_group_size; + for (uint32_t row = 0; row < qo_len; ++row) { + merge_indptr.push_back(merge_indptr.back() + num_tiles_kv); } + o_indptr.push_back(o_indptr.back() + qo_len * num_tiles_kv); + } - const size_t padded_batch_size = - enable_cuda_graph ? std::max(max_batch_size_if_split, total_num_tiles_q) - : new_batch_size; - FLASHINFER_CHECK(new_batch_size <= padded_batch_size, - "new batch size should not exceed padded batch size"); + const size_t padded_batch_size = + enable_cuda_graph ? std::max(max_batch_size_if_split, total_num_tiles_q) : new_batch_size; + FLASHINFER_CHECK(new_batch_size <= padded_batch_size, + "new batch size should not exceed padded batch size"); - // step 4: multiply kv_chunk_size by page_size - kv_chunk_size *= page_size; + // step 4: multiply kv_chunk_size by page_size + kv_chunk_size *= page_size; - return std::make_tuple( - split_kv, new_batch_size, padded_batch_size, cta_tile_q, kv_chunk_size, - std::move(request_indices), std::move(qo_tile_indices), - std::move(kv_tile_indices), std::move(merge_indptr), - std::move(o_indptr)); + return std::make_tuple(split_kv, new_batch_size, padded_batch_size, cta_tile_q, kv_chunk_size, + std::move(request_indices), std::move(qo_tile_indices), + std::move(kv_tile_indices), std::move(merge_indptr), std::move(o_indptr)); } -struct PrefillPlanInfo -{ - int64_t padded_batch_size; - int64_t total_num_rows; - int64_t total_num_rows_offset; - int64_t cta_tile_q; - int64_t request_indices_offset; - int64_t qo_tile_indices_offset; - int64_t kv_tile_indices_offset; - int64_t merge_indptr_offset; - int64_t o_indptr_offset; - int64_t kv_chunk_size_ptr_offset; - int64_t v_offset; - int64_t s_offset; - int64_t block_valid_mask_offset; - bool enable_cuda_graph; - bool split_kv; - - PrefillPlanInfo() - : padded_batch_size(0), total_num_rows(0), total_num_rows_offset(0), - cta_tile_q(0), request_indices_offset(0), qo_tile_indices_offset(0), - kv_tile_indices_offset(0), merge_indptr_offset(0), o_indptr_offset(0), - kv_chunk_size_ptr_offset(0), v_offset(0), s_offset(0), - block_valid_mask_offset(0), enable_cuda_graph(false), split_kv(false) - { - } - - // convert PrefillPlanInfo to std::vector - std::vector ToVector() const - { - return {padded_batch_size, - total_num_rows, - total_num_rows_offset, - cta_tile_q, - request_indices_offset, - qo_tile_indices_offset, - kv_tile_indices_offset, - merge_indptr_offset, - o_indptr_offset, - kv_chunk_size_ptr_offset, - v_offset, - s_offset, - block_valid_mask_offset, - enable_cuda_graph, - split_kv}; - } - - // From std::vector to PrefillPlanInfo - void FromVector(const std::vector &vec) - { - if (vec.size() != 15) { - std::ostringstream err_msg; - err_msg << "PrefillPlanInfo::FromVector: vec.size() should be 15, " - "but got " - << vec.size(); - FLASHINFER_ERROR(err_msg.str()); - } - padded_batch_size = vec[0]; - total_num_rows = vec[1]; - total_num_rows_offset = vec[2]; - cta_tile_q = vec[3]; - request_indices_offset = vec[4]; - qo_tile_indices_offset = vec[5]; - kv_tile_indices_offset = vec[6]; - merge_indptr_offset = vec[7]; - o_indptr_offset = vec[8]; - kv_chunk_size_ptr_offset = vec[9]; - v_offset = vec[10]; - s_offset = vec[11]; - block_valid_mask_offset = vec[12]; - enable_cuda_graph = vec[13]; - split_kv = vec[14]; +struct PrefillPlanInfo { + int64_t padded_batch_size; + int64_t total_num_rows; + int64_t total_num_rows_offset; + int64_t cta_tile_q; + int64_t request_indices_offset; + int64_t qo_tile_indices_offset; + int64_t kv_tile_indices_offset; + int64_t merge_indptr_offset; + int64_t o_indptr_offset; + int64_t kv_chunk_size_ptr_offset; + int64_t v_offset; + int64_t s_offset; + int64_t block_valid_mask_offset; + bool enable_cuda_graph; + bool split_kv; + + PrefillPlanInfo() + : padded_batch_size(0), + total_num_rows(0), + total_num_rows_offset(0), + cta_tile_q(0), + request_indices_offset(0), + qo_tile_indices_offset(0), + kv_tile_indices_offset(0), + merge_indptr_offset(0), + o_indptr_offset(0), + kv_chunk_size_ptr_offset(0), + v_offset(0), + s_offset(0), + block_valid_mask_offset(0), + enable_cuda_graph(false), + split_kv(false) {} + + // convert PrefillPlanInfo to std::vector + std::vector ToVector() const { + return {padded_batch_size, + total_num_rows, + total_num_rows_offset, + cta_tile_q, + request_indices_offset, + qo_tile_indices_offset, + kv_tile_indices_offset, + merge_indptr_offset, + o_indptr_offset, + kv_chunk_size_ptr_offset, + v_offset, + s_offset, + block_valid_mask_offset, + enable_cuda_graph, + split_kv}; + } + + // From std::vector to PrefillPlanInfo + void FromVector(const std::vector& vec) { + if (vec.size() != 15) { + std::ostringstream err_msg; + err_msg << "PrefillPlanInfo::FromVector: vec.size() should be 15, " + "but got " + << vec.size(); + FLASHINFER_ERROR(err_msg.str()); } + padded_batch_size = vec[0]; + total_num_rows = vec[1]; + total_num_rows_offset = vec[2]; + cta_tile_q = vec[3]; + request_indices_offset = vec[4]; + qo_tile_indices_offset = vec[5]; + kv_tile_indices_offset = vec[6]; + merge_indptr_offset = vec[7]; + o_indptr_offset = vec[8]; + kv_chunk_size_ptr_offset = vec[9]; + v_offset = vec[10]; + s_offset = vec[11]; + block_valid_mask_offset = vec[12]; + enable_cuda_graph = vec[13]; + split_kv = vec[14]; + } }; template -inline hipError_t PrefillPlan(void *float_buffer, - size_t float_workspace_size_in_bytes, - void *int_buffer, - void *page_locked_int_buffer, - size_t int_workspace_size_in_bytes, - PrefillPlanInfo &plan_info, - IdType *qo_indptr_h, - IdType *kv_indptr_h, - uint32_t total_num_rows, - uint32_t batch_size, - uint32_t num_qo_heads, - uint32_t num_kv_heads, - uint32_t head_dim_qk, - uint32_t head_dim_vo, - uint32_t page_size, - bool enable_cuda_graph, - uint32_t sizeof_dtype_o, - hipStream_t stream) -{ - if (num_qo_heads % num_kv_heads != 0) { - std::ostringstream err_msg; - err_msg << "num_qo_heads " << num_qo_heads - << " should be divisible by num_kv_heads " << num_kv_heads; - FLASHINFER_ERROR(err_msg.str()); +inline hipError_t PrefillPlan(void* float_buffer, size_t float_workspace_size_in_bytes, + void* int_buffer, void* page_locked_int_buffer, + size_t int_workspace_size_in_bytes, PrefillPlanInfo& plan_info, + IdType* qo_indptr_h, IdType* kv_indptr_h, uint32_t total_num_rows, + uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads, + uint32_t head_dim_qk, uint32_t head_dim_vo, uint32_t page_size, + bool enable_cuda_graph, uint32_t sizeof_dtype_o, hipStream_t stream) { + if (num_qo_heads % num_kv_heads != 0) { + std::ostringstream err_msg; + err_msg << "num_qo_heads " << num_qo_heads << " should be divisible by num_kv_heads " + << num_kv_heads; + FLASHINFER_ERROR(err_msg.str()); + } + + // step 0: get the number of SMs + int num_sm = 0; + int dev_id = 0; + FLASHINFER_CUDA_CALL(hipGetDevice(&dev_id)); + FLASHINFER_CUDA_CALL( + hipDeviceGetAttribute(&num_sm, hipDeviceAttributeMultiprocessorCount, dev_id)); + int num_blocks_per_sm = 2; + int max_grid_size = num_blocks_per_sm * num_sm; + uint32_t max_batch_size_if_split = max_grid_size / num_kv_heads; + + // step 2: determine kv_chunk_size + auto [split_kv, new_batch_size, padded_batch_size, cta_tile_q, kv_chunk_size, request_indices_vec, + qo_tile_indices_vec, kv_tile_indices_vec, merge_indptr_vec, o_indptr_vec] = + PrefillSplitQOKVIndptr(qo_indptr_h, kv_indptr_h, total_num_rows, batch_size, num_qo_heads, + num_kv_heads, head_dim_vo, page_size, max_batch_size_if_split, + enable_cuda_graph); + + plan_info.cta_tile_q = cta_tile_q; + plan_info.total_num_rows = total_num_rows; + plan_info.enable_cuda_graph = enable_cuda_graph; + plan_info.padded_batch_size = padded_batch_size; + plan_info.split_kv = split_kv; + + AlignedAllocator int_allocator(int_buffer, int_workspace_size_in_bytes); + plan_info.request_indices_offset = int_allocator.aligned_alloc_offset( + sizeof(IdType) * padded_batch_size, 16, "batch_prefill_request_indices"); + plan_info.qo_tile_indices_offset = int_allocator.aligned_alloc_offset( + sizeof(IdType) * padded_batch_size, 16, "batch_prefill_qo_tile_indices"); + plan_info.kv_tile_indices_offset = int_allocator.aligned_alloc_offset( + sizeof(IdType) * padded_batch_size, 16, "batch_prefill_kv_tile_indices"); + plan_info.o_indptr_offset = int_allocator.aligned_alloc_offset(sizeof(IdType) * (batch_size + 1), + 16, "batch_prefill_o_indptr"); + plan_info.kv_chunk_size_ptr_offset = + int_allocator.aligned_alloc_offset(sizeof(IdType), 1, "batch_prefill_kv_chunk_size_ptr"); + + if (plan_info.enable_cuda_graph) { + plan_info.total_num_rows_offset = + int_allocator.aligned_alloc_offset(sizeof(uint32_t), 16, "batch_prefill_total_num_rows"); + uint32_t* total_num_rows_h = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.total_num_rows_offset); + *total_num_rows_h = qo_indptr_h[batch_size]; + } + + IdType* request_indices_h = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.request_indices_offset); + IdType* qo_tile_indices_h = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.qo_tile_indices_offset); + IdType* kv_tile_indices_h = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.kv_tile_indices_offset); + IdType* o_indptr_h = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.o_indptr_offset); + IdType* kv_chunk_size_ptr_h = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.kv_chunk_size_ptr_offset); + std::copy(request_indices_vec.begin(), request_indices_vec.end(), request_indices_h); + std::copy(qo_tile_indices_vec.begin(), qo_tile_indices_vec.end(), qo_tile_indices_h); + std::copy(kv_tile_indices_vec.begin(), kv_tile_indices_vec.end(), kv_tile_indices_h); + std::copy(o_indptr_vec.begin(), o_indptr_vec.end(), o_indptr_h); + kv_chunk_size_ptr_h[0] = kv_chunk_size; + + if (split_kv) { + AlignedAllocator float_allocator(float_buffer, float_workspace_size_in_bytes); + plan_info.v_offset = float_allocator.aligned_alloc_offset( + num_qo_heads * padded_batch_size * cta_tile_q * head_dim_vo * sizeof(float), 16, + "batch_prefill_tmp_v"); + plan_info.s_offset = float_allocator.aligned_alloc_offset( + num_qo_heads * padded_batch_size * cta_tile_q * sizeof(float), 16, "batch_prefill_tmp_s"); + plan_info.merge_indptr_offset = int_allocator.aligned_alloc_offset( + sizeof(IdType) * (plan_info.total_num_rows + 1), 16, "batch_prefill_merge_indptr"); + plan_info.block_valid_mask_offset = int_allocator.aligned_alloc_offset( + sizeof(bool) * padded_batch_size, 16, "batch_prefill_block_valid_mask"); + + IdType* merge_indptr_h = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.merge_indptr_offset); + bool* block_valid_mask_h = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.block_valid_mask_offset); + std::copy(merge_indptr_vec.begin(), merge_indptr_vec.end(), merge_indptr_h); + for (uint32_t i = 0; i < padded_batch_size; ++i) { + block_valid_mask_h[i] = i < new_batch_size; } + } - // step 0: get the number of SMs - int num_sm = 0; - int dev_id = 0; - FLASHINFER_CUDA_CALL(hipGetDevice(&dev_id)); - FLASHINFER_CUDA_CALL(hipDeviceGetAttribute( - &num_sm, hipDeviceAttributeMultiprocessorCount, dev_id)); - int num_blocks_per_sm = 2; - int max_grid_size = num_blocks_per_sm * num_sm; - uint32_t max_batch_size_if_split = max_grid_size / num_kv_heads; - - // step 2: determine kv_chunk_size - auto [split_kv, new_batch_size, padded_batch_size, cta_tile_q, - kv_chunk_size, request_indices_vec, qo_tile_indices_vec, - kv_tile_indices_vec, merge_indptr_vec, o_indptr_vec] = - PrefillSplitQOKVIndptr(qo_indptr_h, kv_indptr_h, total_num_rows, - batch_size, num_qo_heads, num_kv_heads, - head_dim_vo, page_size, max_batch_size_if_split, - enable_cuda_graph); - - plan_info.cta_tile_q = cta_tile_q; - plan_info.total_num_rows = total_num_rows; - plan_info.enable_cuda_graph = enable_cuda_graph; - plan_info.padded_batch_size = padded_batch_size; - plan_info.split_kv = split_kv; - - AlignedAllocator int_allocator(int_buffer, int_workspace_size_in_bytes); - plan_info.request_indices_offset = - int_allocator.aligned_alloc_offset(sizeof(IdType) * padded_batch_size, - 16, "batch_prefill_request_indices"); - plan_info.qo_tile_indices_offset = - int_allocator.aligned_alloc_offset(sizeof(IdType) * padded_batch_size, - 16, "batch_prefill_qo_tile_indices"); - plan_info.kv_tile_indices_offset = - int_allocator.aligned_alloc_offset(sizeof(IdType) * padded_batch_size, - 16, "batch_prefill_kv_tile_indices"); - plan_info.o_indptr_offset = int_allocator.aligned_alloc_offset( - sizeof(IdType) * (batch_size + 1), 16, "batch_prefill_o_indptr"); - plan_info.kv_chunk_size_ptr_offset = int_allocator.aligned_alloc_offset( - sizeof(IdType), 1, "batch_prefill_kv_chunk_size_ptr"); - - if (plan_info.enable_cuda_graph) { - plan_info.total_num_rows_offset = int_allocator.aligned_alloc_offset( - sizeof(uint32_t), 16, "batch_prefill_total_num_rows"); - uint32_t *total_num_rows_h = GetPtrFromBaseOffset( - page_locked_int_buffer, plan_info.total_num_rows_offset); - *total_num_rows_h = qo_indptr_h[batch_size]; - } + size_t num_bytes_to_copy = int_allocator.num_allocated_bytes(); + FLASHINFER_CUDA_CALL(hipMemcpyAsync(int_buffer, page_locked_int_buffer, num_bytes_to_copy, + hipMemcpyHostToDevice, stream)); - IdType *request_indices_h = GetPtrFromBaseOffset( - page_locked_int_buffer, plan_info.request_indices_offset); - IdType *qo_tile_indices_h = GetPtrFromBaseOffset( - page_locked_int_buffer, plan_info.qo_tile_indices_offset); - IdType *kv_tile_indices_h = GetPtrFromBaseOffset( - page_locked_int_buffer, plan_info.kv_tile_indices_offset); - IdType *o_indptr_h = GetPtrFromBaseOffset( - page_locked_int_buffer, plan_info.o_indptr_offset); - IdType *kv_chunk_size_ptr_h = GetPtrFromBaseOffset( - page_locked_int_buffer, plan_info.kv_chunk_size_ptr_offset); - std::copy(request_indices_vec.begin(), request_indices_vec.end(), - request_indices_h); - std::copy(qo_tile_indices_vec.begin(), qo_tile_indices_vec.end(), - qo_tile_indices_h); - std::copy(kv_tile_indices_vec.begin(), kv_tile_indices_vec.end(), - kv_tile_indices_h); - std::copy(o_indptr_vec.begin(), o_indptr_vec.end(), o_indptr_h); - kv_chunk_size_ptr_h[0] = kv_chunk_size; - - if (split_kv) { - AlignedAllocator float_allocator(float_buffer, - float_workspace_size_in_bytes); - plan_info.v_offset = float_allocator.aligned_alloc_offset( - num_qo_heads * padded_batch_size * cta_tile_q * head_dim_vo * - sizeof(float), - 16, "batch_prefill_tmp_v"); - plan_info.s_offset = float_allocator.aligned_alloc_offset( - num_qo_heads * padded_batch_size * cta_tile_q * sizeof(float), 16, - "batch_prefill_tmp_s"); - plan_info.merge_indptr_offset = int_allocator.aligned_alloc_offset( - sizeof(IdType) * (plan_info.total_num_rows + 1), 16, - "batch_prefill_merge_indptr"); - plan_info.block_valid_mask_offset = int_allocator.aligned_alloc_offset( - sizeof(bool) * padded_batch_size, 16, - "batch_prefill_block_valid_mask"); - - IdType *merge_indptr_h = GetPtrFromBaseOffset( - page_locked_int_buffer, plan_info.merge_indptr_offset); - bool *block_valid_mask_h = GetPtrFromBaseOffset( - page_locked_int_buffer, plan_info.block_valid_mask_offset); - std::copy(merge_indptr_vec.begin(), merge_indptr_vec.end(), - merge_indptr_h); - for (uint32_t i = 0; i < padded_batch_size; ++i) { - block_valid_mask_h[i] = i < new_batch_size; - } - } - - size_t num_bytes_to_copy = int_allocator.num_allocated_bytes(); - FLASHINFER_CUDA_CALL(hipMemcpyAsync(int_buffer, page_locked_int_buffer, - num_bytes_to_copy, - hipMemcpyHostToDevice, stream)); - - return hipSuccess; + return hipSuccess; } -inline float cost_function(int qo_len, int kv_len) -{ - return 2 * float(qo_len) + kv_len; -} +inline float cost_function(int qo_len, int kv_len) { return 2 * float(qo_len) + kv_len; } template -std::vector flatten(const std::vector> &vec, - int size_after_flatten) -{ - std::vector result; - result.reserve(size_after_flatten); - for (const auto &inner_vec : vec) { - result.insert(result.end(), inner_vec.begin(), inner_vec.end()); - } - return result; +std::vector flatten(const std::vector>& vec, int size_after_flatten) { + std::vector result; + result.reserve(size_after_flatten); + for (const auto& inner_vec : vec) { + result.insert(result.end(), inner_vec.begin(), inner_vec.end()); + } + return result; } -struct PrefillPlanSM90Info -{ - int64_t qo_tile_indices_offset; - int64_t qo_indptr_offset; - int64_t kv_indptr_offset; - int64_t qo_len_offset; - int64_t kv_len_offset; - int64_t head_indices_offset; - int64_t work_indptr_offset; - bool same_schedule_for_all_heads; - - PrefillPlanSM90Info() - : qo_tile_indices_offset(0), qo_indptr_offset(0), kv_indptr_offset(0), - qo_len_offset(0), kv_len_offset(0), head_indices_offset(0), - work_indptr_offset(0), same_schedule_for_all_heads(false) - { - } - - // convert PrefillPlanSM90Info to std::vector - std::vector ToVector() const - { - return {qo_tile_indices_offset, qo_indptr_offset, - kv_indptr_offset, qo_len_offset, - kv_len_offset, head_indices_offset, - work_indptr_offset, same_schedule_for_all_heads}; - } - - // From std::vector to PrefillPlanSM90Info - void FromVector(const std::vector &vec) - { - if (vec.size() != 8) { - std::ostringstream err_msg; - err_msg << "PrefillPlanSM90Info::FromVector: vec.size() should be " - "8, but got " - << vec.size(); - FLASHINFER_ERROR(err_msg.str()); - } - qo_tile_indices_offset = vec[0]; - qo_indptr_offset = vec[1]; - kv_indptr_offset = vec[2]; - qo_len_offset = vec[3]; - kv_len_offset = vec[4]; - head_indices_offset = vec[5]; - work_indptr_offset = vec[6]; - same_schedule_for_all_heads = vec[7]; +struct PrefillPlanSM90Info { + int64_t qo_tile_indices_offset; + int64_t qo_indptr_offset; + int64_t kv_indptr_offset; + int64_t qo_len_offset; + int64_t kv_len_offset; + int64_t head_indices_offset; + int64_t work_indptr_offset; + bool same_schedule_for_all_heads; + + PrefillPlanSM90Info() + : qo_tile_indices_offset(0), + qo_indptr_offset(0), + kv_indptr_offset(0), + qo_len_offset(0), + kv_len_offset(0), + head_indices_offset(0), + work_indptr_offset(0), + same_schedule_for_all_heads(false) {} + + // convert PrefillPlanSM90Info to std::vector + std::vector ToVector() const { + return {qo_tile_indices_offset, qo_indptr_offset, + kv_indptr_offset, qo_len_offset, + kv_len_offset, head_indices_offset, + work_indptr_offset, same_schedule_for_all_heads}; + } + + // From std::vector to PrefillPlanSM90Info + void FromVector(const std::vector& vec) { + if (vec.size() != 8) { + std::ostringstream err_msg; + err_msg << "PrefillPlanSM90Info::FromVector: vec.size() should be " + "8, but got " + << vec.size(); + FLASHINFER_ERROR(err_msg.str()); } + qo_tile_indices_offset = vec[0]; + qo_indptr_offset = vec[1]; + kv_indptr_offset = vec[2]; + qo_len_offset = vec[3]; + kv_len_offset = vec[4]; + head_indices_offset = vec[5]; + work_indptr_offset = vec[6]; + same_schedule_for_all_heads = vec[7]; + } }; template -inline hipError_t PrefillSM90Plan(void *float_buffer, - size_t float_workspace_size_in_bytes, - void *int_buffer, - void *page_locked_int_buffer, - size_t int_workspace_size_in_bytes, - PrefillPlanSM90Info &plan_info, - IdType *qo_indptr_h, - IdType *kv_indptr_h, - IdType *kv_len_arr_h, - uint32_t total_num_rows, - uint32_t batch_size, - uint32_t num_qo_heads, - uint32_t num_kv_heads, - uint32_t head_dim_qk, - uint32_t head_dim_vo, - uint32_t page_size, - bool causal, - bool enable_cuda_graph, - uint32_t sizeof_dtype_o, - hipStream_t stream) -{ - if (num_qo_heads % num_kv_heads != 0) { - std::ostringstream err_msg; - err_msg << "num_qo_heads " << num_qo_heads - << " should be divisible by num_kv_heads " << num_kv_heads; - FLASHINFER_ERROR(err_msg.str()); - } - - std::vector> idx_qo_kv_len_vec; - for (uint32_t i = 0; i < batch_size; ++i) { - int qo_len = qo_indptr_h[i + 1] - qo_indptr_h[i]; - int kv_len = kv_len_arr_h[i]; - if (kv_len < 0) { - std::ostringstream err_msg; - err_msg << "kv_len[" << i << "]" << kv_len - << " should be non-negative"; - FLASHINFER_ERROR(err_msg.str()); - } - if (qo_len < 0) { - std::ostringstream err_msg; - err_msg << "qo_indptr[" << i + 1 << "]" << qo_indptr_h[i + 1] - << " - qo_indptr[" << i << "]" << qo_indptr_h[i] - << " should be non-negative"; - FLASHINFER_ERROR(err_msg.str()); - } - idx_qo_kv_len_vec.push_back({i, qo_len, kv_len}); +inline hipError_t PrefillSM90Plan( + void* float_buffer, size_t float_workspace_size_in_bytes, void* int_buffer, + void* page_locked_int_buffer, size_t int_workspace_size_in_bytes, + PrefillPlanSM90Info& plan_info, IdType* qo_indptr_h, IdType* kv_indptr_h, IdType* kv_len_arr_h, + uint32_t total_num_rows, uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads, + uint32_t head_dim_qk, uint32_t head_dim_vo, uint32_t page_size, bool causal, + bool enable_cuda_graph, uint32_t sizeof_dtype_o, hipStream_t stream) { + if (num_qo_heads % num_kv_heads != 0) { + std::ostringstream err_msg; + err_msg << "num_qo_heads " << num_qo_heads << " should be divisible by num_kv_heads " + << num_kv_heads; + FLASHINFER_ERROR(err_msg.str()); + } + + std::vector> idx_qo_kv_len_vec; + for (uint32_t i = 0; i < batch_size; ++i) { + int qo_len = qo_indptr_h[i + 1] - qo_indptr_h[i]; + int kv_len = kv_len_arr_h[i]; + if (kv_len < 0) { + std::ostringstream err_msg; + err_msg << "kv_len[" << i << "]" << kv_len << " should be non-negative"; + FLASHINFER_ERROR(err_msg.str()); } - - std::sort(idx_qo_kv_len_vec.begin(), idx_qo_kv_len_vec.end(), - [](const auto &a, const auto &b) { - return std::get<2>(a) > std::get<2>(b); - }); - int cta_tile_q = 128; - if (head_dim_vo == 64) { - cta_tile_q = 192; + if (qo_len < 0) { + std::ostringstream err_msg; + err_msg << "qo_indptr[" << i + 1 << "]" << qo_indptr_h[i + 1] << " - qo_indptr[" << i << "]" + << qo_indptr_h[i] << " should be non-negative"; + FLASHINFER_ERROR(err_msg.str()); } - - int device = 0; - FLASHINFER_CUDA_CALL(hipGetDevice(&device)); - int num_sm90_ctas = 0; - FLASHINFER_CUDA_CALL(hipDeviceGetAttribute( - &num_sm90_ctas, hipDeviceAttributeMultiprocessorCount, device)); - - MinHeap cta_cost_heap(num_sm90_ctas); - std::vector> cta_qo_tile_indices(num_sm90_ctas, - std::vector()), - cta_qo_indptr(num_sm90_ctas, std::vector()), - cta_kv_indptr(num_sm90_ctas, std::vector()), - cta_qo_len(num_sm90_ctas, std::vector()), - cta_kv_len(num_sm90_ctas, std::vector()), - cta_head_indices(num_sm90_ctas, std::vector()); - - int max_num_works_per_head = - ceil_div(total_num_rows, cta_tile_q) + batch_size - 1; - plan_info.same_schedule_for_all_heads = max_num_works_per_head > 4096; - - for (int qo_head_idx = 0; - qo_head_idx < - (plan_info.same_schedule_for_all_heads ? 1 : num_qo_heads); - ++qo_head_idx) - { - for (auto &[i, qo_len, kv_len] : idx_qo_kv_len_vec) { - int num_qo_tiles = ceil_div(qo_len, cta_tile_q); - for (int qo_tile_idx = num_qo_tiles - 1; qo_tile_idx >= 0; - --qo_tile_idx) - { - auto [cta_idx, accum_cost] = cta_cost_heap.pop(); - // NOTE(Zihao): our current FA3 implementation do not fuse query - // and group heads so the group_size in cost_function is always - // 1 - cta_cost_heap.insert( - {cta_idx, - accum_cost + - cost_function(cta_tile_q, - causal ? kv_len - (num_qo_tiles - - qo_tile_idx - 1) * - cta_tile_q - : kv_len)}); - cta_qo_tile_indices[cta_idx].push_back(qo_tile_idx); - cta_qo_indptr[cta_idx].push_back(qo_indptr_h[i]); - cta_qo_len[cta_idx].push_back(qo_len); - cta_kv_indptr[cta_idx].push_back(kv_indptr_h[i]); - cta_kv_len[cta_idx].push_back(kv_len); - cta_head_indices[cta_idx].push_back(qo_head_idx); - } - } - } - - std::vector work_indptr_vec(num_sm90_ctas + 1, 0); - for (uint32_t i = 0; i < num_sm90_ctas; ++i) { - work_indptr_vec[i + 1] = - work_indptr_vec[i] + cta_qo_tile_indices[i].size(); + idx_qo_kv_len_vec.push_back({i, qo_len, kv_len}); + } + + std::sort(idx_qo_kv_len_vec.begin(), idx_qo_kv_len_vec.end(), + [](const auto& a, const auto& b) { return std::get<2>(a) > std::get<2>(b); }); + int cta_tile_q = 128; + if (head_dim_vo == 64) { + cta_tile_q = 192; + } + + int device = 0; + FLASHINFER_CUDA_CALL(hipGetDevice(&device)); + int num_sm90_ctas = 0; + FLASHINFER_CUDA_CALL( + hipDeviceGetAttribute(&num_sm90_ctas, hipDeviceAttributeMultiprocessorCount, device)); + + MinHeap cta_cost_heap(num_sm90_ctas); + std::vector> cta_qo_tile_indices(num_sm90_ctas, std::vector()), + cta_qo_indptr(num_sm90_ctas, std::vector()), + cta_kv_indptr(num_sm90_ctas, std::vector()), + cta_qo_len(num_sm90_ctas, std::vector()), + cta_kv_len(num_sm90_ctas, std::vector()), + cta_head_indices(num_sm90_ctas, std::vector()); + + int max_num_works_per_head = ceil_div(total_num_rows, cta_tile_q) + batch_size - 1; + plan_info.same_schedule_for_all_heads = max_num_works_per_head > 4096; + + for (int qo_head_idx = 0; + qo_head_idx < (plan_info.same_schedule_for_all_heads ? 1 : num_qo_heads); ++qo_head_idx) { + for (auto& [i, qo_len, kv_len] : idx_qo_kv_len_vec) { + int num_qo_tiles = ceil_div(qo_len, cta_tile_q); + for (int qo_tile_idx = num_qo_tiles - 1; qo_tile_idx >= 0; --qo_tile_idx) { + auto [cta_idx, accum_cost] = cta_cost_heap.pop(); + // NOTE(Zihao): our current FA3 implementation do not fuse query + // and group heads so the group_size in cost_function is always + // 1 + cta_cost_heap.insert( + {cta_idx, accum_cost + cost_function(cta_tile_q, causal ? kv_len - (num_qo_tiles - + qo_tile_idx - 1) * + cta_tile_q + : kv_len)}); + cta_qo_tile_indices[cta_idx].push_back(qo_tile_idx); + cta_qo_indptr[cta_idx].push_back(qo_indptr_h[i]); + cta_qo_len[cta_idx].push_back(qo_len); + cta_kv_indptr[cta_idx].push_back(kv_indptr_h[i]); + cta_kv_len[cta_idx].push_back(kv_len); + cta_head_indices[cta_idx].push_back(qo_head_idx); + } } - int total_num_works = work_indptr_vec.back(); - auto qo_tile_indices_vec = flatten(cta_qo_tile_indices, total_num_works); - auto qo_indptr_vec = flatten(cta_qo_indptr, total_num_works); - auto kv_indptr_vec = flatten(cta_kv_indptr, total_num_works); - auto qo_len_vec = flatten(cta_qo_len, total_num_works); - auto kv_len_vec = flatten(cta_kv_len, total_num_works); - auto head_indices_vec = flatten(cta_head_indices, total_num_works); - - AlignedAllocator int_allocator(int_buffer, int_workspace_size_in_bytes); - int max_total_num_works; - - if (enable_cuda_graph) { - max_total_num_works = plan_info.same_schedule_for_all_heads - ? max_num_works_per_head - : max_num_works_per_head * num_qo_heads; - } - else { - max_total_num_works = total_num_works; - } - - plan_info.qo_tile_indices_offset = int_allocator.aligned_alloc_offset( - sizeof(IdType) * max_total_num_works, 16, - "batch_prefill_sm90_qo_tile_indices"); - plan_info.qo_indptr_offset = - int_allocator.aligned_alloc_offset(sizeof(IdType) * max_total_num_works, - 16, "batch_prefill_sm90_qo_offset"); - plan_info.kv_indptr_offset = - int_allocator.aligned_alloc_offset(sizeof(IdType) * max_total_num_works, - 16, "batch_prefill_sm90_kv_offset"); - plan_info.qo_len_offset = int_allocator.aligned_alloc_offset( - sizeof(IdType) * max_total_num_works, 16, "batch_prefill_sm90_qo_len"); - plan_info.kv_len_offset = int_allocator.aligned_alloc_offset( - sizeof(IdType) * max_total_num_works, 16, "batch_prefill_sm90_kv_len"); - plan_info.head_indices_offset = int_allocator.aligned_alloc_offset( - sizeof(IdType) * max_total_num_works, 16, - "batch_prefill_sm90_head_indices"); - plan_info.work_indptr_offset = int_allocator.aligned_alloc_offset( - sizeof(IdType) * (num_sm90_ctas + 1), 16, - "batch_prefill_sm90_work_indptr"); - - IdType *qo_tile_indices_h = GetPtrFromBaseOffset( - page_locked_int_buffer, plan_info.qo_tile_indices_offset); - IdType *qo_offset_h = GetPtrFromBaseOffset( - page_locked_int_buffer, plan_info.qo_indptr_offset); - IdType *kv_offset_h = GetPtrFromBaseOffset( - page_locked_int_buffer, plan_info.kv_indptr_offset); - IdType *qo_len_h = GetPtrFromBaseOffset(page_locked_int_buffer, - plan_info.qo_len_offset); - IdType *kv_len_h = GetPtrFromBaseOffset(page_locked_int_buffer, - plan_info.kv_len_offset); - IdType *head_indices_h = GetPtrFromBaseOffset( - page_locked_int_buffer, plan_info.head_indices_offset); - IdType *work_indptr_h = GetPtrFromBaseOffset( - page_locked_int_buffer, plan_info.work_indptr_offset); - - std::copy(qo_tile_indices_vec.begin(), qo_tile_indices_vec.end(), - qo_tile_indices_h); - std::copy(qo_indptr_vec.begin(), qo_indptr_vec.end(), qo_offset_h); - std::copy(kv_indptr_vec.begin(), kv_indptr_vec.end(), kv_offset_h); - std::copy(qo_len_vec.begin(), qo_len_vec.end(), qo_len_h); - std::copy(kv_len_vec.begin(), kv_len_vec.end(), kv_len_h); - std::copy(head_indices_vec.begin(), head_indices_vec.end(), head_indices_h); - std::copy(work_indptr_vec.begin(), work_indptr_vec.end(), work_indptr_h); - - size_t num_bytes_to_copy = int_allocator.num_allocated_bytes(); - FLASHINFER_CUDA_CALL(hipMemcpyAsync(int_buffer, page_locked_int_buffer, - num_bytes_to_copy, - hipMemcpyHostToDevice, stream)); - return hipSuccess; + } + + std::vector work_indptr_vec(num_sm90_ctas + 1, 0); + for (uint32_t i = 0; i < num_sm90_ctas; ++i) { + work_indptr_vec[i + 1] = work_indptr_vec[i] + cta_qo_tile_indices[i].size(); + } + int total_num_works = work_indptr_vec.back(); + auto qo_tile_indices_vec = flatten(cta_qo_tile_indices, total_num_works); + auto qo_indptr_vec = flatten(cta_qo_indptr, total_num_works); + auto kv_indptr_vec = flatten(cta_kv_indptr, total_num_works); + auto qo_len_vec = flatten(cta_qo_len, total_num_works); + auto kv_len_vec = flatten(cta_kv_len, total_num_works); + auto head_indices_vec = flatten(cta_head_indices, total_num_works); + + AlignedAllocator int_allocator(int_buffer, int_workspace_size_in_bytes); + int max_total_num_works; + + if (enable_cuda_graph) { + max_total_num_works = plan_info.same_schedule_for_all_heads + ? max_num_works_per_head + : max_num_works_per_head * num_qo_heads; + } else { + max_total_num_works = total_num_works; + } + + plan_info.qo_tile_indices_offset = int_allocator.aligned_alloc_offset( + sizeof(IdType) * max_total_num_works, 16, "batch_prefill_sm90_qo_tile_indices"); + plan_info.qo_indptr_offset = int_allocator.aligned_alloc_offset( + sizeof(IdType) * max_total_num_works, 16, "batch_prefill_sm90_qo_offset"); + plan_info.kv_indptr_offset = int_allocator.aligned_alloc_offset( + sizeof(IdType) * max_total_num_works, 16, "batch_prefill_sm90_kv_offset"); + plan_info.qo_len_offset = int_allocator.aligned_alloc_offset(sizeof(IdType) * max_total_num_works, + 16, "batch_prefill_sm90_qo_len"); + plan_info.kv_len_offset = int_allocator.aligned_alloc_offset(sizeof(IdType) * max_total_num_works, + 16, "batch_prefill_sm90_kv_len"); + plan_info.head_indices_offset = int_allocator.aligned_alloc_offset( + sizeof(IdType) * max_total_num_works, 16, "batch_prefill_sm90_head_indices"); + plan_info.work_indptr_offset = int_allocator.aligned_alloc_offset( + sizeof(IdType) * (num_sm90_ctas + 1), 16, "batch_prefill_sm90_work_indptr"); + + IdType* qo_tile_indices_h = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.qo_tile_indices_offset); + IdType* qo_offset_h = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.qo_indptr_offset); + IdType* kv_offset_h = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.kv_indptr_offset); + IdType* qo_len_h = GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.qo_len_offset); + IdType* kv_len_h = GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.kv_len_offset); + IdType* head_indices_h = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.head_indices_offset); + IdType* work_indptr_h = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.work_indptr_offset); + + std::copy(qo_tile_indices_vec.begin(), qo_tile_indices_vec.end(), qo_tile_indices_h); + std::copy(qo_indptr_vec.begin(), qo_indptr_vec.end(), qo_offset_h); + std::copy(kv_indptr_vec.begin(), kv_indptr_vec.end(), kv_offset_h); + std::copy(qo_len_vec.begin(), qo_len_vec.end(), qo_len_h); + std::copy(kv_len_vec.begin(), kv_len_vec.end(), kv_len_h); + std::copy(head_indices_vec.begin(), head_indices_vec.end(), head_indices_h); + std::copy(work_indptr_vec.begin(), work_indptr_vec.end(), work_indptr_h); + + size_t num_bytes_to_copy = int_allocator.num_allocated_bytes(); + FLASHINFER_CUDA_CALL(hipMemcpyAsync(int_buffer, page_locked_int_buffer, num_bytes_to_copy, + hipMemcpyHostToDevice, stream)); + return hipSuccess; } -inline int packed_causal_kv_end(int qo_len, - int kv_len, - int qo_tile_idx, - int cluster_tile_q, - int num_qo_tiles, - int group_size) -{ - if (qo_tile_idx + 1 == num_qo_tiles) { - return kv_len; - } - int kv_len_init = kv_len - qo_len; - return kv_len_init + (qo_tile_idx + 1) * cluster_tile_q / group_size; +inline int packed_causal_kv_end(int qo_len, int kv_len, int qo_tile_idx, int cluster_tile_q, + int num_qo_tiles, int group_size) { + if (qo_tile_idx + 1 == num_qo_tiles) { + return kv_len; + } + int kv_len_init = kv_len - qo_len; + return kv_len_init + (qo_tile_idx + 1) * cluster_tile_q / group_size; } -struct MLAPlanInfo -{ - int64_t num_blks_x; - int64_t num_blks_y; - int64_t q_indptr_offset; - int64_t kv_indptr_offset; - int64_t partial_indptr_offset; - int64_t merge_packed_offset_start_offset; - int64_t merge_packed_offset_end_offset; - int64_t merge_partial_packed_offset_start_offset; - int64_t merge_partial_packed_offset_end_offset; - int64_t merge_partial_stride_offset; - int64_t q_len_offset; - int64_t kv_len_offset; - int64_t q_start_offset; - int64_t kv_start_offset; - int64_t kv_end_offset; - int64_t work_indptr_offset; - int64_t partial_o_offset; - int64_t partial_lse_offset; - - std::vector ToVector() const - { - return {num_blks_x, - num_blks_y, - q_indptr_offset, - kv_indptr_offset, - partial_indptr_offset, - merge_packed_offset_start_offset, - merge_packed_offset_end_offset, - merge_partial_packed_offset_start_offset, - merge_partial_packed_offset_end_offset, - merge_partial_stride_offset, - q_len_offset, - kv_len_offset, - q_start_offset, - kv_start_offset, - kv_end_offset, - work_indptr_offset, - partial_o_offset, - partial_lse_offset}; - } - - void FromVector(const std::vector &vec) - { - if (vec.size() != 18) { - std::ostringstream err_msg; - err_msg - << "MLAPlanInfo::FromVector: vec.size() should be 18, but got " - << vec.size(); - FLASHINFER_ERROR(err_msg.str()); - } - num_blks_x = vec[0]; - num_blks_y = vec[1]; - q_indptr_offset = vec[2]; - kv_indptr_offset = vec[3]; - partial_indptr_offset = vec[4]; - merge_packed_offset_start_offset = vec[5]; - merge_packed_offset_end_offset = vec[6]; - merge_partial_packed_offset_start_offset = vec[7]; - merge_partial_packed_offset_end_offset = vec[8]; - merge_partial_stride_offset = vec[9]; - q_len_offset = vec[10]; - kv_len_offset = vec[11]; - q_start_offset = vec[12]; - kv_start_offset = vec[13]; - kv_end_offset = vec[14]; - work_indptr_offset = vec[15]; - partial_o_offset = vec[16]; - partial_lse_offset = vec[17]; +struct MLAPlanInfo { + int64_t num_blks_x; + int64_t num_blks_y; + int64_t q_indptr_offset; + int64_t kv_indptr_offset; + int64_t partial_indptr_offset; + int64_t merge_packed_offset_start_offset; + int64_t merge_packed_offset_end_offset; + int64_t merge_partial_packed_offset_start_offset; + int64_t merge_partial_packed_offset_end_offset; + int64_t merge_partial_stride_offset; + int64_t q_len_offset; + int64_t kv_len_offset; + int64_t q_start_offset; + int64_t kv_start_offset; + int64_t kv_end_offset; + int64_t work_indptr_offset; + int64_t partial_o_offset; + int64_t partial_lse_offset; + + std::vector ToVector() const { + return {num_blks_x, + num_blks_y, + q_indptr_offset, + kv_indptr_offset, + partial_indptr_offset, + merge_packed_offset_start_offset, + merge_packed_offset_end_offset, + merge_partial_packed_offset_start_offset, + merge_partial_packed_offset_end_offset, + merge_partial_stride_offset, + q_len_offset, + kv_len_offset, + q_start_offset, + kv_start_offset, + kv_end_offset, + work_indptr_offset, + partial_o_offset, + partial_lse_offset}; + } + + void FromVector(const std::vector& vec) { + if (vec.size() != 18) { + std::ostringstream err_msg; + err_msg << "MLAPlanInfo::FromVector: vec.size() should be 18, but got " << vec.size(); + FLASHINFER_ERROR(err_msg.str()); } + num_blks_x = vec[0]; + num_blks_y = vec[1]; + q_indptr_offset = vec[2]; + kv_indptr_offset = vec[3]; + partial_indptr_offset = vec[4]; + merge_packed_offset_start_offset = vec[5]; + merge_packed_offset_end_offset = vec[6]; + merge_partial_packed_offset_start_offset = vec[7]; + merge_partial_packed_offset_end_offset = vec[8]; + merge_partial_stride_offset = vec[9]; + q_len_offset = vec[10]; + kv_len_offset = vec[11]; + q_start_offset = vec[12]; + kv_start_offset = vec[13]; + kv_end_offset = vec[14]; + work_indptr_offset = vec[15]; + partial_o_offset = vec[16]; + partial_lse_offset = vec[17]; + } }; template -inline hipError_t MLAPlan(void *float_buffer, - size_t float_workspace_size_in_bytes, - void *int_buffer, - void *page_locked_int_buffer, - size_t int_workspace_size_in_bytes, - MLAPlanInfo &plan_info, - IdType *qo_indptr_h, - IdType *kv_indptr_h, - IdType *kv_len_arr_h, - uint32_t batch_size, - uint32_t num_heads, - uint32_t head_dim_o, - bool causal, - hipStream_t stream) -{ - int num_sm = 0; - int dev_id = 0; - FLASHINFER_CUDA_CALL(hipGetDevice(&dev_id)); - FLASHINFER_CUDA_CALL(hipDeviceGetAttribute( - &num_sm, hipDeviceAttributeMultiprocessorCount, dev_id)); - - // step 0. determine the number of blocks in x and y dimensions - int accum_packed_qo_len = 0; - std::vector> idx_qo_kv_len_vec; - for (uint32_t i = 0; i < batch_size; ++i) { - if (qo_indptr_h[i + 1] - qo_indptr_h[i] < 0) { - std::ostringstream err_msg; - err_msg << "qo_indptr[" << i + 1 << "]" << qo_indptr_h[i + 1] - << " - qo_indptr[" << i << "]" << qo_indptr_h[i] - << " should be non-negative"; - FLASHINFER_ERROR(err_msg.str()); - } - - int qo_len = qo_indptr_h[i + 1] - qo_indptr_h[i]; - int packed_qo_len = qo_len * num_heads; - accum_packed_qo_len += packed_qo_len; - - int kv_len = kv_len_arr_h[i]; - idx_qo_kv_len_vec.push_back({i, qo_len, kv_len}); +inline hipError_t MLAPlan(void* float_buffer, size_t float_workspace_size_in_bytes, + void* int_buffer, void* page_locked_int_buffer, + size_t int_workspace_size_in_bytes, MLAPlanInfo& plan_info, + IdType* qo_indptr_h, IdType* kv_indptr_h, IdType* kv_len_arr_h, + uint32_t batch_size, uint32_t num_heads, uint32_t head_dim_o, bool causal, + hipStream_t stream) { + int num_sm = 0; + int dev_id = 0; + FLASHINFER_CUDA_CALL(hipGetDevice(&dev_id)); + FLASHINFER_CUDA_CALL( + hipDeviceGetAttribute(&num_sm, hipDeviceAttributeMultiprocessorCount, dev_id)); + + // step 0. determine the number of blocks in x and y dimensions + int accum_packed_qo_len = 0; + std::vector> idx_qo_kv_len_vec; + for (uint32_t i = 0; i < batch_size; ++i) { + if (qo_indptr_h[i + 1] - qo_indptr_h[i] < 0) { + std::ostringstream err_msg; + err_msg << "qo_indptr[" << i + 1 << "]" << qo_indptr_h[i + 1] << " - qo_indptr[" << i << "]" + << qo_indptr_h[i] << " should be non-negative"; + FLASHINFER_ERROR(err_msg.str()); } - int avg_packed_qo_len = accum_packed_qo_len / batch_size; - int cluster_size; - if (avg_packed_qo_len > 64) { - cluster_size = 2; // two ctas in a cluster + int qo_len = qo_indptr_h[i + 1] - qo_indptr_h[i]; + int packed_qo_len = qo_len * num_heads; + accum_packed_qo_len += packed_qo_len; + + int kv_len = kv_len_arr_h[i]; + idx_qo_kv_len_vec.push_back({i, qo_len, kv_len}); + } + int avg_packed_qo_len = accum_packed_qo_len / batch_size; + + int cluster_size; + if (avg_packed_qo_len > 64) { + cluster_size = 2; // two ctas in a cluster + } else { + cluster_size = 1; // one cta in a cluster + } + uint32_t num_clusters = num_sm / cluster_size; + plan_info.num_blks_x = cluster_size; + plan_info.num_blks_y = num_clusters; + const int cta_tile_q = 64; + int cluster_tile_q = cluster_size * cta_tile_q; + + int64_t total_kv_lens = 0; + for (auto& [_, qo_len, kv_len] : idx_qo_kv_len_vec) { + int packed_qo_len = qo_len * num_heads; + int num_qo_tiles = ceil_div(packed_qo_len, cluster_tile_q); + for (int qo_tile_idx = num_qo_tiles - 1; qo_tile_idx >= 0; --qo_tile_idx) { + int effective_kv_len = causal ? packed_causal_kv_end(qo_len, kv_len, qo_tile_idx, + cluster_tile_q, num_qo_tiles, num_heads) + : kv_len; + total_kv_lens += effective_kv_len; } - else { - cluster_size = 1; // one cta in a cluster + } + + auto f = [](int x) { + if (x <= 8) { + return 32; + } else if (x <= 16) { + return 64; + } else if (x <= 32) { + return 128; + } else if (x <= 64) { + return 192; } - uint32_t num_clusters = num_sm / cluster_size; - plan_info.num_blks_x = cluster_size; - plan_info.num_blks_y = num_clusters; - const int cta_tile_q = 64; - int cluster_tile_q = cluster_size * cta_tile_q; - - int64_t total_kv_lens = 0; - for (auto &[_, qo_len, kv_len] : idx_qo_kv_len_vec) { - int packed_qo_len = qo_len * num_heads; - int num_qo_tiles = ceil_div(packed_qo_len, cluster_tile_q); - for (int qo_tile_idx = num_qo_tiles - 1; qo_tile_idx >= 0; - --qo_tile_idx) - { - int effective_kv_len = - causal ? packed_causal_kv_end(qo_len, kv_len, qo_tile_idx, - cluster_tile_q, num_qo_tiles, - num_heads) - : kv_len; - total_kv_lens += effective_kv_len; - } - } - - auto f = [](int x) { - if (x <= 8) { - return 32; - } - else if (x <= 16) { - return 64; + return ceil_div(x, 256) * 256; + }; + + assert(num_clusters > 0); + int kv_len_limit = f(std::max(ceil_div(total_kv_lens, num_clusters), 1L)); + + // step 1. load-balancing scheduling algorithm + MinHeap cluster_cost_heap(num_clusters); + std::vector> cluster_q_indptr(num_clusters, std::vector()), + cluster_kv_indptr(num_clusters, std::vector()), + cluster_q_len(num_clusters, std::vector()), + cluster_kv_len(num_clusters, std::vector()), + cluster_q_start(num_clusters, std::vector()), + cluster_kv_start(num_clusters, std::vector()), + cluster_kv_end(num_clusters, std::vector()), + cluster_partial_indptr(num_clusters, std::vector()); + + std::vector merge_packed_offset_start(num_sm, 0), merge_packed_offset_end(num_sm, 0), + merge_partial_packed_offset_start(num_sm, 0), merge_partial_packed_offset_end(num_sm, 0), + merge_partial_stride(num_sm, 0); + + int merge_cta_counter = 0; + int partial_o_nnz = 0; + + for (auto& [i, qo_len, kv_len] : idx_qo_kv_len_vec) { + int packed_qo_len = qo_len * num_heads; + int num_qo_tiles = ceil_div(packed_qo_len, cluster_tile_q); + for (int qo_tile_idx = num_qo_tiles - 1; qo_tile_idx >= 0; --qo_tile_idx) { + int remaining_len = causal ? packed_causal_kv_end(qo_len, kv_len, qo_tile_idx, cluster_tile_q, + num_qo_tiles, num_heads) + : kv_len; + int kv_start = 0; + bool split_kv = remaining_len > kv_len_limit; + int row_tile_size = std::min(cluster_tile_q, packed_qo_len - qo_tile_idx * cluster_tile_q); + if (split_kv) { + /* + * Proof(Zihao): merge_cta_counter <= num_sm (num_sm == + num_clusters * cluster_size) + * + * Precondition: + * 1. kv_len_limit * num_clusters >= total_kv_lens == + sum(remaining_len) + * 2. num_qo_chunks <= max((remaining_len * cluster_size) // + kv_len_limit, 1) + * 3. num_qo_tiles_requires_split <= num_clusters + + * Implication: + * 1. sum(num_qo_chunks) <= max(sum(remaining_len) * + cluster_size / kv_len_limit, num_qo_tiles_requires_split) + * 2. sum(num_qo_chunks) <= max(cluster_size * num_clusters, + num_qo_tiles_requires_split) + */ + int num_qo_chunks = std::max(remaining_len * cluster_size / kv_len_limit, 1); + // row_chunk_size * num_qo_chunks >= row_tile_size + int row_chunk_size = ceil_div(row_tile_size, num_qo_chunks); + int current_q_tile_end = + std::min(cluster_tile_q, packed_qo_len - qo_tile_idx * cluster_tile_q); + for (int offset_start = 0; offset_start < row_tile_size; offset_start += row_chunk_size) { + merge_packed_offset_start[merge_cta_counter] = + qo_indptr_h[i] * num_heads + qo_tile_idx * cluster_tile_q + offset_start; + merge_packed_offset_end[merge_cta_counter] = + qo_indptr_h[i] * num_heads + qo_tile_idx * cluster_tile_q + + std::min(offset_start + row_chunk_size, current_q_tile_end); + merge_partial_packed_offset_start[merge_cta_counter] = partial_o_nnz + offset_start; + merge_partial_packed_offset_end[merge_cta_counter] = + partial_o_nnz + ceil_div(remaining_len, kv_len_limit) * row_tile_size; + merge_partial_stride[merge_cta_counter] = row_tile_size; + merge_cta_counter++; } - else if (x <= 32) { - return 128; + } + bool zero_kv_len = (remaining_len == 0); + while (remaining_len > 0 || zero_kv_len) { + auto [cluster_idx, accum_cost] = cluster_cost_heap.pop(); + int actual_len = std::min(remaining_len, kv_len_limit); + cluster_cost_heap.insert( + {cluster_idx, accum_cost + cost_function(cluster_tile_q, actual_len)}); + cluster_q_len[cluster_idx].push_back(qo_len); + cluster_kv_len[cluster_idx].push_back(kv_len); + cluster_q_indptr[cluster_idx].push_back(qo_indptr_h[i]); + cluster_kv_indptr[cluster_idx].push_back(kv_indptr_h[i]); + if (split_kv) { + cluster_partial_indptr[cluster_idx].push_back(partial_o_nnz); + partial_o_nnz += row_tile_size; + } else { + cluster_partial_indptr[cluster_idx].push_back(-1); } - else if (x <= 64) { - return 192; - } - return ceil_div(x, 256) * 256; - }; - - assert(num_clusters > 0); - int kv_len_limit = f(std::max(ceil_div(total_kv_lens, num_clusters), 1L)); - - // step 1. load-balancing scheduling algorithm - MinHeap cluster_cost_heap(num_clusters); - std::vector> cluster_q_indptr(num_clusters, - std::vector()), - cluster_kv_indptr(num_clusters, std::vector()), - cluster_q_len(num_clusters, std::vector()), - cluster_kv_len(num_clusters, std::vector()), - cluster_q_start(num_clusters, std::vector()), - cluster_kv_start(num_clusters, std::vector()), - cluster_kv_end(num_clusters, std::vector()), - cluster_partial_indptr(num_clusters, std::vector()); - - std::vector merge_packed_offset_start(num_sm, 0), - merge_packed_offset_end(num_sm, 0), - merge_partial_packed_offset_start(num_sm, 0), - merge_partial_packed_offset_end(num_sm, 0), - merge_partial_stride(num_sm, 0); - - int merge_cta_counter = 0; - int partial_o_nnz = 0; - - for (auto &[i, qo_len, kv_len] : idx_qo_kv_len_vec) { - int packed_qo_len = qo_len * num_heads; - int num_qo_tiles = ceil_div(packed_qo_len, cluster_tile_q); - for (int qo_tile_idx = num_qo_tiles - 1; qo_tile_idx >= 0; - --qo_tile_idx) - { - int remaining_len = - causal ? packed_causal_kv_end(qo_len, kv_len, qo_tile_idx, - cluster_tile_q, num_qo_tiles, - num_heads) - : kv_len; - int kv_start = 0; - bool split_kv = remaining_len > kv_len_limit; - int row_tile_size = std::min( - cluster_tile_q, packed_qo_len - qo_tile_idx * cluster_tile_q); - if (split_kv) { - /* - * Proof(Zihao): merge_cta_counter <= num_sm (num_sm == - num_clusters * cluster_size) - * - * Precondition: - * 1. kv_len_limit * num_clusters >= total_kv_lens == - sum(remaining_len) - * 2. num_qo_chunks <= max((remaining_len * cluster_size) // - kv_len_limit, 1) - * 3. num_qo_tiles_requires_split <= num_clusters - - * Implication: - * 1. sum(num_qo_chunks) <= max(sum(remaining_len) * - cluster_size / kv_len_limit, num_qo_tiles_requires_split) - * 2. sum(num_qo_chunks) <= max(cluster_size * num_clusters, - num_qo_tiles_requires_split) - */ - int num_qo_chunks = - std::max(remaining_len * cluster_size / kv_len_limit, 1); - // row_chunk_size * num_qo_chunks >= row_tile_size - int row_chunk_size = ceil_div(row_tile_size, num_qo_chunks); - int current_q_tile_end = - std::min(cluster_tile_q, - packed_qo_len - qo_tile_idx * cluster_tile_q); - for (int offset_start = 0; offset_start < row_tile_size; - offset_start += row_chunk_size) - { - merge_packed_offset_start[merge_cta_counter] = - qo_indptr_h[i] * num_heads + - qo_tile_idx * cluster_tile_q + offset_start; - merge_packed_offset_end[merge_cta_counter] = - qo_indptr_h[i] * num_heads + - qo_tile_idx * cluster_tile_q + - std::min(offset_start + row_chunk_size, - current_q_tile_end); - merge_partial_packed_offset_start[merge_cta_counter] = - partial_o_nnz + offset_start; - merge_partial_packed_offset_end[merge_cta_counter] = - partial_o_nnz + - ceil_div(remaining_len, kv_len_limit) * row_tile_size; - merge_partial_stride[merge_cta_counter] = row_tile_size; - merge_cta_counter++; - } - } - bool zero_kv_len = (remaining_len == 0); - while (remaining_len > 0 || zero_kv_len) { - auto [cluster_idx, accum_cost] = cluster_cost_heap.pop(); - int actual_len = std::min(remaining_len, kv_len_limit); - cluster_cost_heap.insert( - {cluster_idx, - accum_cost + cost_function(cluster_tile_q, actual_len)}); - cluster_q_len[cluster_idx].push_back(qo_len); - cluster_kv_len[cluster_idx].push_back(kv_len); - cluster_q_indptr[cluster_idx].push_back(qo_indptr_h[i]); - cluster_kv_indptr[cluster_idx].push_back(kv_indptr_h[i]); - if (split_kv) { - cluster_partial_indptr[cluster_idx].push_back( - partial_o_nnz); - partial_o_nnz += row_tile_size; - } - else { - cluster_partial_indptr[cluster_idx].push_back(-1); - } - cluster_q_start[cluster_idx].push_back(qo_tile_idx * - cluster_tile_q); - cluster_kv_start[cluster_idx].push_back(kv_start); - cluster_kv_end[cluster_idx].push_back(kv_start + actual_len); - remaining_len -= actual_len; - kv_start += actual_len; - if (zero_kv_len) - break; - } - } - } - - FLASHINFER_CHECK(merge_cta_counter <= num_sm, - "Internal Error: merge_cta_counter should be less than or " - "equal to num_sm, " - "please report this bug to the developers"); - - int max_total_num_works = 16384; // NOTE(Zihao): adjust it later - - std::vector work_indptr_vec(num_clusters + 1, 0); - for (uint32_t i = 0; i < num_clusters; ++i) { - work_indptr_vec[i + 1] = - work_indptr_vec[i] + cluster_q_indptr[i].size(); + cluster_q_start[cluster_idx].push_back(qo_tile_idx * cluster_tile_q); + cluster_kv_start[cluster_idx].push_back(kv_start); + cluster_kv_end[cluster_idx].push_back(kv_start + actual_len); + remaining_len -= actual_len; + kv_start += actual_len; + if (zero_kv_len) break; + } } - int total_num_works = work_indptr_vec.back(); - auto q_indptr_vec = flatten(cluster_q_indptr, total_num_works); - auto kv_indptr_vec = flatten(cluster_kv_indptr, total_num_works); - auto partial_indptr_vec = flatten(cluster_partial_indptr, total_num_works); - auto q_len_vec = flatten(cluster_q_len, total_num_works); - auto kv_len_vec = flatten(cluster_kv_len, total_num_works); - auto q_start_vec = flatten(cluster_q_start, total_num_works); - auto kv_start_vec = flatten(cluster_kv_start, total_num_works); - auto kv_end_vec = flatten(cluster_kv_end, total_num_works); - - AlignedAllocator int_allocator(int_buffer, int_workspace_size_in_bytes); - plan_info.q_indptr_offset = int_allocator.aligned_alloc_offset( - sizeof(IdType) * max_total_num_works, 16, "mla_q_indptr"); - plan_info.kv_indptr_offset = int_allocator.aligned_alloc_offset( - sizeof(IdType) * max_total_num_works, 16, "mla_kv_indptr"); - plan_info.partial_indptr_offset = int_allocator.aligned_alloc_offset( - sizeof(IdType) * max_total_num_works, 16, "mla_partial_indptr"); - plan_info.merge_packed_offset_start_offset = - int_allocator.aligned_alloc_offset(sizeof(IdType) * num_sm, 16, - "mla_merge_packed_offset_start"); - plan_info.merge_packed_offset_end_offset = - int_allocator.aligned_alloc_offset(sizeof(IdType) * num_sm, 16, - "mla_merge_packed_offset_end"); - plan_info.merge_partial_packed_offset_start_offset = - int_allocator.aligned_alloc_offset( - sizeof(IdType) * num_sm, 16, - "mla_merge_partial_packed_offset_start"); - plan_info.merge_partial_packed_offset_end_offset = - int_allocator.aligned_alloc_offset( - sizeof(IdType) * num_sm, 16, "mla_merge_partial_packed_offset_end"); - plan_info.merge_partial_stride_offset = int_allocator.aligned_alloc_offset( - sizeof(IdType) * num_sm, 16, "mla_merge_partial_stride"); - plan_info.q_len_offset = int_allocator.aligned_alloc_offset( - sizeof(IdType) * max_total_num_works, 16, "mla_q_len"); - plan_info.kv_len_offset = int_allocator.aligned_alloc_offset( - sizeof(IdType) * max_total_num_works, 16, "mla_kv_len"); - plan_info.q_start_offset = int_allocator.aligned_alloc_offset( - sizeof(IdType) * max_total_num_works, 16, "mla_q_start"); - plan_info.kv_start_offset = int_allocator.aligned_alloc_offset( - sizeof(IdType) * max_total_num_works, 16, "mla_kv_start"); - plan_info.kv_end_offset = int_allocator.aligned_alloc_offset( - sizeof(IdType) * max_total_num_works, 16, "mla_kv_end"); - plan_info.work_indptr_offset = int_allocator.aligned_alloc_offset( - sizeof(IdType) * max_total_num_works, 16, "mla_work_indptr"); - - IdType *cluster_q_indptr_h = GetPtrFromBaseOffset( - page_locked_int_buffer, plan_info.q_indptr_offset); - IdType *cluster_kv_indptr_h = GetPtrFromBaseOffset( - page_locked_int_buffer, plan_info.kv_indptr_offset); - IdType *cluster_partial_indptr_h = GetPtrFromBaseOffset( - page_locked_int_buffer, plan_info.partial_indptr_offset); - IdType *cluster_merge_packed_offset_start_h = GetPtrFromBaseOffset( - page_locked_int_buffer, plan_info.merge_packed_offset_start_offset); - IdType *cluster_merge_packed_offset_end_h = GetPtrFromBaseOffset( - page_locked_int_buffer, plan_info.merge_packed_offset_end_offset); - IdType *cluster_merge_partial_packed_offset_start_h = - GetPtrFromBaseOffset( - page_locked_int_buffer, - plan_info.merge_partial_packed_offset_start_offset); - IdType *cluster_merge_partial_packed_offset_end_h = - GetPtrFromBaseOffset( - page_locked_int_buffer, - plan_info.merge_partial_packed_offset_end_offset); - IdType *cluster_merge_partial_stride_h = GetPtrFromBaseOffset( - page_locked_int_buffer, plan_info.merge_partial_stride_offset); - IdType *cluster_q_len_h = GetPtrFromBaseOffset( - page_locked_int_buffer, plan_info.q_len_offset); - IdType *cluster_kv_len_h = GetPtrFromBaseOffset( - page_locked_int_buffer, plan_info.kv_len_offset); - IdType *cluster_q_start_h = GetPtrFromBaseOffset( - page_locked_int_buffer, plan_info.q_start_offset); - IdType *cluster_kv_start_h = GetPtrFromBaseOffset( - page_locked_int_buffer, plan_info.kv_start_offset); - IdType *cluster_kv_end_h = GetPtrFromBaseOffset( - page_locked_int_buffer, plan_info.kv_end_offset); - IdType *cluster_work_indptr_h = GetPtrFromBaseOffset( - page_locked_int_buffer, plan_info.work_indptr_offset); - - std::copy(q_indptr_vec.begin(), q_indptr_vec.end(), cluster_q_indptr_h); - std::copy(kv_indptr_vec.begin(), kv_indptr_vec.end(), cluster_kv_indptr_h); - std::copy(partial_indptr_vec.begin(), partial_indptr_vec.end(), - cluster_partial_indptr_h); - std::copy(merge_packed_offset_start.begin(), - merge_packed_offset_start.end(), - cluster_merge_packed_offset_start_h); - std::copy(merge_packed_offset_end.begin(), merge_packed_offset_end.end(), - cluster_merge_packed_offset_end_h); - std::copy(merge_partial_packed_offset_start.begin(), - merge_partial_packed_offset_start.end(), - cluster_merge_partial_packed_offset_start_h); - std::copy(merge_partial_packed_offset_end.begin(), - merge_partial_packed_offset_end.end(), - cluster_merge_partial_packed_offset_end_h); - std::copy(merge_partial_stride.begin(), merge_partial_stride.end(), - cluster_merge_partial_stride_h); - std::copy(q_len_vec.begin(), q_len_vec.end(), cluster_q_len_h); - std::copy(kv_len_vec.begin(), kv_len_vec.end(), cluster_kv_len_h); - std::copy(q_start_vec.begin(), q_start_vec.end(), cluster_q_start_h); - std::copy(kv_start_vec.begin(), kv_start_vec.end(), cluster_kv_start_h); - std::copy(kv_end_vec.begin(), kv_end_vec.end(), cluster_kv_end_h); - std::copy(work_indptr_vec.begin(), work_indptr_vec.end(), - cluster_work_indptr_h); - - size_t num_bytes_to_copy = int_allocator.num_allocated_bytes(); - FLASHINFER_CUDA_CALL(hipMemcpyAsync(int_buffer, page_locked_int_buffer, - num_bytes_to_copy, - hipMemcpyHostToDevice, stream)); - - constexpr size_t sizeof_dtype_o = 2; - AlignedAllocator float_allocator(float_buffer, - float_workspace_size_in_bytes); - plan_info.partial_o_offset = float_allocator.aligned_alloc_offset( - 2 * num_clusters * cluster_tile_q * sizeof_dtype_o * head_dim_o, 16, - "mla_partial_o"); - plan_info.partial_lse_offset = float_allocator.aligned_alloc_offset( - 2 * num_clusters * cluster_tile_q * sizeof(float), 16, - "mla_partial_lse"); - - return hipSuccess; + } + + FLASHINFER_CHECK(merge_cta_counter <= num_sm, + "Internal Error: merge_cta_counter should be less than or " + "equal to num_sm, " + "please report this bug to the developers"); + + int max_total_num_works = 16384; // NOTE(Zihao): adjust it later + + std::vector work_indptr_vec(num_clusters + 1, 0); + for (uint32_t i = 0; i < num_clusters; ++i) { + work_indptr_vec[i + 1] = work_indptr_vec[i] + cluster_q_indptr[i].size(); + } + int total_num_works = work_indptr_vec.back(); + auto q_indptr_vec = flatten(cluster_q_indptr, total_num_works); + auto kv_indptr_vec = flatten(cluster_kv_indptr, total_num_works); + auto partial_indptr_vec = flatten(cluster_partial_indptr, total_num_works); + auto q_len_vec = flatten(cluster_q_len, total_num_works); + auto kv_len_vec = flatten(cluster_kv_len, total_num_works); + auto q_start_vec = flatten(cluster_q_start, total_num_works); + auto kv_start_vec = flatten(cluster_kv_start, total_num_works); + auto kv_end_vec = flatten(cluster_kv_end, total_num_works); + + AlignedAllocator int_allocator(int_buffer, int_workspace_size_in_bytes); + plan_info.q_indptr_offset = + int_allocator.aligned_alloc_offset(sizeof(IdType) * max_total_num_works, 16, "mla_q_indptr"); + plan_info.kv_indptr_offset = + int_allocator.aligned_alloc_offset(sizeof(IdType) * max_total_num_works, 16, "mla_kv_indptr"); + plan_info.partial_indptr_offset = int_allocator.aligned_alloc_offset( + sizeof(IdType) * max_total_num_works, 16, "mla_partial_indptr"); + plan_info.merge_packed_offset_start_offset = int_allocator.aligned_alloc_offset( + sizeof(IdType) * num_sm, 16, "mla_merge_packed_offset_start"); + plan_info.merge_packed_offset_end_offset = int_allocator.aligned_alloc_offset( + sizeof(IdType) * num_sm, 16, "mla_merge_packed_offset_end"); + plan_info.merge_partial_packed_offset_start_offset = int_allocator.aligned_alloc_offset( + sizeof(IdType) * num_sm, 16, "mla_merge_partial_packed_offset_start"); + plan_info.merge_partial_packed_offset_end_offset = int_allocator.aligned_alloc_offset( + sizeof(IdType) * num_sm, 16, "mla_merge_partial_packed_offset_end"); + plan_info.merge_partial_stride_offset = + int_allocator.aligned_alloc_offset(sizeof(IdType) * num_sm, 16, "mla_merge_partial_stride"); + plan_info.q_len_offset = + int_allocator.aligned_alloc_offset(sizeof(IdType) * max_total_num_works, 16, "mla_q_len"); + plan_info.kv_len_offset = + int_allocator.aligned_alloc_offset(sizeof(IdType) * max_total_num_works, 16, "mla_kv_len"); + plan_info.q_start_offset = + int_allocator.aligned_alloc_offset(sizeof(IdType) * max_total_num_works, 16, "mla_q_start"); + plan_info.kv_start_offset = + int_allocator.aligned_alloc_offset(sizeof(IdType) * max_total_num_works, 16, "mla_kv_start"); + plan_info.kv_end_offset = + int_allocator.aligned_alloc_offset(sizeof(IdType) * max_total_num_works, 16, "mla_kv_end"); + plan_info.work_indptr_offset = int_allocator.aligned_alloc_offset( + sizeof(IdType) * max_total_num_works, 16, "mla_work_indptr"); + + IdType* cluster_q_indptr_h = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.q_indptr_offset); + IdType* cluster_kv_indptr_h = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.kv_indptr_offset); + IdType* cluster_partial_indptr_h = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.partial_indptr_offset); + IdType* cluster_merge_packed_offset_start_h = GetPtrFromBaseOffset( + page_locked_int_buffer, plan_info.merge_packed_offset_start_offset); + IdType* cluster_merge_packed_offset_end_h = GetPtrFromBaseOffset( + page_locked_int_buffer, plan_info.merge_packed_offset_end_offset); + IdType* cluster_merge_partial_packed_offset_start_h = GetPtrFromBaseOffset( + page_locked_int_buffer, plan_info.merge_partial_packed_offset_start_offset); + IdType* cluster_merge_partial_packed_offset_end_h = GetPtrFromBaseOffset( + page_locked_int_buffer, plan_info.merge_partial_packed_offset_end_offset); + IdType* cluster_merge_partial_stride_h = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.merge_partial_stride_offset); + IdType* cluster_q_len_h = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.q_len_offset); + IdType* cluster_kv_len_h = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.kv_len_offset); + IdType* cluster_q_start_h = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.q_start_offset); + IdType* cluster_kv_start_h = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.kv_start_offset); + IdType* cluster_kv_end_h = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.kv_end_offset); + IdType* cluster_work_indptr_h = + GetPtrFromBaseOffset(page_locked_int_buffer, plan_info.work_indptr_offset); + + std::copy(q_indptr_vec.begin(), q_indptr_vec.end(), cluster_q_indptr_h); + std::copy(kv_indptr_vec.begin(), kv_indptr_vec.end(), cluster_kv_indptr_h); + std::copy(partial_indptr_vec.begin(), partial_indptr_vec.end(), cluster_partial_indptr_h); + std::copy(merge_packed_offset_start.begin(), merge_packed_offset_start.end(), + cluster_merge_packed_offset_start_h); + std::copy(merge_packed_offset_end.begin(), merge_packed_offset_end.end(), + cluster_merge_packed_offset_end_h); + std::copy(merge_partial_packed_offset_start.begin(), merge_partial_packed_offset_start.end(), + cluster_merge_partial_packed_offset_start_h); + std::copy(merge_partial_packed_offset_end.begin(), merge_partial_packed_offset_end.end(), + cluster_merge_partial_packed_offset_end_h); + std::copy(merge_partial_stride.begin(), merge_partial_stride.end(), + cluster_merge_partial_stride_h); + std::copy(q_len_vec.begin(), q_len_vec.end(), cluster_q_len_h); + std::copy(kv_len_vec.begin(), kv_len_vec.end(), cluster_kv_len_h); + std::copy(q_start_vec.begin(), q_start_vec.end(), cluster_q_start_h); + std::copy(kv_start_vec.begin(), kv_start_vec.end(), cluster_kv_start_h); + std::copy(kv_end_vec.begin(), kv_end_vec.end(), cluster_kv_end_h); + std::copy(work_indptr_vec.begin(), work_indptr_vec.end(), cluster_work_indptr_h); + + size_t num_bytes_to_copy = int_allocator.num_allocated_bytes(); + FLASHINFER_CUDA_CALL(hipMemcpyAsync(int_buffer, page_locked_int_buffer, num_bytes_to_copy, + hipMemcpyHostToDevice, stream)); + + constexpr size_t sizeof_dtype_o = 2; + AlignedAllocator float_allocator(float_buffer, float_workspace_size_in_bytes); + plan_info.partial_o_offset = float_allocator.aligned_alloc_offset( + 2 * num_clusters * cluster_tile_q * sizeof_dtype_o * head_dim_o, 16, "mla_partial_o"); + plan_info.partial_lse_offset = float_allocator.aligned_alloc_offset( + 2 * num_clusters * cluster_tile_q * sizeof(float), 16, "mla_partial_lse"); + + return hipSuccess; } -} // namespace flashinfer -#endif // FLASHINFER_ATTENTION_SCHEDULER_CUH_ +} // namespace flashinfer +#endif // FLASHINFER_ATTENTION_SCHEDULER_CUH_ diff --git a/libflashinfer/include/flashinfer/hip/attention/state.hip.h b/libflashinfer/include/flashinfer/hip/attention/state.hip.h index f7973a84e2..8d1c72e389 100644 --- a/libflashinfer/include/flashinfer/hip/attention/state.hip.h +++ b/libflashinfer/include/flashinfer/hip/attention/state.hip.h @@ -10,76 +10,66 @@ #include "../math.hip.h" #include "../vec_dtypes.hip.h" -namespace flashinfer -{ +namespace flashinfer { /*! * \brief The flashattention state. * \tparam vec_size The size of the vector used in o. */ -template struct state_t -{ - /* the weighted sum of v: exp(pre-softmax logit - m) * v / d */ - vec_t o; - /* maximum value of pre-softmax logits */ - float m; - /* sum of exp(pre-softmax logits - m) */ - float d; +template +struct state_t { + /* the weighted sum of v: exp(pre-softmax logit - m) * v / d */ + vec_t o; + /* maximum value of pre-softmax logits */ + float m; + /* sum of exp(pre-softmax logits - m) */ + float d; - __device__ __forceinline__ void init() - { - o.fill(0.f); - m = -math::inf; - d = 1.f; - } + __device__ __forceinline__ void init() { + o.fill(0.f); + m = -math::inf; + d = 1.f; + } - __device__ __forceinline__ state_t() { init(); } + __device__ __forceinline__ state_t() { init(); } - __device__ __forceinline__ float get_lse() const - { - return m + math::ptx_log2(d); - } + __device__ __forceinline__ float get_lse() const { return m + math::ptx_log2(d); } - /*! - * \brief Merge the state with another state. - * \param other_m The maximum value of pre-softmax logits of the other - * state. - * \param other_d The sum of exp(pre-softmax logits - m) of the other state. - * \param other_o The weighted sum of v of the other state. - */ - __device__ __forceinline__ void - merge(const vec_t &other_o, float other_m, float other_d) - { - float m_prev = m, d_prev = d; - m = max(m_prev, other_m); - d = d_prev * math::ptx_exp2(m_prev - m) + - other_d * math::ptx_exp2(other_m - m); + /*! + * \brief Merge the state with another state. + * \param other_m The maximum value of pre-softmax logits of the other + * state. + * \param other_d The sum of exp(pre-softmax logits - m) of the other state. + * \param other_o The weighted sum of v of the other state. + */ + __device__ __forceinline__ void merge(const vec_t& other_o, float other_m, + float other_d) { + float m_prev = m, d_prev = d; + m = max(m_prev, other_m); + d = d_prev * math::ptx_exp2(m_prev - m) + other_d * math::ptx_exp2(other_m - m); #pragma unroll - for (size_t i = 0; i < vec_size; ++i) { - o[i] = o[i] * math::ptx_exp2(m_prev - m) + - other_o[i] * math::ptx_exp2(other_m - m); - } + for (size_t i = 0; i < vec_size; ++i) { + o[i] = o[i] * math::ptx_exp2(m_prev - m) + other_o[i] * math::ptx_exp2(other_m - m); } + } - /*! - * \brief Merge the state with another state. - * \param other The other state. - */ - __device__ __forceinline__ void merge(const state_t &other) - { - merge(other.o, other.m, other.d); - } + /*! + * \brief Merge the state with another state. + * \param other The other state. + */ + __device__ __forceinline__ void merge(const state_t& other) { + merge(other.o, other.m, other.d); + } - __device__ __forceinline__ void normalize() - { - // only normalize by d when not normalized on the fly + __device__ __forceinline__ void normalize() { + // only normalize by d when not normalized on the fly #pragma unroll - for (size_t i = 0; i < vec_size; ++i) { - o[i] = __fdividef(o[i], d); - } + for (size_t i = 0; i < vec_size; ++i) { + o[i] = __fdividef(o[i], d); } + } }; -} // namespace flashinfer +} // namespace flashinfer -#endif // FLASHINFER_STATE_CUH_ +#endif // FLASHINFER_STATE_CUH_ diff --git a/libflashinfer/include/flashinfer/hip/attention/variant_helper.hip.h b/libflashinfer/include/flashinfer/hip/attention/variant_helper.hip.h index b63c31ff4e..78a1a73ba5 100644 --- a/libflashinfer/include/flashinfer/hip/attention/variant_helper.hip.h +++ b/libflashinfer/include/flashinfer/hip/attention/variant_helper.hip.h @@ -11,66 +11,46 @@ #include -namespace flashinfer -{ - -#define REGISTER_QUERY_TRANSFORM(params, q, ...) \ - template \ - __device__ __forceinline__ T QueryTransform(const Params ¶ms, \ - void *q_smem) \ - { \ - __VA_ARGS__ \ - } - -#define REGISTER_KEY_TRANSFORM(params, k, ...) \ - template \ - __device__ __forceinline__ T KeyTransform(const Params ¶ms, \ - void *k_smem) \ - { \ - __VA_ARGS__ \ - } - -#define REGISTER_LOGITS_TRANSFORM(params, logits, batch_idx, qo_idx, kv_idx, \ - qo_head_idx, kv_head_idx, ...) \ - template \ - __device__ __forceinline__ T LogitsTransform( \ - const Params ¶ms, T logits, uint32_t batch_idx, uint32_t qo_idx, \ - uint32_t kv_idx, uint32_t qo_head_idx, uint32_t kv_head_idx) \ - { \ - __VA_ARGS__ \ - } - -#define REGISTER_LOGITS_MASK(params, batch_idx, qo_idx, kv_idx, qo_head_idx, \ - kv_head_idx, ...) \ - template \ - __device__ __forceinline__ bool LogitsMask( \ - const Params ¶ms, uint32_t batch_idx, uint32_t qo_idx, \ - uint32_t kv_idx, uint32_t qo_head_idx, uint32_t kv_head_idx) \ - { \ - __VA_ARGS__ \ - } - -struct AttentionVariantBase -{ - constexpr static bool use_softmax = true; - REGISTER_LOGITS_TRANSFORM(params, - logits, - batch_idx, - qo_idx, - kv_idx, - qo_head_idx, - kv_head_idx, - { return logits; }) - - REGISTER_LOGITS_MASK(params, - batch_idx, - qo_idx, - kv_idx, - qo_head_idx, - kv_head_idx, - { return true; }) +namespace flashinfer { + +#define REGISTER_QUERY_TRANSFORM(params, q, ...) \ + template \ + __device__ __forceinline__ T QueryTransform(const Params& params, void* q_smem) { \ + __VA_ARGS__ \ + } + +#define REGISTER_KEY_TRANSFORM(params, k, ...) \ + template \ + __device__ __forceinline__ T KeyTransform(const Params& params, void* k_smem) { \ + __VA_ARGS__ \ + } + +#define REGISTER_LOGITS_TRANSFORM(params, logits, batch_idx, qo_idx, kv_idx, qo_head_idx, \ + kv_head_idx, ...) \ + template \ + __device__ __forceinline__ T LogitsTransform(const Params& params, T logits, uint32_t batch_idx, \ + uint32_t qo_idx, uint32_t kv_idx, \ + uint32_t qo_head_idx, uint32_t kv_head_idx) { \ + __VA_ARGS__ \ + } + +#define REGISTER_LOGITS_MASK(params, batch_idx, qo_idx, kv_idx, qo_head_idx, kv_head_idx, ...) \ + template \ + __device__ __forceinline__ bool LogitsMask(const Params& params, uint32_t batch_idx, \ + uint32_t qo_idx, uint32_t kv_idx, \ + uint32_t qo_head_idx, uint32_t kv_head_idx) { \ + __VA_ARGS__ \ + } + +struct AttentionVariantBase { + constexpr static bool use_softmax = true; + REGISTER_LOGITS_TRANSFORM(params, logits, batch_idx, qo_idx, kv_idx, qo_head_idx, kv_head_idx, + { return logits; }) + + REGISTER_LOGITS_MASK(params, batch_idx, qo_idx, kv_idx, qo_head_idx, kv_head_idx, + { return true; }) }; -} // namespace flashinfer +} // namespace flashinfer -#endif // FLASHINFER_ATTENTION_VARIANT_HELPER_H +#endif // FLASHINFER_ATTENTION_VARIANT_HELPER_H diff --git a/libflashinfer/include/flashinfer/hip/attention/variants.hip.h b/libflashinfer/include/flashinfer/hip/attention/variants.hip.h index 65135f7c80..e2ec1453b0 100644 --- a/libflashinfer/include/flashinfer/hip/attention/variants.hip.h +++ b/libflashinfer/include/flashinfer/hip/attention/variants.hip.h @@ -7,111 +7,81 @@ #ifndef FLASHINFER_ATTENTION_VARIANTS_CUH_ #define FLASHINFER_ATTENTION_VARIANTS_CUH_ -#include "../math.hip.h" -#include "../utils.hip.h" - -#include "variant_helper.hip.h" - #include #include #include -namespace flashinfer -{ +#include "../math.hip.h" +#include "../utils.hip.h" +#include "variant_helper.hip.h" + +namespace flashinfer { DEFINE_HAS_MEMBER(maybe_mask_indptr) -template -struct DefaultAttention : AttentionVariantBase -{ - static constexpr bool use_softmax = true; +template +struct DefaultAttention : AttentionVariantBase { + static constexpr bool use_softmax = true; - uint8_t *custom_mask_ptr; - uint32_t qo_len, kv_len; - uint32_t window_left; - float sm_scale_log2; - float soft_cap_pre_tanh_scale; + uint8_t* custom_mask_ptr; + uint32_t qo_len, kv_len; + uint32_t window_left; + float sm_scale_log2; + float soft_cap_pre_tanh_scale; - // Create closure - template - __device__ __host__ DefaultAttention(const Params ¶ms, - uint32_t batch_idx, - uint8_t *smem_ptr) - { - qo_len = params.get_qo_len(batch_idx); - kv_len = params.get_kv_len(batch_idx); - if constexpr (use_logits_soft_cap) { - soft_cap_pre_tanh_scale = - params.sm_scale * math::ptx_rcp(params.logits_soft_cap); - sm_scale_log2 = math::log2e * params.logits_soft_cap; - } - else { - if constexpr (use_alibi) { - sm_scale_log2 = math::log2e; - } - else { - sm_scale_log2 = params.sm_scale * math::log2e; - } - } - if constexpr (use_custom_mask) { - if constexpr (has_maybe_mask_indptr_v) { - custom_mask_ptr = params.maybe_custom_mask + - params.maybe_mask_indptr[batch_idx]; - } - else { - custom_mask_ptr = params.maybe_custom_mask; - } - } - if constexpr (use_sliding_window) { - window_left = - (params.window_left >= 0) ? params.window_left : kv_len; - } + // Create closure + template + __device__ __host__ DefaultAttention(const Params& params, uint32_t batch_idx, + uint8_t* smem_ptr) { + qo_len = params.get_qo_len(batch_idx); + kv_len = params.get_kv_len(batch_idx); + if constexpr (use_logits_soft_cap) { + soft_cap_pre_tanh_scale = params.sm_scale * math::ptx_rcp(params.logits_soft_cap); + sm_scale_log2 = math::log2e * params.logits_soft_cap; + } else { + if constexpr (use_alibi) { + sm_scale_log2 = math::log2e; + } else { + sm_scale_log2 = params.sm_scale * math::log2e; + } + } + if constexpr (use_custom_mask) { + if constexpr (has_maybe_mask_indptr_v) { + custom_mask_ptr = params.maybe_custom_mask + params.maybe_mask_indptr[batch_idx]; + } else { + custom_mask_ptr = params.maybe_custom_mask; + } } + if constexpr (use_sliding_window) { + window_left = (params.window_left >= 0) ? params.window_left : kv_len; + } + } - REGISTER_LOGITS_TRANSFORM( - params, - logits, - batch_idx, - qo_idx, - kv_idx, - qo_head_idx, - kv_head_idx, - { - if constexpr (use_alibi) { - logits = logits * params.sm_scale + - params.maybe_alibi_slopes[qo_head_idx] * - float(int(kv_idx) - int(qo_idx)); - } - if constexpr (use_logits_soft_cap) { - logits = float(math::tanh(logits * soft_cap_pre_tanh_scale)); - } - return logits; - }) + REGISTER_LOGITS_TRANSFORM(params, logits, batch_idx, qo_idx, kv_idx, qo_head_idx, kv_head_idx, { + if constexpr (use_alibi) { + logits = logits * params.sm_scale + + params.maybe_alibi_slopes[qo_head_idx] * float(int(kv_idx) - int(qo_idx)); + } + if constexpr (use_logits_soft_cap) { + logits = float(math::tanh(logits * soft_cap_pre_tanh_scale)); + } + return logits; + }) - REGISTER_LOGITS_MASK( - params, - batch_idx, - qo_idx, - kv_idx, - qo_head_idx, - kv_head_idx, - { - bool mask = true; - if constexpr (use_custom_mask) { - const uint32_t offset = qo_idx * kv_len + kv_idx; - mask &= ((custom_mask_ptr[offset / 8] >> (offset % 8)) & 1); - } - if constexpr (use_sliding_window) { - mask &= (kv_idx + qo_len + window_left >= kv_len + qo_idx); - } - return mask; - }) + REGISTER_LOGITS_MASK(params, batch_idx, qo_idx, kv_idx, qo_head_idx, kv_head_idx, { + bool mask = true; + if constexpr (use_custom_mask) { + const uint32_t offset = qo_idx * kv_len + kv_idx; + mask &= ((custom_mask_ptr[offset / 8] >> (offset % 8)) & 1); + } + if constexpr (use_sliding_window) { + mask &= (kv_idx + qo_len + window_left >= kv_len + qo_idx); + } + return mask; + }) }; -}; // namespace flashinfer +}; // namespace flashinfer -#endif // FLASHINFER_ATTENTION_VARIANTS_CUH_ +#endif // FLASHINFER_ATTENTION_VARIANTS_CUH_ diff --git a/libflashinfer/include/flashinfer/hip/attention_impl.hip.h b/libflashinfer/include/flashinfer/hip/attention_impl.hip.h index ecca6df656..27dac29b45 100644 --- a/libflashinfer/include/flashinfer/hip/attention_impl.hip.h +++ b/libflashinfer/include/flashinfer/hip/attention_impl.hip.h @@ -12,4 +12,4 @@ #include "attention/default_decode_params.hip.h" #include "attention/variants.hip.h" -#endif // FLASHINFER_ATTENTION_IMPL_CUH_ +#endif // FLASHINFER_ATTENTION_IMPL_CUH_ diff --git a/libflashinfer/include/flashinfer/hip/cp_async.hip.h b/libflashinfer/include/flashinfer/hip/cp_async.hip.h index cf8df86baa..495827e7ce 100644 --- a/libflashinfer/include/flashinfer/hip/cp_async.hip.h +++ b/libflashinfer/include/flashinfer/hip/cp_async.hip.h @@ -11,19 +11,16 @@ #include -namespace flashinfer::cp_async -{ +namespace flashinfer::cp_async { -enum class SharedMemFillMode -{ - kFillZero, // Fill zero to shared memory when predicate is false - kNoFill // Do not fill zero to shared memory when predicate is false +enum class SharedMemFillMode { + kFillZero, // Fill zero to shared memory when predicate is false + kNoFill // Do not fill zero to shared memory when predicate is false }; -enum class PrefetchMode -{ - kNoPrefetch, // Do not fetch additional data from global memory to L2 - kPrefetch // Fetch additional data from global memory to L2 +enum class PrefetchMode { + kNoPrefetch, // Do not fetch additional data from global memory to L2 + kPrefetch // Fetch additional data from global memory to L2 }; /// @brief orders memory accesses for all threads within a thread block. @@ -31,9 +28,9 @@ __device__ __forceinline__ void commit_group() { __threadfence_block(); } /// @brief Wrapper of PTX cp.async.wait_group instruction /// @param n Wait till most recent n groups are committed -template __device__ __forceinline__ void wait_group() -{ - __syncthreads(); +template +__device__ __forceinline__ void wait_group() { + __syncthreads(); } /// @brief Wrapper of PTX cp.async.cg.shared.global instruction, asynchronously @@ -44,9 +41,8 @@ template __device__ __forceinline__ void wait_group() /// @param smem_ptr Pointer to shared memory /// @param gmem_ptr Pointer to global memory template -__device__ __forceinline__ void load_128b(T *smem_ptr, const T *gmem_ptr) -{ - *((uint4 *)smem_ptr) = *((uint4 *)gmem_ptr); +__device__ __forceinline__ void load_128b(T* smem_ptr, const T* gmem_ptr) { + *((uint4*)smem_ptr) = *((uint4*)gmem_ptr); } /// @brief Wrapper of PTX cp.async.cg.shared.global instruction, asynchronously @@ -61,17 +57,14 @@ __device__ __forceinline__ void load_128b(T *smem_ptr, const T *gmem_ptr) /// @param predicate Predicate value /// @note fill zero is slower than not fill zero template -__device__ __forceinline__ void -pred_load_128b(T *smem_ptr, const T *gmem_ptr, bool predicate) -{ - if (predicate) { - *((uint4 *)smem_ptr) = *((uint4 *)gmem_ptr); - } - else { - if constexpr (fill_mode == SharedMemFillMode::kFillZero) { - *((uint4 *)smem_ptr) = make_uint4(0, 0, 0, 0); - } +__device__ __forceinline__ void pred_load_128b(T* smem_ptr, const T* gmem_ptr, bool predicate) { + if (predicate) { + *((uint4*)smem_ptr) = *((uint4*)gmem_ptr); + } else { + if constexpr (fill_mode == SharedMemFillMode::kFillZero) { + *((uint4*)smem_ptr) = make_uint4(0, 0, 0, 0); } + } } /// @brief Load specified number of bits per thread from global memory to shared @@ -83,18 +76,14 @@ pred_load_128b(T *smem_ptr, const T *gmem_ptr, bool predicate) /// @param smem_ptr Pointer to shared memory /// @param gmem_ptr Pointer to global memory template -__device__ __forceinline__ void load(T *smem_ptr, const T *gmem_ptr) -{ - static_assert(num_bits == 128 || num_bits == 256, - "num_bits must be 128 or 256"); - if constexpr (num_bits == 128) { - load_128b(smem_ptr, gmem_ptr); - } - else { - load_128b(smem_ptr, gmem_ptr); - load_128b(smem_ptr + 16 / sizeof(T), - gmem_ptr + 16 / sizeof(T)); - } +__device__ __forceinline__ void load(T* smem_ptr, const T* gmem_ptr) { + static_assert(num_bits == 128 || num_bits == 256, "num_bits must be 128 or 256"); + if constexpr (num_bits == 128) { + load_128b(smem_ptr, gmem_ptr); + } else { + load_128b(smem_ptr, gmem_ptr); + load_128b(smem_ptr + 16 / sizeof(T), gmem_ptr + 16 / sizeof(T)); + } } /// @brief Load specified number of bits per thread from global memory to shared @@ -109,25 +98,19 @@ __device__ __forceinline__ void load(T *smem_ptr, const T *gmem_ptr) /// @param gmem_ptr Pointer to global memory /// @param predicate Predicate value /// @note fill zero is slower than not fill zero -template -__device__ __forceinline__ void -pred_load(T *smem_ptr, const T *gmem_ptr, bool predicate) -{ - // static_assert(num_bits == 128 || num_bits == 256, "num_bits must be 128 - // or 256"); - if constexpr (num_bits == 128) { - pred_load_128b(smem_ptr, gmem_ptr, predicate); - } - else { - pred_load_128b(smem_ptr, gmem_ptr, predicate); - pred_load_128b( - smem_ptr + 16 / sizeof(T), gmem_ptr + 16 / sizeof(T), predicate); - } +template +__device__ __forceinline__ void pred_load(T* smem_ptr, const T* gmem_ptr, bool predicate) { + // static_assert(num_bits == 128 || num_bits == 256, "num_bits must be 128 + // or 256"); + if constexpr (num_bits == 128) { + pred_load_128b(smem_ptr, gmem_ptr, predicate); + } else { + pred_load_128b(smem_ptr, gmem_ptr, predicate); + pred_load_128b(smem_ptr + 16 / sizeof(T), gmem_ptr + 16 / sizeof(T), + predicate); + } } -} // namespace flashinfer::cp_async +} // namespace flashinfer::cp_async -#endif // FLASHINFER_CP_ASYNC_CUH_ +#endif // FLASHINFER_CP_ASYNC_CUH_ diff --git a/libflashinfer/include/flashinfer/hip/fastdiv.hip.h b/libflashinfer/include/flashinfer/hip/fastdiv.hip.h index 0700104ac2..bd838de5cd 100644 --- a/libflashinfer/include/flashinfer/hip/fastdiv.hip.h +++ b/libflashinfer/include/flashinfer/hip/fastdiv.hip.h @@ -11,100 +11,84 @@ #include -namespace flashinfer -{ +namespace flashinfer { -struct uint_fastdiv -{ - uint32_t d; - uint32_t m; - uint32_t s; - uint32_t a; +struct uint_fastdiv { + uint32_t d; + uint32_t m; + uint32_t s; + uint32_t a; - __host__ __device__ uint_fastdiv() : d(0), m(0), s(0), a(0) {} + __host__ __device__ uint_fastdiv() : d(0), m(0), s(0), a(0) {} - __host__ uint_fastdiv(uint32_t d) : d(d) - { - unsigned int p, nc, delta, q1, r1, q2, r2; - a = 0; - nc = unsigned(-1) - unsigned(-d) % d; - p = 31; - q1 = 0x80000000 / nc; - r1 = 0x80000000 - q1 * nc; - q2 = 0x7FFFFFFF / d; - r2 = 0x7FFFFFFF - q2 * d; - do { - p++; - if (r1 >= nc - r1) { - q1 = 2 * q1 + 1; - r1 = 2 * r1 - nc; - } - else { - q1 = 2 * q1; - r1 = 2 * r1; - } - if (r2 + 1 >= d - r2) { - if (q2 >= 0x7FFFFFFF) - a = 1; - q2 = 2 * q2 + 1; - r2 = 2 * r2 + 1 - d; - } - else { - if (q2 >= 0x80000000) - a = 1; - q2 = 2 * q2; - r2 = 2 * r2 + 1; - } - delta = d - 1 - r2; - } while (p < 64 && (q1 < delta || (q1 == delta && r1 == 0))); - m = q2 + 1; - s = p - 32; - } + __host__ uint_fastdiv(uint32_t d) : d(d) { + unsigned int p, nc, delta, q1, r1, q2, r2; + a = 0; + nc = unsigned(-1) - unsigned(-d) % d; + p = 31; + q1 = 0x80000000 / nc; + r1 = 0x80000000 - q1 * nc; + q2 = 0x7FFFFFFF / d; + r2 = 0x7FFFFFFF - q2 * d; + do { + p++; + if (r1 >= nc - r1) { + q1 = 2 * q1 + 1; + r1 = 2 * r1 - nc; + } else { + q1 = 2 * q1; + r1 = 2 * r1; + } + if (r2 + 1 >= d - r2) { + if (q2 >= 0x7FFFFFFF) a = 1; + q2 = 2 * q2 + 1; + r2 = 2 * r2 + 1 - d; + } else { + if (q2 >= 0x80000000) a = 1; + q2 = 2 * q2; + r2 = 2 * r2 + 1; + } + delta = d - 1 - r2; + } while (p < 64 && (q1 < delta || (q1 == delta && r1 == 0))); + m = q2 + 1; + s = p - 32; + } - __host__ __device__ __forceinline__ operator unsigned int() const - { - return d; - } + __host__ __device__ __forceinline__ operator unsigned int() const { return d; } - __host__ __device__ __forceinline__ void - divmod(uint32_t n, uint32_t &q, uint32_t &r) const - { - if (d == 1) { - q = n; - } - else { - q = __umulhi(m, n); + __host__ __device__ __forceinline__ void divmod(uint32_t n, uint32_t& q, uint32_t& r) const { + if (d == 1) { + q = n; + } else { + q = __umulhi(m, n); - q += a * n; - q >>= s; - } - r = n - q * d; + q += a * n; + q >>= s; } + r = n - q * d; + } }; -__host__ __device__ __forceinline__ uint32_t -operator/(const uint32_t n, const uint_fastdiv &divisor) -{ - uint32_t q; - if (divisor.d == 1) { - q = n; - } - else { - q = __umulhi(divisor.m, n); - q += divisor.a * n; - q >>= divisor.s; - } - return q; +__host__ __device__ __forceinline__ uint32_t operator/(const uint32_t n, + const uint_fastdiv& divisor) { + uint32_t q; + if (divisor.d == 1) { + q = n; + } else { + q = __umulhi(divisor.m, n); + q += divisor.a * n; + q >>= divisor.s; + } + return q; } -__host__ __device__ __forceinline__ uint32_t -operator%(const uint32_t n, const uint_fastdiv &divisor) -{ - uint32_t quotient = n / divisor; - uint32_t remainder = n - quotient * divisor; - return remainder; +__host__ __device__ __forceinline__ uint32_t operator%(const uint32_t n, + const uint_fastdiv& divisor) { + uint32_t quotient = n / divisor; + uint32_t remainder = n - quotient * divisor; + return remainder; } -} // namespace flashinfer +} // namespace flashinfer -#endif // FLASHINFER_FASTDIV_CUH_ +#endif // FLASHINFER_FASTDIV_CUH_ diff --git a/libflashinfer/include/flashinfer/hip/hip_platform.h b/libflashinfer/include/flashinfer/hip/hip_platform.h index 172083db4e..4dfb7030e7 100644 --- a/libflashinfer/include/flashinfer/hip/hip_platform.h +++ b/libflashinfer/include/flashinfer/hip/hip_platform.h @@ -18,13 +18,12 @@ #include #endif -#define FI_GPU_CALL(call) \ - do { \ - hipError_t err = (call); \ - if (err != hipSuccess) { \ - std::ostringstream err_msg; \ - err_msg << "GPU error: " << hipGetErrorString(err) << " at " \ - << __FILE__ << ":" << __LINE__; \ - throw std::runtime_error(err_msg.str()); \ - } \ - } while (0) +#define FI_GPU_CALL(call) \ + do { \ + hipError_t err = (call); \ + if (err != hipSuccess) { \ + std::ostringstream err_msg; \ + err_msg << "GPU error: " << hipGetErrorString(err) << " at " << __FILE__ << ":" << __LINE__; \ + throw std::runtime_error(err_msg.str()); \ + } \ + } while (0) diff --git a/libflashinfer/include/flashinfer/hip/layout.hip.h b/libflashinfer/include/flashinfer/hip/layout.hip.h index 395408e621..3a060394ee 100644 --- a/libflashinfer/include/flashinfer/hip/layout.hip.h +++ b/libflashinfer/include/flashinfer/hip/layout.hip.h @@ -11,135 +11,103 @@ #include #include -namespace flashinfer -{ +namespace flashinfer { /*! * \brief The Layout of QKV matrices */ -enum class QKVLayout -{ - // [seq_len, num_heads, head_dim] - kNHD = 0U, - // [num_heads, seq_len, head_dim] - kHND = 1U +enum class QKVLayout { + // [seq_len, num_heads, head_dim] + kNHD = 0U, + // [num_heads, seq_len, head_dim] + kHND = 1U }; -__host__ __device__ __inline__ size_t get_elem_offset_impl(size_t elem_idx, - size_t head_idx, - size_t feat_idx, - size_t stride_n, - size_t stride_h) -{ - return elem_idx * stride_n + head_idx * stride_h + feat_idx; +__host__ __device__ __inline__ size_t get_elem_offset_impl(size_t elem_idx, size_t head_idx, + size_t feat_idx, size_t stride_n, + size_t stride_h) { + return elem_idx * stride_n + head_idx * stride_h + feat_idx; } -__host__ __inline__ auto get_qkv_strides(QKVLayout kv_layout, - uint32_t kv_len, - uint32_t num_qo_heads, - uint32_t num_kv_heads, - uint32_t head_dim) -{ - const uint32_t q_stride_n = num_qo_heads * head_dim, q_stride_h = head_dim, - kv_stride_n = (kv_layout == QKVLayout::kNHD) - ? num_kv_heads * head_dim - : head_dim, - kv_stride_h = (kv_layout == QKVLayout::kNHD) - ? head_dim - : kv_len * head_dim; - return std::make_tuple(q_stride_n, q_stride_h, kv_stride_n, kv_stride_h); +__host__ __inline__ auto get_qkv_strides(QKVLayout kv_layout, uint32_t kv_len, + uint32_t num_qo_heads, uint32_t num_kv_heads, + uint32_t head_dim) { + const uint32_t q_stride_n = num_qo_heads * head_dim, q_stride_h = head_dim, + kv_stride_n = (kv_layout == QKVLayout::kNHD) ? num_kv_heads * head_dim : head_dim, + kv_stride_h = (kv_layout == QKVLayout::kNHD) ? head_dim : kv_len * head_dim; + return std::make_tuple(q_stride_n, q_stride_h, kv_stride_n, kv_stride_h); } -struct tensor_info_t -{ - uint32_t qo_len; - uint32_t kv_len; - uint32_t num_qo_heads; - uint32_t num_kv_heads; - uint32_t q_stride_n; - uint32_t q_stride_h; - uint32_t kv_stride_n; - uint32_t kv_stride_h; - uint32_t head_dim; - __host__ __device__ inline tensor_info_t(uint32_t qo_len, - uint32_t kv_len, - uint32_t num_qo_heads, - uint32_t num_kv_heads, - uint32_t q_stride_n, - uint32_t q_stride_h, - uint32_t kv_stride_n, - uint32_t kv_stride_h, - uint32_t head_dim) - : qo_len(qo_len), kv_len(kv_len), num_qo_heads(num_qo_heads), - num_kv_heads(num_kv_heads), q_stride_n(q_stride_n), - q_stride_h(q_stride_h), kv_stride_n(kv_stride_n), - kv_stride_h(kv_stride_h), head_dim(head_dim) - { - } +struct tensor_info_t { + uint32_t qo_len; + uint32_t kv_len; + uint32_t num_qo_heads; + uint32_t num_kv_heads; + uint32_t q_stride_n; + uint32_t q_stride_h; + uint32_t kv_stride_n; + uint32_t kv_stride_h; + uint32_t head_dim; + __host__ __device__ inline tensor_info_t(uint32_t qo_len, uint32_t kv_len, uint32_t num_qo_heads, + uint32_t num_kv_heads, uint32_t q_stride_n, + uint32_t q_stride_h, uint32_t kv_stride_n, + uint32_t kv_stride_h, uint32_t head_dim) + : qo_len(qo_len), + kv_len(kv_len), + num_qo_heads(num_qo_heads), + num_kv_heads(num_kv_heads), + q_stride_n(q_stride_n), + q_stride_h(q_stride_h), + kv_stride_n(kv_stride_n), + kv_stride_h(kv_stride_h), + head_dim(head_dim) {} - __host__ __device__ inline tensor_info_t(uint32_t qo_len, - uint32_t kv_len, - uint32_t num_qo_heads, - uint32_t num_kv_heads, - QKVLayout kv_layout, - uint32_t head_dim) - : qo_len(qo_len), kv_len(kv_len), num_qo_heads(num_qo_heads), - num_kv_heads(num_kv_heads), head_dim(head_dim) - { - q_stride_n = num_qo_heads * head_dim; - q_stride_h = head_dim; - kv_stride_n = - (kv_layout == QKVLayout::kNHD) ? num_kv_heads * head_dim : head_dim; - kv_stride_h = - (kv_layout == QKVLayout::kNHD) ? head_dim : kv_len * head_dim; - } + __host__ __device__ inline tensor_info_t(uint32_t qo_len, uint32_t kv_len, uint32_t num_qo_heads, + uint32_t num_kv_heads, QKVLayout kv_layout, + uint32_t head_dim) + : qo_len(qo_len), + kv_len(kv_len), + num_qo_heads(num_qo_heads), + num_kv_heads(num_kv_heads), + head_dim(head_dim) { + q_stride_n = num_qo_heads * head_dim; + q_stride_h = head_dim; + kv_stride_n = (kv_layout == QKVLayout::kNHD) ? num_kv_heads * head_dim : head_dim; + kv_stride_h = (kv_layout == QKVLayout::kNHD) ? head_dim : kv_len * head_dim; + } - __host__ __device__ inline size_t get_q_elem_offset(uint32_t qo_idx, - uint32_t qo_head_idx, - uint32_t feat_idx) const - { - return get_elem_offset_impl(qo_idx, qo_head_idx, feat_idx, q_stride_n, - q_stride_h); - } + __host__ __device__ inline size_t get_q_elem_offset(uint32_t qo_idx, uint32_t qo_head_idx, + uint32_t feat_idx) const { + return get_elem_offset_impl(qo_idx, qo_head_idx, feat_idx, q_stride_n, q_stride_h); + } - __host__ __device__ inline size_t get_o_elem_offset(uint32_t qo_idx, - uint32_t qo_head_idx, - uint32_t feat_idx) const - { - return get_elem_offset_impl(qo_idx, qo_head_idx, feat_idx, - num_qo_heads * head_dim, head_dim); - } + __host__ __device__ inline size_t get_o_elem_offset(uint32_t qo_idx, uint32_t qo_head_idx, + uint32_t feat_idx) const { + return get_elem_offset_impl(qo_idx, qo_head_idx, feat_idx, num_qo_heads * head_dim, head_dim); + } - __host__ __device__ inline size_t - get_kv_elem_offset(uint32_t kv_idx, - uint32_t kv_head_idx, - uint32_t feat_idx) const - { - return get_elem_offset_impl(kv_idx, kv_head_idx, feat_idx, kv_stride_n, - kv_stride_h); - } + __host__ __device__ inline size_t get_kv_elem_offset(uint32_t kv_idx, uint32_t kv_head_idx, + uint32_t feat_idx) const { + return get_elem_offset_impl(kv_idx, kv_head_idx, feat_idx, kv_stride_n, kv_stride_h); + } - __host__ __device__ inline uint32_t get_group_size() const - { - return num_qo_heads / num_kv_heads; - } + __host__ __device__ inline uint32_t get_group_size() const { return num_qo_heads / num_kv_heads; } }; /*! * \brief Convert QKVLayout to string * \param layout The QKVLayout to convert */ -inline std::string QKVLayoutToString(const QKVLayout &layout) -{ - switch (layout) { +inline std::string QKVLayoutToString(const QKVLayout& layout) { + switch (layout) { case QKVLayout::kNHD: - return "NHD"; + return "NHD"; case QKVLayout::kHND: - return "HND"; + return "HND"; default: - return "Unknown"; - } + return "Unknown"; + } } -} // namespace flashinfer -#endif // FLASHINFER_LAYOUT_CUH_ +} // namespace flashinfer +#endif // FLASHINFER_LAYOUT_CUH_ diff --git a/libflashinfer/include/flashinfer/hip/math.hip.h b/libflashinfer/include/flashinfer/hip/math.hip.h index 103430082e..a65fcdfaa0 100644 --- a/libflashinfer/include/flashinfer/hip/math.hip.h +++ b/libflashinfer/include/flashinfer/hip/math.hip.h @@ -9,15 +9,13 @@ #define HIP_ENABLE_WARP_SYNC_BUILTINS 1 +#include #include #include -#include - #include -namespace flashinfer::math -{ +namespace flashinfer::math { // log2(e) constexpr float log2e = 1.44269504088896340736f; @@ -26,33 +24,34 @@ constexpr float loge2 = 0.693147180559945309417f; constexpr float inf = 5e4; -template __forceinline__ __device__ T ptx_exp2(T x); +template +__forceinline__ __device__ T ptx_exp2(T x); /// @brief Wrapper for computing 2 ^ x. We currently do not support a direct /// equivalent of __exp2f() /// @param x Input power to exponentiate /// @return Computes 2 ^ x -template <> __forceinline__ __device__ float ptx_exp2(float x) -{ - return __exp10f(x * __log10f(2.0f)); // Writing 2^x = 10 ^ (x * log_10(2)) +template <> +__forceinline__ __device__ float ptx_exp2(float x) { + return __exp10f(x * __log10f(2.0f)); // Writing 2^x = 10 ^ (x * log_10(2)) } /// @brief Wrapper for computing 2 ^ x. We currently do not support a direct /// equivalent of __exp2f() /// @param x Input power to exponentiate /// @return Computes 2 ^ x -template <> __forceinline__ __device__ __half ptx_exp2<__half>(__half x) -{ - return hexp2(x); +template <> +__forceinline__ __device__ __half ptx_exp2<__half>(__half x) { + return hexp2(x); } /// @brief Wrapper for computing 2 ^ x. We currently do not support a direct /// equivalent of __exp2f() /// @param x Vector of two half dtypes to exponentiate /// @return Computes 2 ^ x -template <> __forceinline__ __device__ __half2 ptx_exp2<__half2>(__half2 x) -{ - return half2(ptx_exp2(x.x), ptx_exp2(x.y)); +template <> +__forceinline__ __device__ __half2 ptx_exp2<__half2>(__half2 x) { + return half2(ptx_exp2(x.x), ptx_exp2(x.y)); } /// @brief Compute log2 @@ -66,16 +65,15 @@ __forceinline__ __device__ float ptx_log2(float x) { return __log2f(x); } __forceinline__ __device__ float ptx_rcp(float x) { return __frcp_rn(x); } template -__forceinline__ __device__ T shfl_xor_sync(T x, int lane_mask) -{ - // FIXME (diptorupd): The shfl_xor_sync is used to implement a butterfly - // reduction pattern. The caller in decode.cuh most likely assumes that the - // warp size is 32 and the lane_mask is going from 16, 8, 4, 2, 1. - // Given that AMDGPU for CDNA3 has a warp size of 64, the lane_mask based on - // the warp size of 32 might lead to incorrect exchanges between the - // threads. The issue requires further investigation, for now I have hard - // coded the warp size to 32 when calling shfl_xor. - return __shfl_xor(x, lane_mask, 32); +__forceinline__ __device__ T shfl_xor_sync(T x, int lane_mask) { + // FIXME (diptorupd): The shfl_xor_sync is used to implement a butterfly + // reduction pattern. The caller in decode.cuh most likely assumes that the + // warp size is 32 and the lane_mask is going from 16, 8, 4, 2, 1. + // Given that AMDGPU for CDNA3 has a warp size of 64, the lane_mask based on + // the warp size of 32 might lead to incorrect exchanges between the + // threads. The issue requires further investigation, for now I have hard + // coded the warp size to 32 when calling shfl_xor. + return __shfl_xor(x, lane_mask, 32); } /// @brief Wrapper for math intrinsic 1/sqrt(x) @@ -83,32 +81,33 @@ __forceinline__ __device__ T shfl_xor_sync(T x, int lane_mask) /// @return Returns 1 / sqrt(x) in round to nearest even mode __forceinline__ __device__ float rsqrt(float x) { return __frsqrt_rn(x); } -template __forceinline__ __device__ T tanh(T x); +template +__forceinline__ __device__ T tanh(T x); /// @brief Compute tanhf(x) /// @param x Input param - float dtype /// @return Returns tanhf(x) /// @note ROCm6.3 does not have a fast tanh or instrincs to support this -template <> __forceinline__ __device__ float tanh(float x) -{ - return tanhf(x); +template <> +__forceinline__ __device__ float tanh(float x) { + return tanhf(x); } /// @brief A utility function to compute tanh for half dtype /// @param x Input param - half /// @return Hyperbolic tangent of x -template <> __forceinline__ __device__ __half tanh<__half>(__half x) -{ - return __float2half(tanh(__half2float(x))); +template <> +__forceinline__ __device__ __half tanh<__half>(__half x) { + return __float2half(tanh(__half2float(x))); } /// @brief Compute hyperbolic tangent for a vector of two half dtype /// @param x Vector of two half dtypes /// @return Hyperbolic tangent of x -template <> __forceinline__ __device__ __half2 tanh<__half2>(__half2 x) -{ - return __half2(tanh(x.x), tanh(x.y)); +template <> +__forceinline__ __device__ __half2 tanh<__half2>(__half2 x) { + return __half2(tanh(x.x), tanh(x.y)); } -} // namespace flashinfer::math -#endif // FLASHINFER_MATH_CUH_ +} // namespace flashinfer::math +#endif // FLASHINFER_MATH_CUH_ diff --git a/libflashinfer/include/flashinfer/hip/norm.hip.h b/libflashinfer/include/flashinfer/hip/norm.hip.h index c8e7ef4a96..961df1474a 100644 --- a/libflashinfer/include/flashinfer/hip/norm.hip.h +++ b/libflashinfer/include/flashinfer/hip/norm.hip.h @@ -12,344 +12,276 @@ #include "utils.hip.h" #include "vec_dtypes.hip.h" -namespace flashinfer -{ +namespace flashinfer { -namespace norm -{ +namespace norm { template -__global__ void RMSNormKernel(T *__restrict__ input, - T *__restrict__ weight, - T *__restrict__ output, - const uint32_t d, - const uint32_t stride_input, - const uint32_t stride_output, - float weight_bias, - float eps) -{ - const uint32_t bx = blockIdx.x; - const uint32_t tx = threadIdx.x, ty = threadIdx.y; - constexpr uint32_t warp_size = 32; - const uint32_t num_warps = blockDim.y; - // NOTE(Zihao): it's guaranteed that num_warps should be smaller than 32 - const uint32_t thread_id = tx + ty * warp_size; - const uint32_t num_threads = num_warps * warp_size; - const uint32_t rounds = ceil_div(d, VEC_SIZE * num_threads); - extern __shared__ float smem[]; - - float sum_sq = 0.f; - -#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && \ - (__CUDA_ARCH__ >= 900)) - asm volatile("griddepcontrol.wait;"); +__global__ void RMSNormKernel(T* __restrict__ input, T* __restrict__ weight, T* __restrict__ output, + const uint32_t d, const uint32_t stride_input, + const uint32_t stride_output, float weight_bias, float eps) { + const uint32_t bx = blockIdx.x; + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + constexpr uint32_t warp_size = 32; + const uint32_t num_warps = blockDim.y; + // NOTE(Zihao): it's guaranteed that num_warps should be smaller than 32 + const uint32_t thread_id = tx + ty * warp_size; + const uint32_t num_threads = num_warps * warp_size; + const uint32_t rounds = ceil_div(d, VEC_SIZE * num_threads); + extern __shared__ float smem[]; + + float sum_sq = 0.f; + +#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.wait;"); #endif - for (uint32_t i = 0; i < rounds; i++) { - vec_t input_vec; - input_vec.fill(0.0); - if ((i * num_threads + thread_id) * VEC_SIZE < d) { - input_vec.load(input + bx * stride_input + - i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); - } + for (uint32_t i = 0; i < rounds; i++) { + vec_t input_vec; + input_vec.fill(0.0); + if ((i * num_threads + thread_id) * VEC_SIZE < d) { + input_vec.load(input + bx * stride_input + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); + } #pragma unroll - for (uint32_t j = 0; j < VEC_SIZE; j++) { - sum_sq += float(input_vec[j]) * float(input_vec[j]); - } + for (uint32_t j = 0; j < VEC_SIZE; j++) { + sum_sq += float(input_vec[j]) * float(input_vec[j]); } + } - // first, warp reduce sum + // first, warp reduce sum +#pragma unroll + for (uint32_t offset = warp_size / 2; offset > 0; offset /= 2) { + sum_sq += math::shfl_xor_sync(sum_sq, offset); + } + + smem[ty] = sum_sq; + __syncthreads(); + // then, cross warp reduce sum using only the first warp + if (ty == 0) { + sum_sq = (tx < num_warps) ? smem[tx] : 0.f; #pragma unroll for (uint32_t offset = warp_size / 2; offset > 0; offset /= 2) { - sum_sq += math::shfl_xor_sync(sum_sq, offset); + sum_sq += math::shfl_xor_sync(sum_sq, offset); } - - smem[ty] = sum_sq; - __syncthreads(); - // then, cross warp reduce sum using only the first warp - if (ty == 0) { - sum_sq = (tx < num_warps) ? smem[tx] : 0.f; -#pragma unroll - for (uint32_t offset = warp_size / 2; offset > 0; offset /= 2) { - sum_sq += math::shfl_xor_sync(sum_sq, offset); - } - smem[0] = sum_sq; + smem[0] = sum_sq; + } + __syncthreads(); + + float rms_rcp = math::rsqrt(smem[0] / float(d) + eps); + + for (uint32_t i = 0; i < rounds; i++) { + vec_t input_vec; + vec_t weight_vec; + vec_t output_vec; + input_vec.fill(0.f); + weight_vec.fill(0.f); + if ((i * num_threads + thread_id) * VEC_SIZE < d) { + input_vec.load(input + bx * stride_input + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); + weight_vec.load(weight + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); } - __syncthreads(); - - float rms_rcp = math::rsqrt(smem[0] / float(d) + eps); - - for (uint32_t i = 0; i < rounds; i++) { - vec_t input_vec; - vec_t weight_vec; - vec_t output_vec; - input_vec.fill(0.f); - weight_vec.fill(0.f); - if ((i * num_threads + thread_id) * VEC_SIZE < d) { - input_vec.load(input + bx * stride_input + - i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); - weight_vec.load(weight + i * num_threads * VEC_SIZE + - thread_id * VEC_SIZE); - } #pragma unroll - for (uint32_t j = 0; j < VEC_SIZE; j++) { - output_vec[j] = float(input_vec[j]) * rms_rcp * - (weight_bias + float(weight_vec[j])); - } - if ((i * num_threads + thread_id) * VEC_SIZE < d) { - output_vec.store(output + bx * stride_output + - i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); - } + for (uint32_t j = 0; j < VEC_SIZE; j++) { + output_vec[j] = float(input_vec[j]) * rms_rcp * (weight_bias + float(weight_vec[j])); } -#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && \ - (__CUDA_ARCH__ >= 900)) - asm volatile("griddepcontrol.launch_dependents;"); + if ((i * num_threads + thread_id) * VEC_SIZE < d) { + output_vec.store(output + bx * stride_output + i * num_threads * VEC_SIZE + + thread_id * VEC_SIZE); + } + } +#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.launch_dependents;"); #endif } template -hipError_t RMSNorm(T *input, - T *weight, - T *output, - uint32_t batch_size, - uint32_t d, - uint32_t stride_input, - uint32_t stride_output, - float eps = 1e-5, - bool enable_pdl = false, - hipStream_t stream = 0) -{ - const uint32_t vec_size = std::gcd(16 / sizeof(T), d); - - const uint32_t block_size = std::min(1024, d / vec_size); - const uint32_t num_warps = ceil_div(block_size, 32); - dim3 nblks(batch_size); - dim3 nthrs(32, num_warps); - const uint32_t smem_size = num_warps * sizeof(float); - float weight_bias = 0.f; - - DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { - auto kernel = RMSNormKernel; - hipFuncSetAttribute( - (void*)kernel, hipFuncAttributeMaxDynamicSharedMemorySize, smem_size); - RMSNormKernel<<>>( - input, weight, output, d, stride_input, stride_output, weight_bias, - eps); - }); - return hipSuccess; +hipError_t RMSNorm(T* input, T* weight, T* output, uint32_t batch_size, uint32_t d, + uint32_t stride_input, uint32_t stride_output, float eps = 1e-5, + bool enable_pdl = false, hipStream_t stream = 0) { + const uint32_t vec_size = std::gcd(16 / sizeof(T), d); + + const uint32_t block_size = std::min(1024, d / vec_size); + const uint32_t num_warps = ceil_div(block_size, 32); + dim3 nblks(batch_size); + dim3 nthrs(32, num_warps); + const uint32_t smem_size = num_warps * sizeof(float); + float weight_bias = 0.f; + + DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { + auto kernel = RMSNormKernel; + hipFuncSetAttribute((void*)kernel, hipFuncAttributeMaxDynamicSharedMemorySize, smem_size); + RMSNormKernel<<>>( + input, weight, output, d, stride_input, stride_output, weight_bias, eps); + }); + return hipSuccess; } template -__global__ void FusedAddRMSNormKernel(T *__restrict__ input, - T *__restrict__ residual, - T *__restrict__ weight, - const uint32_t d, - const uint32_t stride_input, - const uint32_t stride_residual, - float weight_bias, - float eps) -{ - const uint32_t bx = blockIdx.x; - const uint32_t tx = threadIdx.x, ty = threadIdx.y; - constexpr uint32_t warp_size = 32; - const uint32_t num_warps = blockDim.y; - const uint32_t thread_id = tx + ty * warp_size; - const uint32_t num_threads = num_warps * warp_size; - const uint32_t rounds = ceil_div(d, VEC_SIZE * num_threads); - extern __shared__ float smem[]; - float *smem_x = smem + ceil_div(num_warps, 4) * 4; - - float sum_sq = 0.f; -#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && \ - (__CUDA_ARCH__ >= 900)) - asm volatile("griddepcontrol.wait;"); +__global__ void FusedAddRMSNormKernel(T* __restrict__ input, T* __restrict__ residual, + T* __restrict__ weight, const uint32_t d, + const uint32_t stride_input, const uint32_t stride_residual, + float weight_bias, float eps) { + const uint32_t bx = blockIdx.x; + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + constexpr uint32_t warp_size = 32; + const uint32_t num_warps = blockDim.y; + const uint32_t thread_id = tx + ty * warp_size; + const uint32_t num_threads = num_warps * warp_size; + const uint32_t rounds = ceil_div(d, VEC_SIZE * num_threads); + extern __shared__ float smem[]; + float* smem_x = smem + ceil_div(num_warps, 4) * 4; + + float sum_sq = 0.f; +#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.wait;"); #endif - for (uint32_t i = 0; i < rounds; i++) { - vec_t input_vec; - input_vec.fill(0.f); - vec_t residual_vec; - residual_vec.fill(0.f); - vec_t x_vec; - x_vec.fill(0.f); - if ((i * num_threads + thread_id) * VEC_SIZE < d) { - input_vec.load(input + bx * stride_input + - i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); - residual_vec.load(residual + bx * stride_residual + - i * num_threads * VEC_SIZE + - thread_id * VEC_SIZE); - } -#pragma unroll - for (uint32_t j = 0; j < VEC_SIZE; j++) { - float x = float(input_vec[j]); - x += float(residual_vec[j]); - sum_sq += x * x; - residual_vec[j] = (T)x; - x_vec[j] = x; - } - if ((i * num_threads + thread_id) * VEC_SIZE < d) { - residual_vec.store(residual + bx * stride_residual + - i * num_threads * VEC_SIZE + - thread_id * VEC_SIZE); - x_vec.store(smem_x + i * num_threads * VEC_SIZE + + for (uint32_t i = 0; i < rounds; i++) { + vec_t input_vec; + input_vec.fill(0.f); + vec_t residual_vec; + residual_vec.fill(0.f); + vec_t x_vec; + x_vec.fill(0.f); + if ((i * num_threads + thread_id) * VEC_SIZE < d) { + input_vec.load(input + bx * stride_input + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); + residual_vec.load(residual + bx * stride_residual + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); - } } - - // first, warp reduce sum #pragma unroll - for (uint32_t offset = warp_size / 2; offset > 0; offset /= 2) { - sum_sq += math::shfl_xor_sync(sum_sq, offset); + for (uint32_t j = 0; j < VEC_SIZE; j++) { + float x = float(input_vec[j]); + x += float(residual_vec[j]); + sum_sq += x * x; + residual_vec[j] = (T)x; + x_vec[j] = x; + } + if ((i * num_threads + thread_id) * VEC_SIZE < d) { + residual_vec.store(residual + bx * stride_residual + i * num_threads * VEC_SIZE + + thread_id * VEC_SIZE); + x_vec.store(smem_x + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); } + } - smem[ty] = sum_sq; - __syncthreads(); - // then, cross warp reduce sum using only the first warp - if (ty == 0) { - sum_sq = (tx < num_warps) ? smem[tx] : 0.f; + // first, warp reduce sum #pragma unroll - for (uint32_t offset = warp_size / 2; offset > 0; offset /= 2) { - sum_sq += math::shfl_xor_sync(sum_sq, offset); - } - smem[0] = sum_sq; + for (uint32_t offset = warp_size / 2; offset > 0; offset /= 2) { + sum_sq += math::shfl_xor_sync(sum_sq, offset); + } + + smem[ty] = sum_sq; + __syncthreads(); + // then, cross warp reduce sum using only the first warp + if (ty == 0) { + sum_sq = (tx < num_warps) ? smem[tx] : 0.f; +#pragma unroll + for (uint32_t offset = warp_size / 2; offset > 0; offset /= 2) { + sum_sq += math::shfl_xor_sync(sum_sq, offset); + } + smem[0] = sum_sq; + } + __syncthreads(); + + float rms_rcp = math::rsqrt(smem[0] / float(d) + eps); + + for (uint32_t i = 0; i < rounds; i++) { + vec_t input_vec; + vec_t weight_vec; + vec_t x_vec; + input_vec.fill(0.f); + weight_vec.fill(0.f); + x_vec.fill(0.f); + if ((i * num_threads + thread_id) * VEC_SIZE < d) { + weight_vec.load(weight + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); + x_vec.load(smem_x + i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); } - __syncthreads(); - - float rms_rcp = math::rsqrt(smem[0] / float(d) + eps); - - for (uint32_t i = 0; i < rounds; i++) { - vec_t input_vec; - vec_t weight_vec; - vec_t x_vec; - input_vec.fill(0.f); - weight_vec.fill(0.f); - x_vec.fill(0.f); - if ((i * num_threads + thread_id) * VEC_SIZE < d) { - weight_vec.load(weight + i * num_threads * VEC_SIZE + - thread_id * VEC_SIZE); - x_vec.load(smem_x + i * num_threads * VEC_SIZE + - thread_id * VEC_SIZE); - } #pragma unroll - for (uint32_t j = 0; j < VEC_SIZE; j++) { - input_vec[j] = - x_vec[j] * rms_rcp * (weight_bias + float(weight_vec[j])); - } - if ((i * num_threads + thread_id) * VEC_SIZE < d) { - input_vec.store(input + bx * stride_input + - i * num_threads * VEC_SIZE + thread_id * VEC_SIZE); - } + for (uint32_t j = 0; j < VEC_SIZE; j++) { + input_vec[j] = x_vec[j] * rms_rcp * (weight_bias + float(weight_vec[j])); + } + if ((i * num_threads + thread_id) * VEC_SIZE < d) { + input_vec.store(input + bx * stride_input + i * num_threads * VEC_SIZE + + thread_id * VEC_SIZE); } -#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && \ - (__CUDA_ARCH__ >= 900)) - asm volatile("griddepcontrol.launch_dependents;"); + } +#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.launch_dependents;"); #endif } template -hipError_t FusedAddRMSNorm(T *input, - T *residual, - T *weight, - uint32_t batch_size, - uint32_t d, - uint32_t stride_input, - uint32_t stride_residual, - float eps = 1e-5, - bool enable_pdl = false, - hipStream_t stream = 0) -{ - const uint32_t vec_size = std::gcd(16 / sizeof(T), d); - - const uint32_t block_size = std::min(1024, d / vec_size); - const uint32_t num_warps = ceil_div(block_size, 32); - dim3 nblks(batch_size); - dim3 nthrs(32, num_warps); - const uint32_t smem_size = (ceil_div(num_warps, 4) * 4 + d) * sizeof(float); - float weight_bias = 0.f; - void *args[] = {&input, &residual, &weight, &d, - &stride_input, &stride_residual, &weight_bias, &eps}; - - DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { - auto kernel = FusedAddRMSNormKernel; - hipFuncSetAttribute( - (void*)kernel, hipFuncAttributeMaxDynamicSharedMemorySize, smem_size); - FusedAddRMSNormKernel<<>>( - input, residual, weight, d, stride_input, stride_residual, - weight_bias, eps); - }); - - return hipSuccess; +hipError_t FusedAddRMSNorm(T* input, T* residual, T* weight, uint32_t batch_size, uint32_t d, + uint32_t stride_input, uint32_t stride_residual, float eps = 1e-5, + bool enable_pdl = false, hipStream_t stream = 0) { + const uint32_t vec_size = std::gcd(16 / sizeof(T), d); + + const uint32_t block_size = std::min(1024, d / vec_size); + const uint32_t num_warps = ceil_div(block_size, 32); + dim3 nblks(batch_size); + dim3 nthrs(32, num_warps); + const uint32_t smem_size = (ceil_div(num_warps, 4) * 4 + d) * sizeof(float); + float weight_bias = 0.f; + void* args[] = {&input, &residual, &weight, &d, + &stride_input, &stride_residual, &weight_bias, &eps}; + + DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { + auto kernel = FusedAddRMSNormKernel; + hipFuncSetAttribute((void*)kernel, hipFuncAttributeMaxDynamicSharedMemorySize, smem_size); + FusedAddRMSNormKernel<<>>( + input, residual, weight, d, stride_input, stride_residual, weight_bias, eps); + }); + + return hipSuccess; } template -hipError_t GemmaRMSNorm(T *input, - T *weight, - T *output, - uint32_t batch_size, - uint32_t d, - uint32_t stride_input, - uint32_t stride_output, - float eps = 1e-5, - bool enable_pdl = false, - hipStream_t stream = 0) -{ - const uint32_t vec_size = std::gcd(16 / sizeof(T), d); - - const uint32_t block_size = std::min(1024, d / vec_size); - const uint32_t num_warps = ceil_div(block_size, 32); - dim3 nblks(batch_size); - dim3 nthrs(32, num_warps); - const uint32_t smem_size = num_warps * sizeof(float); - float weight_bias = 1.f; - - DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { - auto kernel = RMSNormKernel; - hipFuncSetAttribute( - (void*)kernel, hipFuncAttributeMaxDynamicSharedMemorySize, smem_size); - RMSNormKernel<<>>( - input, weight, output, d, stride_input, stride_output, weight_bias, - eps); - }); - return hipSuccess; +hipError_t GemmaRMSNorm(T* input, T* weight, T* output, uint32_t batch_size, uint32_t d, + uint32_t stride_input, uint32_t stride_output, float eps = 1e-5, + bool enable_pdl = false, hipStream_t stream = 0) { + const uint32_t vec_size = std::gcd(16 / sizeof(T), d); + + const uint32_t block_size = std::min(1024, d / vec_size); + const uint32_t num_warps = ceil_div(block_size, 32); + dim3 nblks(batch_size); + dim3 nthrs(32, num_warps); + const uint32_t smem_size = num_warps * sizeof(float); + float weight_bias = 1.f; + + DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { + auto kernel = RMSNormKernel; + hipFuncSetAttribute((void*)kernel, hipFuncAttributeMaxDynamicSharedMemorySize, smem_size); + RMSNormKernel<<>>( + input, weight, output, d, stride_input, stride_output, weight_bias, eps); + }); + return hipSuccess; } template -hipError_t GemmaFusedAddRMSNorm(T *input, - T *residual, - T *weight, - uint32_t batch_size, - uint32_t d, - uint32_t stride_input, - uint32_t stride_residual, - float eps = 1e-5, - bool enable_pdl = false, - hipStream_t stream = 0) -{ - const uint32_t vec_size = std::gcd(16 / sizeof(T), d); - - const uint32_t block_size = std::min(1024, d / vec_size); - const uint32_t num_warps = ceil_div(block_size, 32); - dim3 nblks(batch_size); - dim3 nthrs(32, num_warps); - // NOTE(Zihao): use ceil_div(num_warps, 4) * 4 for address alignment to 16 - // bytes - const uint32_t smem_size = (ceil_div(num_warps, 4) * 4 + d) * sizeof(float); - float weight_bias = 1.f; - - DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { - auto kernel = FusedAddRMSNormKernel; - hipFuncSetAttribute( - (void*)kernel, hipFuncAttributeMaxDynamicSharedMemorySize, smem_size); - FusedAddRMSNormKernel<<>>( - input, residual, weight, d, stride_input, stride_residual, - weight_bias, eps); - }); - - return hipSuccess; +hipError_t GemmaFusedAddRMSNorm(T* input, T* residual, T* weight, uint32_t batch_size, uint32_t d, + uint32_t stride_input, uint32_t stride_residual, float eps = 1e-5, + bool enable_pdl = false, hipStream_t stream = 0) { + const uint32_t vec_size = std::gcd(16 / sizeof(T), d); + + const uint32_t block_size = std::min(1024, d / vec_size); + const uint32_t num_warps = ceil_div(block_size, 32); + dim3 nblks(batch_size); + dim3 nthrs(32, num_warps); + // NOTE(Zihao): use ceil_div(num_warps, 4) * 4 for address alignment to 16 + // bytes + const uint32_t smem_size = (ceil_div(num_warps, 4) * 4 + d) * sizeof(float); + float weight_bias = 1.f; + + DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { + auto kernel = FusedAddRMSNormKernel; + hipFuncSetAttribute((void*)kernel, hipFuncAttributeMaxDynamicSharedMemorySize, smem_size); + FusedAddRMSNormKernel<<>>( + input, residual, weight, d, stride_input, stride_residual, weight_bias, eps); + }); + + return hipSuccess; } -} // namespace norm +} // namespace norm -} // namespace flashinfer +} // namespace flashinfer -#endif // FLASHINFER_NORM_CUH_ +#endif // FLASHINFER_NORM_CUH_ diff --git a/libflashinfer/include/flashinfer/hip/page.hip.h b/libflashinfer/include/flashinfer/hip/page.hip.h index 544990ebfe..6d9c656015 100644 --- a/libflashinfer/include/flashinfer/hip/page.hip.h +++ b/libflashinfer/include/flashinfer/hip/page.hip.h @@ -18,10 +18,10 @@ #include #include #include + #include -namespace flashinfer -{ +namespace flashinfer { /*! * \brief Paged key-value cache @@ -29,218 +29,186 @@ namespace flashinfer * \tparam DType The data type of the key-value cache * \tparam IdType The index data type of the kv-cache */ -template struct paged_kv_t -{ - uint_fastdiv page_size; - uint32_t num_heads; - uint32_t head_dim; - uint32_t batch_size; - uint32_t stride_page; - uint32_t stride_n; - uint32_t stride_h; - - // Internal layout: - // [max_num_pages, num_heads, page_size, head_dim] if layout == HND - // [max_num_pages, page_size, num_heads, head_dim] if layout == NHD - DType *k_data; - DType *v_data; - IdType *indices; - - // [batch_size + 1] The page indptr array, with the first element 0, the - // last element nnz_pages - IdType *indptr; - // [batch_size] The offset of the last page for each request in the batch - IdType *last_page_len; - // [batch_size] The start position of each request in the batch. - IdType *rope_pos_offset; - - /*! - * \brief Construct an empty paged key-value cache - */ - __host__ __device__ __forceinline__ paged_kv_t() - : num_heads(0), page_size(), head_dim(0), batch_size(0), stride_page(0), - stride_n(0), stride_h(0), k_data(nullptr), v_data(nullptr), - indices(nullptr), indptr(nullptr), last_page_len(nullptr), - rope_pos_offset(nullptr) - { - } - - /*! - * \brief Construct a paged key-value cache - * \param num_heads The number of heads - * \param page_size The size of each page - * \param head_dim The dimension of each head - * \param batch_size The batch size - * \param layout The layout of last 3 dimensions in KV-Cache. - * \param k_data The start pointer of key cache, k_cache should be - * contiguous - * \param v_data The start pointer of value cache, v_cache should be - * contiguous - * \param indices The page indices array - * \param indptr The page indptr array - * \param last_page_len The offset of the last page for each request in the - * batch - * \param rope_pos_offset The start position of each request in the batch. - */ - __host__ __forceinline__ paged_kv_t(uint32_t num_heads, - uint32_t page_size, - uint32_t head_dim, - uint32_t batch_size, - QKVLayout layout, - DType *k_data, - DType *v_data, - IdType *indices, - IdType *indptr, - IdType *last_page_len, - IdType *rope_pos_offset = nullptr) - : num_heads(num_heads), page_size(page_size), head_dim(head_dim), - batch_size(batch_size), indices(indices), indptr(indptr), - last_page_len(last_page_len), rope_pos_offset(rope_pos_offset) - { - stride_page = num_heads * page_size * head_dim; - this->k_data = k_data; - this->v_data = v_data; - stride_n = layout == QKVLayout::kHND ? head_dim : num_heads * head_dim; - stride_h = layout == QKVLayout::kHND ? page_size * head_dim : head_dim; - } - - /*! - * \brief Construct a paged key-value cache with custom kv-cache strides - * \param num_heads The number of heads - * \param page_size The size of each page - * \param head_dim The dimension of each head - * \param batch_size The batch size - * \param layout The layout of last 3 dimensions in KV-Cache. - * \param k_data The start pointer of key cache, k_cache doesn't have to be - * contiguous - * \param v_data The start pointer of value cache, v_cache doesn't have to - * be contiguous - * \param kv_strides custom strides of each dimensions of k_data and v_data - * \param indices The page indices array - * \param indptr The page indptr array - * \param last_page_len The offset of the last page for each request in the - * batch - * \param rope_pos_offset The start position of each request in the batch. - */ - __host__ __forceinline__ paged_kv_t(uint32_t num_heads, - uint32_t page_size, - uint32_t head_dim, - uint32_t batch_size, - QKVLayout layout, - DType *k_data, - DType *v_data, - const int64_t *kv_strides, - IdType *indices, - IdType *indptr, - IdType *last_page_len, - IdType *rope_pos_offset = nullptr) - : num_heads(num_heads), page_size(page_size), head_dim(head_dim), - batch_size(batch_size), indices(indices), indptr(indptr), - last_page_len(last_page_len), rope_pos_offset(rope_pos_offset) - { - stride_page = kv_strides[0]; - this->k_data = k_data; - this->v_data = v_data; - stride_n = layout == QKVLayout::kHND ? kv_strides[2] : kv_strides[1]; - stride_h = layout == QKVLayout::kHND ? kv_strides[1] : kv_strides[2]; - } - - __host__ __device__ __forceinline__ uint32_t - get_length(uint32_t batch_idx) const - { - if (indptr[batch_idx + 1] == indptr[batch_idx]) { - return 0; - } - return (indptr[batch_idx + 1] - indptr[batch_idx] - 1) * page_size + - last_page_len[batch_idx]; - } - - /*! - * \brief Compute the offset of element in the allocated buffer. - * \param page_idx The page index - * \param head_idx The head index - * \param entry_idx The page entry index - * \param feat_idx The feature index - */ - __host__ __device__ __forceinline__ size_t - get_elem_offset(size_t page_idx, - size_t head_idx, - size_t entry_idx, - size_t feat_idx) const - { - return page_idx * stride_page + head_idx * stride_h + - entry_idx * stride_n + feat_idx; - } - - /*! - * \brief Compute the offset of element inside the page. - * \param head_idx The head index - * \param entry_idx The page entry index - * \param feat_idx The feature index - */ - __host__ __device__ __forceinline__ size_t - get_elem_offset_in_page(size_t head_idx, - size_t entry_idx, - size_t feat_idx) const - { - return head_idx * stride_h + entry_idx * stride_n + feat_idx; - } - - __device__ __forceinline__ DType *get_k_ptr(IdType page_iter, - uint32_t head_idx, - uint32_t entry_idx, - uint32_t feat_idx) const - { - return k_data + get_elem_offset(__ldg(indices + page_iter), head_idx, - entry_idx, feat_idx); - } - - __device__ __forceinline__ size_t - protective_get_kv_offset(IdType page_iter, - uint32_t head_idx, - uint32_t entry_idx, - uint32_t feat_idx, - IdType last_indptr) const - { - if (page_iter < last_indptr) { - return get_elem_offset(__ldg(indices + page_iter), head_idx, - entry_idx, feat_idx); - } - else { - return 0; - } - } - - __device__ __forceinline__ DType * - protective_get_k_ptr(IdType page_iter, - uint32_t head_idx, - uint32_t entry_idx, - uint32_t feat_idx, - IdType last_indptr) const - { - return k_data + protective_get_kv_offset(page_iter, head_idx, entry_idx, - feat_idx, last_indptr); - } - - __device__ __forceinline__ DType *get_v_ptr(IdType page_iter, - uint32_t head_idx, - uint32_t entry_idx, - uint32_t feat_idx) const - { - return v_data + get_elem_offset(__ldg(indices + page_iter), head_idx, - entry_idx, feat_idx); +template +struct paged_kv_t { + uint_fastdiv page_size; + uint32_t num_heads; + uint32_t head_dim; + uint32_t batch_size; + uint32_t stride_page; + uint32_t stride_n; + uint32_t stride_h; + + // Internal layout: + // [max_num_pages, num_heads, page_size, head_dim] if layout == HND + // [max_num_pages, page_size, num_heads, head_dim] if layout == NHD + DType* k_data; + DType* v_data; + IdType* indices; + + // [batch_size + 1] The page indptr array, with the first element 0, the + // last element nnz_pages + IdType* indptr; + // [batch_size] The offset of the last page for each request in the batch + IdType* last_page_len; + // [batch_size] The start position of each request in the batch. + IdType* rope_pos_offset; + + /*! + * \brief Construct an empty paged key-value cache + */ + __host__ __device__ __forceinline__ paged_kv_t() + : num_heads(0), + page_size(), + head_dim(0), + batch_size(0), + stride_page(0), + stride_n(0), + stride_h(0), + k_data(nullptr), + v_data(nullptr), + indices(nullptr), + indptr(nullptr), + last_page_len(nullptr), + rope_pos_offset(nullptr) {} + + /*! + * \brief Construct a paged key-value cache + * \param num_heads The number of heads + * \param page_size The size of each page + * \param head_dim The dimension of each head + * \param batch_size The batch size + * \param layout The layout of last 3 dimensions in KV-Cache. + * \param k_data The start pointer of key cache, k_cache should be + * contiguous + * \param v_data The start pointer of value cache, v_cache should be + * contiguous + * \param indices The page indices array + * \param indptr The page indptr array + * \param last_page_len The offset of the last page for each request in the + * batch + * \param rope_pos_offset The start position of each request in the batch. + */ + __host__ __forceinline__ paged_kv_t(uint32_t num_heads, uint32_t page_size, uint32_t head_dim, + uint32_t batch_size, QKVLayout layout, DType* k_data, + DType* v_data, IdType* indices, IdType* indptr, + IdType* last_page_len, IdType* rope_pos_offset = nullptr) + : num_heads(num_heads), + page_size(page_size), + head_dim(head_dim), + batch_size(batch_size), + indices(indices), + indptr(indptr), + last_page_len(last_page_len), + rope_pos_offset(rope_pos_offset) { + stride_page = num_heads * page_size * head_dim; + this->k_data = k_data; + this->v_data = v_data; + stride_n = layout == QKVLayout::kHND ? head_dim : num_heads * head_dim; + stride_h = layout == QKVLayout::kHND ? page_size * head_dim : head_dim; + } + + /*! + * \brief Construct a paged key-value cache with custom kv-cache strides + * \param num_heads The number of heads + * \param page_size The size of each page + * \param head_dim The dimension of each head + * \param batch_size The batch size + * \param layout The layout of last 3 dimensions in KV-Cache. + * \param k_data The start pointer of key cache, k_cache doesn't have to be + * contiguous + * \param v_data The start pointer of value cache, v_cache doesn't have to + * be contiguous + * \param kv_strides custom strides of each dimensions of k_data and v_data + * \param indices The page indices array + * \param indptr The page indptr array + * \param last_page_len The offset of the last page for each request in the + * batch + * \param rope_pos_offset The start position of each request in the batch. + */ + __host__ __forceinline__ paged_kv_t(uint32_t num_heads, uint32_t page_size, uint32_t head_dim, + uint32_t batch_size, QKVLayout layout, DType* k_data, + DType* v_data, const int64_t* kv_strides, IdType* indices, + IdType* indptr, IdType* last_page_len, + IdType* rope_pos_offset = nullptr) + : num_heads(num_heads), + page_size(page_size), + head_dim(head_dim), + batch_size(batch_size), + indices(indices), + indptr(indptr), + last_page_len(last_page_len), + rope_pos_offset(rope_pos_offset) { + stride_page = kv_strides[0]; + this->k_data = k_data; + this->v_data = v_data; + stride_n = layout == QKVLayout::kHND ? kv_strides[2] : kv_strides[1]; + stride_h = layout == QKVLayout::kHND ? kv_strides[1] : kv_strides[2]; + } + + __host__ __device__ __forceinline__ uint32_t get_length(uint32_t batch_idx) const { + if (indptr[batch_idx + 1] == indptr[batch_idx]) { + return 0; } - - __device__ __forceinline__ DType * - protective_get_v_ptr(IdType page_iter, - uint32_t head_idx, - uint32_t entry_idx, - uint32_t feat_idx, - IdType last_indptr) const - { - return v_data + protective_get_kv_offset(page_iter, head_idx, entry_idx, - feat_idx, last_indptr); + return (indptr[batch_idx + 1] - indptr[batch_idx] - 1) * page_size + last_page_len[batch_idx]; + } + + /*! + * \brief Compute the offset of element in the allocated buffer. + * \param page_idx The page index + * \param head_idx The head index + * \param entry_idx The page entry index + * \param feat_idx The feature index + */ + __host__ __device__ __forceinline__ size_t get_elem_offset(size_t page_idx, size_t head_idx, + size_t entry_idx, + size_t feat_idx) const { + return page_idx * stride_page + head_idx * stride_h + entry_idx * stride_n + feat_idx; + } + + /*! + * \brief Compute the offset of element inside the page. + * \param head_idx The head index + * \param entry_idx The page entry index + * \param feat_idx The feature index + */ + __host__ __device__ __forceinline__ size_t get_elem_offset_in_page(size_t head_idx, + size_t entry_idx, + size_t feat_idx) const { + return head_idx * stride_h + entry_idx * stride_n + feat_idx; + } + + __device__ __forceinline__ DType* get_k_ptr(IdType page_iter, uint32_t head_idx, + uint32_t entry_idx, uint32_t feat_idx) const { + return k_data + get_elem_offset(__ldg(indices + page_iter), head_idx, entry_idx, feat_idx); + } + + __device__ __forceinline__ size_t protective_get_kv_offset(IdType page_iter, uint32_t head_idx, + uint32_t entry_idx, uint32_t feat_idx, + IdType last_indptr) const { + if (page_iter < last_indptr) { + return get_elem_offset(__ldg(indices + page_iter), head_idx, entry_idx, feat_idx); + } else { + return 0; } + } + + __device__ __forceinline__ DType* protective_get_k_ptr(IdType page_iter, uint32_t head_idx, + uint32_t entry_idx, uint32_t feat_idx, + IdType last_indptr) const { + return k_data + protective_get_kv_offset(page_iter, head_idx, entry_idx, feat_idx, last_indptr); + } + + __device__ __forceinline__ DType* get_v_ptr(IdType page_iter, uint32_t head_idx, + uint32_t entry_idx, uint32_t feat_idx) const { + return v_data + get_elem_offset(__ldg(indices + page_iter), head_idx, entry_idx, feat_idx); + } + + __device__ __forceinline__ DType* protective_get_v_ptr(IdType page_iter, uint32_t head_idx, + uint32_t entry_idx, uint32_t feat_idx, + IdType last_indptr) const { + return v_data + protective_get_kv_offset(page_iter, head_idx, entry_idx, feat_idx, last_indptr); + } }; /*! @@ -255,36 +223,27 @@ template struct paged_kv_t * \param value The value to be appended */ template -__global__ void -AppendPagedKVCacheDecodeKernel(paged_kv_t paged_kv, - DType *__restrict__ key, - DType *__restrict__ value) -{ - uint32_t tx = threadIdx.x, ty = threadIdx.y; - uint32_t num_heads = paged_kv.num_heads; - uint32_t batch_idx = blockIdx.x; - uint32_t head_idx = ty; - - uint32_t seq_len = - (paged_kv.indptr[batch_idx + 1] - paged_kv.indptr[batch_idx] - 1) * - paged_kv.page_size + - paged_kv.last_page_len[batch_idx]; - - uint32_t page_iter = - paged_kv.indptr[batch_idx] + (seq_len - 1) / paged_kv.page_size; - uint32_t entry_idx = (seq_len - 1) % paged_kv.page_size; - - DType *k_ptr = - paged_kv.get_k_ptr(page_iter, head_idx, entry_idx, tx * vec_size); - DType *v_ptr = - paged_kv.get_v_ptr(page_iter, head_idx, entry_idx, tx * vec_size); - vec_t::memcpy( - k_ptr, - key + (batch_idx * num_heads + head_idx) * head_dim + tx * vec_size); - - vec_t::memcpy( - v_ptr, - value + (batch_idx * num_heads + head_idx) * head_dim + tx * vec_size); +__global__ void AppendPagedKVCacheDecodeKernel(paged_kv_t paged_kv, + DType* __restrict__ key, DType* __restrict__ value) { + uint32_t tx = threadIdx.x, ty = threadIdx.y; + uint32_t num_heads = paged_kv.num_heads; + uint32_t batch_idx = blockIdx.x; + uint32_t head_idx = ty; + + uint32_t seq_len = + (paged_kv.indptr[batch_idx + 1] - paged_kv.indptr[batch_idx] - 1) * paged_kv.page_size + + paged_kv.last_page_len[batch_idx]; + + uint32_t page_iter = paged_kv.indptr[batch_idx] + (seq_len - 1) / paged_kv.page_size; + uint32_t entry_idx = (seq_len - 1) % paged_kv.page_size; + + DType* k_ptr = paged_kv.get_k_ptr(page_iter, head_idx, entry_idx, tx * vec_size); + DType* v_ptr = paged_kv.get_v_ptr(page_iter, head_idx, entry_idx, tx * vec_size); + vec_t::memcpy( + k_ptr, key + (batch_idx * num_heads + head_idx) * head_dim + tx * vec_size); + + vec_t::memcpy( + v_ptr, value + (batch_idx * num_heads + head_idx) * head_dim + tx * vec_size); } /*! @@ -302,113 +261,88 @@ AppendPagedKVCacheDecodeKernel(paged_kv_t paged_kv, */ template __global__ void AppendPagedKVCacheKernel(paged_kv_t paged_kv, - DType *__restrict__ append_key, - DType *__restrict__ append_value, - IdType *__restrict__ batch_indices, - IdType *__restrict__ positions, - uint32_t nnz, - size_t append_k_stride_n, - size_t append_k_stride_h, - size_t append_v_stride_n, - size_t append_v_stride_h) -{ - uint32_t tx = threadIdx.x, ty = threadIdx.y; - uint32_t num_heads = paged_kv.num_heads; - uint32_t head_idx = ty; - uint32_t cta_id = blockIdx.x; - uint32_t num_ctas = gridDim.x; + DType* __restrict__ append_key, + DType* __restrict__ append_value, + IdType* __restrict__ batch_indices, + IdType* __restrict__ positions, uint32_t nnz, + size_t append_k_stride_n, size_t append_k_stride_h, + size_t append_v_stride_n, size_t append_v_stride_h) { + uint32_t tx = threadIdx.x, ty = threadIdx.y; + uint32_t num_heads = paged_kv.num_heads; + uint32_t head_idx = ty; + uint32_t cta_id = blockIdx.x; + uint32_t num_ctas = gridDim.x; #pragma unroll 4 - for (uint32_t i = cta_id; i < nnz; i += num_ctas) { - uint32_t page_iter, entry_idx; - paged_kv.page_size.divmod(paged_kv.indptr[batch_indices[i]] * - paged_kv.page_size + - positions[i], - page_iter, entry_idx); - DType *k_ptr = - paged_kv.get_k_ptr(page_iter, head_idx, entry_idx, tx * vec_size); - DType *v_ptr = - paged_kv.get_v_ptr(page_iter, head_idx, entry_idx, tx * vec_size); - vec_t::memcpy( - k_ptr, append_key + i * append_k_stride_n + - head_idx * append_k_stride_h + tx * vec_size); - vec_t::memcpy( - v_ptr, append_value + i * append_v_stride_n + - head_idx * append_v_stride_h + tx * vec_size); - } + for (uint32_t i = cta_id; i < nnz; i += num_ctas) { + uint32_t page_iter, entry_idx; + paged_kv.page_size.divmod(paged_kv.indptr[batch_indices[i]] * paged_kv.page_size + positions[i], + page_iter, entry_idx); + DType* k_ptr = paged_kv.get_k_ptr(page_iter, head_idx, entry_idx, tx * vec_size); + DType* v_ptr = paged_kv.get_v_ptr(page_iter, head_idx, entry_idx, tx * vec_size); + vec_t::memcpy( + k_ptr, append_key + i * append_k_stride_n + head_idx * append_k_stride_h + tx * vec_size); + vec_t::memcpy( + v_ptr, append_value + i * append_v_stride_n + head_idx * append_v_stride_h + tx * vec_size); + } } template __global__ void BlockSparseIndicesToVectorSparseOffsetsKernel( - IdType *__restrict__ block_sparse_indices, - IdType *__restrict__ block_sparse_indptr, - IdType *__restrict__ vector_sparse_offsets, - IdType *__restrict__ vector_sparse_indptr, - IdType *__restrict__ kv_lens, - const uint32_t stride_block, - const uint32_t stride_n, - const uint32_t batch_size, - const uint_fastdiv block_size) -{ + IdType* __restrict__ block_sparse_indices, IdType* __restrict__ block_sparse_indptr, + IdType* __restrict__ vector_sparse_offsets, IdType* __restrict__ vector_sparse_indptr, + IdType* __restrict__ kv_lens, const uint32_t stride_block, const uint32_t stride_n, + const uint32_t batch_size, const uint_fastdiv block_size) { #pragma unroll 1 - for (int b = blockIdx.x; b < batch_size; ++b) { + for (int b = blockIdx.x; b < batch_size; ++b) { #pragma unroll 2 - for (int pos = threadIdx.x; pos < kv_lens[b]; pos += blockDim.x) { - uint32_t q, r; - block_size.divmod(pos, q, r); - vector_sparse_offsets[vector_sparse_indptr[b] + pos] = - block_sparse_indices[block_sparse_indptr[b] + q] * - stride_block + - r * stride_n; - } + for (int pos = threadIdx.x; pos < kv_lens[b]; pos += blockDim.x) { + uint32_t q, r; + block_size.divmod(pos, q, r); + vector_sparse_offsets[vector_sparse_indptr[b] + pos] = + block_sparse_indices[block_sparse_indptr[b] + q] * stride_block + r * stride_n; } + } } template -hipError_t BlockSparseIndicesToVectorSparseOffset(IdType *block_sparse_indices, - IdType *block_sparse_indptr, - IdType *vector_sparse_offsets, - IdType *vector_sparse_indptr, - IdType *kv_lens, - const int64_t stride_block, - const int64_t stride_n, - const int64_t batch_size, - const uint32_t block_size, - hipStream_t stream = nullptr) -{ - int dev_id = 0; - int num_sms = 0; - FLASHINFER_CUDA_CALL(hipGetDevice(&dev_id)); - FLASHINFER_CUDA_CALL(hipDeviceGetAttribute( - &num_sms, hipDeviceAttributeMultiprocessorCount, dev_id)); - - uint32_t num_threads = 512; - - uint_fastdiv block_size_fastdiv(block_size); - - auto kernel = BlockSparseIndicesToVectorSparseOffsetsKernel; - void *args[] = {(void *)&block_sparse_indices, - (void *)&block_sparse_indptr, - (void *)&vector_sparse_offsets, - (void *)&vector_sparse_indptr, - (void *)&kv_lens, - (void *)&stride_block, - (void *)&stride_n, - (void *)&batch_size, - (void *)&block_size_fastdiv}; - - // FLASHINFER_CUDA_CALL(cudaLaunchKernel((void *)kernel, num_sms, - // num_threads, - // args, 0, stream)); - - hipLaunchKernel((void *)kernel, dim3(num_sms), dim3(num_threads), args, 0, - stream); - // BlockSparseIndicesToVectorSparseOffsetsKernel<<>>( - // block_sparse_indices, block_sparse_indptr, vector_sparse_offsets, - // vector_sparse_indptr, kv_lens, stride_block, stride_n, batch_size, - // block_size); - return hipSuccess; +hipError_t BlockSparseIndicesToVectorSparseOffset( + IdType* block_sparse_indices, IdType* block_sparse_indptr, IdType* vector_sparse_offsets, + IdType* vector_sparse_indptr, IdType* kv_lens, const int64_t stride_block, + const int64_t stride_n, const int64_t batch_size, const uint32_t block_size, + hipStream_t stream = nullptr) { + int dev_id = 0; + int num_sms = 0; + FLASHINFER_CUDA_CALL(hipGetDevice(&dev_id)); + FLASHINFER_CUDA_CALL( + hipDeviceGetAttribute(&num_sms, hipDeviceAttributeMultiprocessorCount, dev_id)); + + uint32_t num_threads = 512; + + uint_fastdiv block_size_fastdiv(block_size); + + auto kernel = BlockSparseIndicesToVectorSparseOffsetsKernel; + void* args[] = {(void*)&block_sparse_indices, + (void*)&block_sparse_indptr, + (void*)&vector_sparse_offsets, + (void*)&vector_sparse_indptr, + (void*)&kv_lens, + (void*)&stride_block, + (void*)&stride_n, + (void*)&batch_size, + (void*)&block_size_fastdiv}; + + // FLASHINFER_CUDA_CALL(cudaLaunchKernel((void *)kernel, num_sms, + // num_threads, + // args, 0, stream)); + + hipLaunchKernel((void*)kernel, dim3(num_sms), dim3(num_threads), args, 0, stream); + // BlockSparseIndicesToVectorSparseOffsetsKernel<<>>( + // block_sparse_indices, block_sparse_indptr, vector_sparse_offsets, + // vector_sparse_indptr, kv_lens, stride_block, stride_n, batch_size, + // block_size); + return hipSuccess; } /*! @@ -423,29 +357,24 @@ hipError_t BlockSparseIndicesToVectorSparseOffset(IdType *block_sparse_indices, * \return status Indicates whether CUDA calls are successful */ template -hipError_t AppendPagedKVCacheDecode(paged_kv_t paged_kv, - DType *key, - DType *value, - hipStream_t stream = nullptr) -{ - uint32_t head_dim = paged_kv.head_dim; - uint32_t batch_size = paged_kv.batch_size; - uint32_t num_heads = paged_kv.num_heads; - DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { - constexpr uint32_t vec_size = - std::max(16 / sizeof(DType), HEAD_DIM / 32); - uint32_t bdx = HEAD_DIM / vec_size; - uint32_t bdy = num_heads; - // NOTE(Zihao): could be slow for small batch size, will optimize later - dim3 nblks(batch_size); - dim3 nthrs(bdx, bdy); - auto kernel = - AppendPagedKVCacheDecodeKernel; - void *args[] = {(void *)&paged_kv, (void *)&key, (void *)&value}; - - hipLaunchKernel((void *)kernel, nblks, nthrs, args, 0, stream); - }); - return hipSuccess; +hipError_t AppendPagedKVCacheDecode(paged_kv_t paged_kv, DType* key, DType* value, + hipStream_t stream = nullptr) { + uint32_t head_dim = paged_kv.head_dim; + uint32_t batch_size = paged_kv.batch_size; + uint32_t num_heads = paged_kv.num_heads; + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + constexpr uint32_t vec_size = std::max(16 / sizeof(DType), HEAD_DIM / 32); + uint32_t bdx = HEAD_DIM / vec_size; + uint32_t bdy = num_heads; + // NOTE(Zihao): could be slow for small batch size, will optimize later + dim3 nblks(batch_size); + dim3 nthrs(bdx, bdy); + auto kernel = AppendPagedKVCacheDecodeKernel; + void* args[] = {(void*)&paged_kv, (void*)&key, (void*)&value}; + + hipLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream); + }); + return hipSuccess; } /*! @@ -461,336 +390,287 @@ hipError_t AppendPagedKVCacheDecode(paged_kv_t paged_kv, * \return status Indicates whether CUDA calls are successful */ template -hipError_t AppendPagedKVCache(paged_kv_t paged_kv, - DType *append_key, - DType *append_value, - IdType *batch_indices, - IdType *positions, - uint32_t nnz, - size_t append_k_stride_n, - size_t append_k_stride_h, - size_t append_v_stride_n, - size_t append_v_stride_h, - hipStream_t stream = nullptr) -{ - uint32_t head_dim = paged_kv.head_dim; - uint32_t num_heads = paged_kv.num_heads; - int dev_id = 0; - int num_sms = 0; - int num_blocks_per_sm = 0; - FLASHINFER_CUDA_CALL(hipGetDevice(&dev_id)); - FLASHINFER_CUDA_CALL(hipDeviceGetAttribute( - &num_sms, hipDeviceAttributeMultiprocessorCount, dev_id)); - - DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { - constexpr uint32_t vec_size = - std::max(16 / sizeof(DType), HEAD_DIM / 32); - uint32_t bdx = HEAD_DIM / vec_size; - uint32_t bdy = num_heads; - uint32_t num_threads = bdx * bdy; - uint32_t smem_size = 0; - auto kernel = - AppendPagedKVCacheKernel; - assert(num_sms > 0); - FLASHINFER_CUDA_CALL(hipOccupancyMaxActiveBlocksPerMultiprocessor( - &num_blocks_per_sm, kernel, num_threads, smem_size)); - - num_blocks_per_sm = min(num_blocks_per_sm, ceil_div(int(nnz), num_sms)); - dim3 nblks(num_blocks_per_sm * num_sms); - dim3 nthrs(bdx, bdy); - - void *args[] = {(void *)&paged_kv, (void *)&append_key, - (void *)&append_value, (void *)&batch_indices, - (void *)&positions, (void *)&nnz, - (void *)&append_k_stride_n, (void *)&append_k_stride_h, - (void *)&append_v_stride_n, (void *)&append_v_stride_h}; - // FLASHINFER_CUDA_CALL( - // cudaLaunchKernel((void *)kernel, nblks, nthrs, args, 0, stream)); - hipLaunchKernel((void *)kernel, nblks, nthrs, args, 0, stream); - }); - return hipSuccess; -} - -template struct paged_kv_mla_t -{ - uint_fastdiv page_size; - uint32_t head_dim_ckv; - uint32_t head_dim_kpe; - uint32_t batch_size; - uint32_t stride_page_ckv; - uint32_t stride_page_kpe; - uint32_t stride_n_ckv; - uint32_t stride_n_kpe; - - // Internal layout: - // [max_num_pages, page_size, head_dim] - DType *ckv_data; - DType *kpe_data; - IdType *indices; - - // [batch_size + 1] The page indptr array, with the first element 0, the - // last element nnz_pages - IdType *indptr; - // [batch_size] The offset of the last page for each request in the batch - IdType *last_page_len; - // [batch_size] The start position of each request in the batch. - IdType *rope_pos_offset; - - /*! - * \brief Construct an empty paged key-value cache - */ - __host__ __device__ __forceinline__ paged_kv_mla_t() - : head_dim_ckv(0), head_dim_kpe(0), batch_size(0), stride_page_ckv(0), - stride_page_kpe(0), stride_n_ckv(0), stride_n_kpe(0), - ckv_data(nullptr), kpe_data(nullptr), indices(nullptr), - indptr(nullptr), last_page_len(nullptr), rope_pos_offset(nullptr) - { - } - - /*! - * \brief Construct a paged mla kv cache - * \param page_size The size of each page - * \param head_dim_compressed_kv The dimension of compressed-kv - * \param head_dim_kpe The dimension of k-pe - * \param batch_size The batch size - * \param compressed_kv_data The start pointer of compressed-kv cache, cache - * should be contiguous - * \param kpe_data The start pointer of k-pe cache, cache should be - * contiguous - * \param indices The page indices array - * \param indptr The page indptr array - * \param last_page_len The offset of the last page for each request in the - * batch - * \param rope_pos_offset The start position of each request in the batch. - */ - __host__ __forceinline__ paged_kv_mla_t(uint32_t page_size, - uint32_t head_dim_compressed_kv, - uint32_t head_dim_kpe, - uint32_t batch_size, - DType *compressed_kv_data, - DType *kpe_data, - IdType *indices, - IdType *indptr, - IdType *last_page_len, - IdType *rope_pos_offset = nullptr) - : page_size(page_size), head_dim_ckv(head_dim_compressed_kv), - head_dim_kpe(head_dim_kpe), batch_size(batch_size), - ckv_data(compressed_kv_data), kpe_data(kpe_data), indices(indices), - indptr(indptr), last_page_len(last_page_len), - rope_pos_offset(rope_pos_offset) - { - stride_page_ckv = page_size * head_dim_ckv; - stride_n_ckv = head_dim_ckv; - stride_page_kpe = page_size * head_dim_kpe; - stride_n_kpe = head_dim_kpe; - } - - /*! - * \brief Construct a paged key-value cache with custom kv-cache strides - * \param page_size The size of each page - * \param head_dim_compressed_kv The dimension of compressed-kv - * \param head_dim_kpe The dimension of k-pe - * \param batch_size The batch size - * \param compressed_kv_data The start pointer of compressed-kv cache, cache - * should be contiguous - * \param compressed_kv_strides custom strides of each dimensions of - * compressed-kv cache - * \param kpe_data The start pointer of k-pe cache, cache should be - * contiguous - * \param kpe_strides custom strides of each dimensions of k-pe cache - * \param indices The page indices array - * \param indptr The page indptr array - * \param last_page_len The offset of the last page for each request in the - * batch - * \param rope_pos_offset The start position of each request in the batch. - */ - __host__ __forceinline__ - paged_kv_mla_t(uint32_t page_size, - uint32_t head_dim_compressed_kv, - uint32_t head_dim_kpe, - uint32_t batch_size, - DType *compressed_kv_data, - const int64_t *compressed_kv_strides, - DType *kpe_data, - const int64_t *kpe_strides, - IdType *indices, - IdType *indptr, - IdType *last_page_len, - IdType *rope_pos_offset = nullptr) - : page_size(page_size), head_dim_ckv(head_dim_compressed_kv), - head_dim_kpe(head_dim_kpe), batch_size(batch_size), - ckv_data(compressed_kv_data), kpe_data(kpe_data), indices(indices), - indptr(indptr), last_page_len(last_page_len), - rope_pos_offset(rope_pos_offset) - { - stride_page_ckv = compressed_kv_strides[0]; - stride_n_ckv = compressed_kv_strides[1]; - stride_page_kpe = kpe_strides[0]; - stride_n_kpe = kpe_strides[1]; - } +hipError_t AppendPagedKVCache(paged_kv_t paged_kv, DType* append_key, + DType* append_value, IdType* batch_indices, IdType* positions, + uint32_t nnz, size_t append_k_stride_n, size_t append_k_stride_h, + size_t append_v_stride_n, size_t append_v_stride_h, + hipStream_t stream = nullptr) { + uint32_t head_dim = paged_kv.head_dim; + uint32_t num_heads = paged_kv.num_heads; + int dev_id = 0; + int num_sms = 0; + int num_blocks_per_sm = 0; + FLASHINFER_CUDA_CALL(hipGetDevice(&dev_id)); + FLASHINFER_CUDA_CALL( + hipDeviceGetAttribute(&num_sms, hipDeviceAttributeMultiprocessorCount, dev_id)); + + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + constexpr uint32_t vec_size = std::max(16 / sizeof(DType), HEAD_DIM / 32); + uint32_t bdx = HEAD_DIM / vec_size; + uint32_t bdy = num_heads; + uint32_t num_threads = bdx * bdy; + uint32_t smem_size = 0; + auto kernel = AppendPagedKVCacheKernel; + assert(num_sms > 0); + FLASHINFER_CUDA_CALL(hipOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks_per_sm, kernel, + num_threads, smem_size)); - __host__ __device__ __forceinline__ uint32_t - get_length(uint32_t batch_idx) const - { - if (indptr[batch_idx + 1] == indptr[batch_idx]) { - return 0; - } - return (indptr[batch_idx + 1] - indptr[batch_idx] - 1) * page_size + - last_page_len[batch_idx]; - } + num_blocks_per_sm = min(num_blocks_per_sm, ceil_div(int(nnz), num_sms)); + dim3 nblks(num_blocks_per_sm * num_sms); + dim3 nthrs(bdx, bdy); - __host__ __device__ __forceinline__ size_t - get_elem_offset_ckv(size_t page_idx, - size_t entry_idx, - size_t feat_idx) const - { - return page_idx * stride_page_ckv + entry_idx * stride_n_ckv + feat_idx; - } + void* args[] = {(void*)&paged_kv, (void*)&append_key, (void*)&append_value, + (void*)&batch_indices, (void*)&positions, (void*)&nnz, + (void*)&append_k_stride_n, (void*)&append_k_stride_h, (void*)&append_v_stride_n, + (void*)&append_v_stride_h}; + // FLASHINFER_CUDA_CALL( + // cudaLaunchKernel((void *)kernel, nblks, nthrs, args, 0, stream)); + hipLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream); + }); + return hipSuccess; +} - __device__ __forceinline__ size_t - protective_get_offset_ckv(IdType page_iter, - uint32_t entry_idx, - uint32_t feat_idx, - IdType last_indptr) const - { - if (page_iter < last_indptr) { - return get_elem_offset_ckv(__ldg(indices + page_iter), entry_idx, - feat_idx); - } - else { - return 0; - } +template +struct paged_kv_mla_t { + uint_fastdiv page_size; + uint32_t head_dim_ckv; + uint32_t head_dim_kpe; + uint32_t batch_size; + uint32_t stride_page_ckv; + uint32_t stride_page_kpe; + uint32_t stride_n_ckv; + uint32_t stride_n_kpe; + + // Internal layout: + // [max_num_pages, page_size, head_dim] + DType* ckv_data; + DType* kpe_data; + IdType* indices; + + // [batch_size + 1] The page indptr array, with the first element 0, the + // last element nnz_pages + IdType* indptr; + // [batch_size] The offset of the last page for each request in the batch + IdType* last_page_len; + // [batch_size] The start position of each request in the batch. + IdType* rope_pos_offset; + + /*! + * \brief Construct an empty paged key-value cache + */ + __host__ __device__ __forceinline__ paged_kv_mla_t() + : head_dim_ckv(0), + head_dim_kpe(0), + batch_size(0), + stride_page_ckv(0), + stride_page_kpe(0), + stride_n_ckv(0), + stride_n_kpe(0), + ckv_data(nullptr), + kpe_data(nullptr), + indices(nullptr), + indptr(nullptr), + last_page_len(nullptr), + rope_pos_offset(nullptr) {} + + /*! + * \brief Construct a paged mla kv cache + * \param page_size The size of each page + * \param head_dim_compressed_kv The dimension of compressed-kv + * \param head_dim_kpe The dimension of k-pe + * \param batch_size The batch size + * \param compressed_kv_data The start pointer of compressed-kv cache, cache + * should be contiguous + * \param kpe_data The start pointer of k-pe cache, cache should be + * contiguous + * \param indices The page indices array + * \param indptr The page indptr array + * \param last_page_len The offset of the last page for each request in the + * batch + * \param rope_pos_offset The start position of each request in the batch. + */ + __host__ __forceinline__ paged_kv_mla_t(uint32_t page_size, uint32_t head_dim_compressed_kv, + uint32_t head_dim_kpe, uint32_t batch_size, + DType* compressed_kv_data, DType* kpe_data, + IdType* indices, IdType* indptr, IdType* last_page_len, + IdType* rope_pos_offset = nullptr) + : page_size(page_size), + head_dim_ckv(head_dim_compressed_kv), + head_dim_kpe(head_dim_kpe), + batch_size(batch_size), + ckv_data(compressed_kv_data), + kpe_data(kpe_data), + indices(indices), + indptr(indptr), + last_page_len(last_page_len), + rope_pos_offset(rope_pos_offset) { + stride_page_ckv = page_size * head_dim_ckv; + stride_n_ckv = head_dim_ckv; + stride_page_kpe = page_size * head_dim_kpe; + stride_n_kpe = head_dim_kpe; + } + + /*! + * \brief Construct a paged key-value cache with custom kv-cache strides + * \param page_size The size of each page + * \param head_dim_compressed_kv The dimension of compressed-kv + * \param head_dim_kpe The dimension of k-pe + * \param batch_size The batch size + * \param compressed_kv_data The start pointer of compressed-kv cache, cache + * should be contiguous + * \param compressed_kv_strides custom strides of each dimensions of + * compressed-kv cache + * \param kpe_data The start pointer of k-pe cache, cache should be + * contiguous + * \param kpe_strides custom strides of each dimensions of k-pe cache + * \param indices The page indices array + * \param indptr The page indptr array + * \param last_page_len The offset of the last page for each request in the + * batch + * \param rope_pos_offset The start position of each request in the batch. + */ + __host__ __forceinline__ paged_kv_mla_t(uint32_t page_size, uint32_t head_dim_compressed_kv, + uint32_t head_dim_kpe, uint32_t batch_size, + DType* compressed_kv_data, + const int64_t* compressed_kv_strides, DType* kpe_data, + const int64_t* kpe_strides, IdType* indices, + IdType* indptr, IdType* last_page_len, + IdType* rope_pos_offset = nullptr) + : page_size(page_size), + head_dim_ckv(head_dim_compressed_kv), + head_dim_kpe(head_dim_kpe), + batch_size(batch_size), + ckv_data(compressed_kv_data), + kpe_data(kpe_data), + indices(indices), + indptr(indptr), + last_page_len(last_page_len), + rope_pos_offset(rope_pos_offset) { + stride_page_ckv = compressed_kv_strides[0]; + stride_n_ckv = compressed_kv_strides[1]; + stride_page_kpe = kpe_strides[0]; + stride_n_kpe = kpe_strides[1]; + } + + __host__ __device__ __forceinline__ uint32_t get_length(uint32_t batch_idx) const { + if (indptr[batch_idx + 1] == indptr[batch_idx]) { + return 0; } - - __host__ __device__ __forceinline__ size_t - get_elem_offset_kpe(size_t page_idx, - size_t entry_idx, - size_t feat_idx) const - { - return page_idx * stride_page_kpe + entry_idx * stride_n_kpe + feat_idx; + return (indptr[batch_idx + 1] - indptr[batch_idx] - 1) * page_size + last_page_len[batch_idx]; + } + + __host__ __device__ __forceinline__ size_t get_elem_offset_ckv(size_t page_idx, size_t entry_idx, + size_t feat_idx) const { + return page_idx * stride_page_ckv + entry_idx * stride_n_ckv + feat_idx; + } + + __device__ __forceinline__ size_t protective_get_offset_ckv(IdType page_iter, uint32_t entry_idx, + uint32_t feat_idx, + IdType last_indptr) const { + if (page_iter < last_indptr) { + return get_elem_offset_ckv(__ldg(indices + page_iter), entry_idx, feat_idx); + } else { + return 0; } - - __device__ __forceinline__ size_t - protective_get_offset_kpe(IdType page_iter, - uint32_t entry_idx, - uint32_t feat_idx, - IdType last_indptr) const - { - if (page_iter < last_indptr) { - return get_elem_offset_kpe(__ldg(indices + page_iter), entry_idx, - feat_idx); - } - else { - return 0; - } + } + + __host__ __device__ __forceinline__ size_t get_elem_offset_kpe(size_t page_idx, size_t entry_idx, + size_t feat_idx) const { + return page_idx * stride_page_kpe + entry_idx * stride_n_kpe + feat_idx; + } + + __device__ __forceinline__ size_t protective_get_offset_kpe(IdType page_iter, uint32_t entry_idx, + uint32_t feat_idx, + IdType last_indptr) const { + if (page_iter < last_indptr) { + return get_elem_offset_kpe(__ldg(indices + page_iter), entry_idx, feat_idx); + } else { + return 0; } + } - __device__ __forceinline__ DType * - get_ckv_ptr(size_t page_idx, size_t entry_idx, size_t feat_idx) const - { - return ckv_data + get_elem_offset_ckv(__ldg(indices + page_idx), - entry_idx, feat_idx); - } + __device__ __forceinline__ DType* get_ckv_ptr(size_t page_idx, size_t entry_idx, + size_t feat_idx) const { + return ckv_data + get_elem_offset_ckv(__ldg(indices + page_idx), entry_idx, feat_idx); + } - __device__ __forceinline__ DType * - get_kpe_ptr(size_t page_idx, size_t entry_idx, size_t feat_idx) const - { - return kpe_data + get_elem_offset_kpe(__ldg(indices + page_idx), - entry_idx, feat_idx); - } + __device__ __forceinline__ DType* get_kpe_ptr(size_t page_idx, size_t entry_idx, + size_t feat_idx) const { + return kpe_data + get_elem_offset_kpe(__ldg(indices + page_idx), entry_idx, feat_idx); + } }; -template -__global__ void -AppendPagedKVMlaCacheKernel(paged_kv_mla_t paged_kv_mla, - DType *__restrict__ append_ckv, - DType *__restrict__ append_kpe, - IdType *__restrict__ batch_indices, - IdType *__restrict__ positions, - uint32_t nnz, - size_t append_ckv_stride_n, - size_t append_kpe_stride_n) -{ - uint32_t tx = threadIdx.x; - uint32_t cta_id = blockIdx.x; - uint32_t num_ctas = gridDim.x; +__global__ void AppendPagedKVMlaCacheKernel(paged_kv_mla_t paged_kv_mla, + DType* __restrict__ append_ckv, + DType* __restrict__ append_kpe, + IdType* __restrict__ batch_indices, + IdType* __restrict__ positions, uint32_t nnz, + size_t append_ckv_stride_n, + size_t append_kpe_stride_n) { + uint32_t tx = threadIdx.x; + uint32_t cta_id = blockIdx.x; + uint32_t num_ctas = gridDim.x; #pragma unroll 4 - for (uint32_t i = cta_id; i < nnz; i += num_ctas) { - uint32_t page_iter, entry_idx; - paged_kv_mla.page_size.divmod(paged_kv_mla.indptr[batch_indices[i]] * - paged_kv_mla.page_size + - positions[i], - page_iter, entry_idx); - DType *ckv_ptr = - paged_kv_mla.get_ckv_ptr(page_iter, entry_idx, tx * vec_size); - vec_t::memcpy( - ckv_ptr, append_ckv + i * append_ckv_stride_n + tx * vec_size); - - if (tx * vec_size < head_dim_kpe) { - DType *kpe_ptr = - paged_kv_mla.get_kpe_ptr(page_iter, entry_idx, tx * vec_size); - vec_t::memcpy( - kpe_ptr, append_kpe + i * append_kpe_stride_n + tx * vec_size); - } + for (uint32_t i = cta_id; i < nnz; i += num_ctas) { + uint32_t page_iter, entry_idx; + paged_kv_mla.page_size.divmod( + paged_kv_mla.indptr[batch_indices[i]] * paged_kv_mla.page_size + positions[i], page_iter, + entry_idx); + DType* ckv_ptr = paged_kv_mla.get_ckv_ptr(page_iter, entry_idx, tx * vec_size); + vec_t::memcpy(ckv_ptr, append_ckv + i * append_ckv_stride_n + tx * vec_size); + + if (tx * vec_size < head_dim_kpe) { + DType* kpe_ptr = paged_kv_mla.get_kpe_ptr(page_iter, entry_idx, tx * vec_size); + vec_t::memcpy(kpe_ptr, append_kpe + i * append_kpe_stride_n + tx * vec_size); } + } } template -hipError_t AppendPagedKVMlaCache(paged_kv_mla_t paged_kv, - DType *append_ckv, - DType *append_kpe, - IdType *batch_indices, - IdType *positions, - uint32_t nnz, - size_t append_ckv_stride_n, - size_t append_kpe_stride_n, - hipStream_t stream = nullptr) -{ - int dev_id = 0; - int num_sms = 0; - int num_blocks_per_sm = 0; - FLASHINFER_CUDA_CALL(hipGetDevice(&dev_id)); - FLASHINFER_CUDA_CALL(hipDeviceGetAttribute( - &num_sms, hipDeviceAttributeMultiprocessorCount, dev_id)); - - uint32_t head_dim_ckv = paged_kv.head_dim_ckv; - uint32_t head_dim_kpe = paged_kv.head_dim_kpe; - constexpr uint32_t HEAD_CKV_DIM = 512; - constexpr uint32_t HEAD_KPE_DIM = 64; - FLASHINFER_CHECK(head_dim_ckv == HEAD_CKV_DIM, - "head_dim_ckv must be equal to 512"); - FLASHINFER_CHECK(head_dim_kpe == HEAD_KPE_DIM, - "head_dim_kpe must be equal to 64"); - constexpr uint32_t vec_size = 2; - - uint32_t bdx = HEAD_CKV_DIM / vec_size; - uint32_t num_threads = bdx; - uint32_t smem_size = 0; - auto kernel = AppendPagedKVMlaCacheKernel; - assert(num_sms > 0); - FLASHINFER_CUDA_CALL(hipOccupancyMaxActiveBlocksPerMultiprocessor( - &num_blocks_per_sm, kernel, num_threads, smem_size)); - num_blocks_per_sm = min(num_blocks_per_sm, ceil_div(int(nnz), num_sms)); - dim3 nblks(num_blocks_per_sm * num_sms); - dim3 nthrs(bdx); - - // FLASHINFER_CUDA_CALL( - // cudaLaunchKernel((void *)kernel, nblks, nthrs, args, 0, stream)); - - AppendPagedKVMlaCacheKernel<<>>( - paged_kv, append_ckv, append_kpe, batch_indices, positions, nnz, - append_ckv_stride_n, append_kpe_stride_n); - - return hipSuccess; +hipError_t AppendPagedKVMlaCache(paged_kv_mla_t paged_kv, DType* append_ckv, + DType* append_kpe, IdType* batch_indices, IdType* positions, + uint32_t nnz, size_t append_ckv_stride_n, + size_t append_kpe_stride_n, hipStream_t stream = nullptr) { + int dev_id = 0; + int num_sms = 0; + int num_blocks_per_sm = 0; + FLASHINFER_CUDA_CALL(hipGetDevice(&dev_id)); + FLASHINFER_CUDA_CALL( + hipDeviceGetAttribute(&num_sms, hipDeviceAttributeMultiprocessorCount, dev_id)); + + uint32_t head_dim_ckv = paged_kv.head_dim_ckv; + uint32_t head_dim_kpe = paged_kv.head_dim_kpe; + constexpr uint32_t HEAD_CKV_DIM = 512; + constexpr uint32_t HEAD_KPE_DIM = 64; + FLASHINFER_CHECK(head_dim_ckv == HEAD_CKV_DIM, "head_dim_ckv must be equal to 512"); + FLASHINFER_CHECK(head_dim_kpe == HEAD_KPE_DIM, "head_dim_kpe must be equal to 64"); + constexpr uint32_t vec_size = 2; + + uint32_t bdx = HEAD_CKV_DIM / vec_size; + uint32_t num_threads = bdx; + uint32_t smem_size = 0; + auto kernel = AppendPagedKVMlaCacheKernel; + assert(num_sms > 0); + FLASHINFER_CUDA_CALL(hipOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks_per_sm, kernel, + num_threads, smem_size)); + num_blocks_per_sm = min(num_blocks_per_sm, ceil_div(int(nnz), num_sms)); + dim3 nblks(num_blocks_per_sm * num_sms); + dim3 nthrs(bdx); + + // FLASHINFER_CUDA_CALL( + // cudaLaunchKernel((void *)kernel, nblks, nthrs, args, 0, stream)); + + AppendPagedKVMlaCacheKernel + <<>>(paged_kv, append_ckv, append_kpe, batch_indices, + positions, nnz, append_ckv_stride_n, + append_kpe_stride_n); + + return hipSuccess; } -} // namespace flashinfer +} // namespace flashinfer -#endif // FLAHSINFER_PAGE_CUH_ +#endif // FLAHSINFER_PAGE_CUH_ diff --git a/libflashinfer/include/flashinfer/hip/pos_enc.hip.h b/libflashinfer/include/flashinfer/hip/pos_enc.hip.h index d77e794a6f..76eac2693a 100644 --- a/libflashinfer/include/flashinfer/hip/pos_enc.hip.h +++ b/libflashinfer/include/flashinfer/hip/pos_enc.hip.h @@ -7,60 +7,52 @@ #ifndef FLASHINFER_POS_ENC_CUH_ #define FLASHINFER_POS_ENC_CUH_ -#include "layout.hip.h" -#include "math.hip.h" -#include "utils.hip.h" -#include "vec_dtypes.hip.h" - #include #include #include #include -namespace flashinfer -{ +#include "layout.hip.h" +#include "math.hip.h" +#include "utils.hip.h" +#include "vec_dtypes.hip.h" + +namespace flashinfer { /*! * \brief An enumeration class that defines different modes for applying RoPE * (Rotary Positional Embeddings). */ -enum class PosEncodingMode -{ - // No rotary positional embeddings - kNone = 0U, - // Apply Llama-style rope. - kRoPELlama = 1U, - // Apply ALiBi bias - kALiBi = 2U +enum class PosEncodingMode { + // No rotary positional embeddings + kNone = 0U, + // Apply Llama-style rope. + kRoPELlama = 1U, + // Apply ALiBi bias + kALiBi = 2U }; /*! * \brief Convert PosEncodingMode to string * \param pos_encoding_mode A PosEncodingMode value */ -inline std::string -PosEncodingModeToString(const PosEncodingMode &pos_encoding_mode) -{ - switch (pos_encoding_mode) { +inline std::string PosEncodingModeToString(const PosEncodingMode& pos_encoding_mode) { + switch (pos_encoding_mode) { case PosEncodingMode::kNone: - return "None"; + return "None"; case PosEncodingMode::kRoPELlama: - return "Llama"; + return "Llama"; case PosEncodingMode::kALiBi: - return "ALiBi"; + return "ALiBi"; default: - return "Unknown"; - } + return "Unknown"; + } } -__device__ __forceinline__ float get_alibi_slope(uint32_t head_idx, - uint32_t num_heads) -{ - int n = math::ptx_exp2((int)math::ptx_log2(num_heads)); - return head_idx < n - ? math::ptx_exp2(-8. * float(head_idx + 1) / float(n)) - : math::ptx_exp2(-4. * float((head_idx + 1 - n) * 2 - 1) / - float(n)); +__device__ __forceinline__ float get_alibi_slope(uint32_t head_idx, uint32_t num_heads) { + int n = math::ptx_exp2((int)math::ptx_log2(num_heads)); + return head_idx < n ? math::ptx_exp2(-8. * float(head_idx + 1) / float(n)) + : math::ptx_exp2(-4. * float((head_idx + 1 - n) * 2 - 1) / float(n)); } /*! @@ -75,59 +67,48 @@ __device__ __forceinline__ float get_alibi_slope(uint32_t head_idx, * \param offset A integer indicates the offset of the position in RoPE */ template -__device__ __forceinline__ vec_t -vec_apply_llama_rope(const T *x, - const vec_t &freq, - int32_t offset, - const uint32_t rotary_dim = vec_size * bdx) -{ - vec_t permuted_vec, vec; - vec.cast_load(x + threadIdx.x * vec_size); - - if (threadIdx.x * vec_size < rotary_dim) { - permuted_vec.cast_load(x + - ((threadIdx.x * vec_size < rotary_dim / 2) +__device__ __forceinline__ vec_t vec_apply_llama_rope( + const T* x, const vec_t& freq, int32_t offset, + const uint32_t rotary_dim = vec_size * bdx) { + vec_t permuted_vec, vec; + vec.cast_load(x + threadIdx.x * vec_size); + + if (threadIdx.x * vec_size < rotary_dim) { + permuted_vec.cast_load(x + ((threadIdx.x * vec_size < rotary_dim / 2) ? threadIdx.x * vec_size + rotary_dim / 2 : threadIdx.x * vec_size - rotary_dim / 2)); #pragma unroll - for (uint32_t i = 0; i < vec_size; ++i) { - float embed = float(offset) * freq[i]; - float cos, sin; - __sincosf(embed, &sin, &cos); - vec[i] = vec[i] * cos + ((threadIdx.x * vec_size < rotary_dim / 2) - ? -permuted_vec[i] - : permuted_vec[i]) * - sin; - } + for (uint32_t i = 0; i < vec_size; ++i) { + float embed = float(offset) * freq[i]; + float cos, sin; + __sincosf(embed, &sin, &cos); + vec[i] = + vec[i] * cos + + ((threadIdx.x * vec_size < rotary_dim / 2) ? -permuted_vec[i] : permuted_vec[i]) * sin; } - return vec; + } + return vec; } template -__device__ __forceinline__ vec_t -vec_apply_llama_rope_cos_sin(const T *x, - const vec_t &cos, - const vec_t &sin, - const uint32_t rotary_dim = vec_size * bdx) -{ - vec_t permuted_vec, vec; - vec.cast_load(x + threadIdx.x * vec_size); - - if (threadIdx.x * vec_size < rotary_dim) { - permuted_vec.cast_load(x + - ((threadIdx.x * vec_size < rotary_dim / 2) +__device__ __forceinline__ vec_t vec_apply_llama_rope_cos_sin( + const T* x, const vec_t& cos, const vec_t& sin, + const uint32_t rotary_dim = vec_size * bdx) { + vec_t permuted_vec, vec; + vec.cast_load(x + threadIdx.x * vec_size); + + if (threadIdx.x * vec_size < rotary_dim) { + permuted_vec.cast_load(x + ((threadIdx.x * vec_size < rotary_dim / 2) ? threadIdx.x * vec_size + rotary_dim / 2 : threadIdx.x * vec_size - rotary_dim / 2)); #pragma unroll - for (uint32_t i = 0; i < vec_size; ++i) { - vec[i] = - vec[i] * cos[i] + ((threadIdx.x * vec_size < rotary_dim / 2) - ? -permuted_vec[i] - : permuted_vec[i]) * - sin[i]; - } + for (uint32_t i = 0; i < vec_size; ++i) { + vec[i] = + vec[i] * cos[i] + + ((threadIdx.x * vec_size < rotary_dim / 2) ? -permuted_vec[i] : permuted_vec[i]) * sin[i]; } - return vec; + } + return vec; } /*! @@ -142,51 +123,40 @@ vec_apply_llama_rope_cos_sin(const T *x, * \param offset A integer indicates the offset of the position in RoPE */ template -__device__ __forceinline__ vec_t -vec_apply_llama_rope_interleave(const T *x, - const vec_t &freq, - int32_t offset, - const uint32_t rotary_dim = vec_size * bdx) -{ - vec_t vec, vec_before; - vec.cast_load(x + threadIdx.x * vec_size); - - if (threadIdx.x * vec_size < rotary_dim) { - vec_before = vec; +__device__ __forceinline__ vec_t vec_apply_llama_rope_interleave( + const T* x, const vec_t& freq, int32_t offset, + const uint32_t rotary_dim = vec_size * bdx) { + vec_t vec, vec_before; + vec.cast_load(x + threadIdx.x * vec_size); + + if (threadIdx.x * vec_size < rotary_dim) { + vec_before = vec; #pragma unroll - for (uint32_t i = 0; i < vec_size; ++i) { - float embed = float(offset) * freq[i]; - float cos, sin; - __sincosf(embed, &sin, &cos); - vec[i] = - vec[i] * cos + - ((i % 2 == 0) ? -vec_before[i ^ 1] : vec_before[i ^ 1]) * sin; - } + for (uint32_t i = 0; i < vec_size; ++i) { + float embed = float(offset) * freq[i]; + float cos, sin; + __sincosf(embed, &sin, &cos); + vec[i] = vec[i] * cos + ((i % 2 == 0) ? -vec_before[i ^ 1] : vec_before[i ^ 1]) * sin; } - return vec; + } + return vec; } template -__device__ __forceinline__ vec_t -vec_apply_llama_rope_cos_sin_interleave(const T *x, - const vec_t &cos, - const vec_t &sin, - const uint32_t rotary_dim = vec_size * - bdx) -{ - vec_t vec, vec_before; - vec.cast_load(x + threadIdx.x * vec_size); - - if (threadIdx.x * vec_size < rotary_dim) { - vec_before = vec; +__device__ __forceinline__ vec_t vec_apply_llama_rope_cos_sin_interleave( + const T* x, const vec_t& cos, const vec_t& sin, + const uint32_t rotary_dim = vec_size * bdx) { + vec_t vec, vec_before; + vec.cast_load(x + threadIdx.x * vec_size); + + if (threadIdx.x * vec_size < rotary_dim) { + vec_before = vec; #pragma unroll - for (uint32_t i = 0; i < vec_size; ++i) { - vec[i] = vec[i] * cos[i] + - ((i % 2 == 0) ? -vec_before[i ^ 1] : vec_before[i ^ 1]) * - sin[i]; - } + for (uint32_t i = 0; i < vec_size; ++i) { + vec[i] = vec[i] * cos[i] + ((i % 2 == 0) ? -vec_before[i ^ 1] : vec_before[i ^ 1]) * sin[i]; } - return vec; + } + return vec; } /* @@ -212,670 +182,471 @@ non-interleave mode. */ template __device__ __forceinline__ vec_t -vec_apply_llama_rope_cos_sin_interleave_reuse_half( - const T *x, - const vec_t &cos, - const vec_t &sin, - const uint32_t rotary_dim = vec_size * bdx) -{ - vec_t vec, vec_before; - vec.cast_load(x + threadIdx.x * vec_size); - - if (threadIdx.x * vec_size < rotary_dim) { - vec_before = vec; +vec_apply_llama_rope_cos_sin_interleave_reuse_half(const T* x, const vec_t& cos, + const vec_t& sin, + const uint32_t rotary_dim = vec_size * bdx) { + vec_t vec, vec_before; + vec.cast_load(x + threadIdx.x * vec_size); + + if (threadIdx.x * vec_size < rotary_dim) { + vec_before = vec; #pragma unroll - for (uint32_t i = 0; i < vec_size; ++i) { - // i / 2 is to get the index of the first half of cos and sin - vec[i] = vec[i] * cos[i / 2] + - ((i % 2 == 0) ? -vec_before[i ^ 1] : vec_before[i ^ 1]) * - sin[i / 2]; - } + for (uint32_t i = 0; i < vec_size; ++i) { + // i / 2 is to get the index of the first half of cos and sin + vec[i] = vec[i] * cos[i / 2] + + ((i % 2 == 0) ? -vec_before[i ^ 1] : vec_before[i ^ 1]) * sin[i / 2]; } - return vec; + } + return vec; } -template __global__ void BatchQKApplyRotaryPosIdsCosSinCacheHeadParallelismKernel( - DType *q, - DType *k, - DType *q_rope, - DType *k_rope, - float *__restrict__ cos_sin_cache, - IdType *__restrict__ pos_ids, - uint32_t nnz, - uint32_t num_qo_heads, - uint32_t num_kv_heads, - uint32_t rotary_dim, - size_t q_stride_n, - size_t q_stride_h, - size_t k_stride_n, - size_t k_stride_h, - size_t q_rope_stride_n, - size_t q_rope_stride_h, - size_t k_rope_stride_n, - size_t k_rope_stride_h) -{ - uint32_t bx = blockIdx.x, tx = threadIdx.x, ty = threadIdx.y; - uint32_t by = blockIdx.y; - const uint32_t bdy = blockDim.y; - - vec_t cos, sin; - if (bx * bdy + ty < nnz) { - const uint32_t idx = bx * bdy + ty; - const IdType pos = pos_ids[idx]; - - const int half_rotary_dim = rotary_dim / 2; - - // 1. if interleave: - // - cos = cos_sin_cache[pos_id][tx * vec_size // 2] - // - sin = cos_sin_cache[pos_id][(rot_dim // 2) + tx * vec_size // 2] - // 2. if not interleave - // - cos = cos_cache[pos_id][(tx * vec_size) % (rot_dim // 2)] - // - sin = sin_cache[pos_id][(rot_dim // 2) + (tx * vec_size) % - // (rot_dim // 2)] - if (tx * vec_size < rotary_dim) { - int sin_offset = rotary_dim / 2; - int vec_idx; - if constexpr (interleave) { - vec_idx = (tx * vec_size) / 2; // Force integer division - } - else { - vec_idx = - (tx * vec_size) % half_rotary_dim; // Use half_rotary_dim - } - cos.load(cos_sin_cache + (pos * rotary_dim) + vec_idx); - sin.load(cos_sin_cache + (pos * rotary_dim) + - (sin_offset + vec_idx)); - } + DType* q, DType* k, DType* q_rope, DType* k_rope, float* __restrict__ cos_sin_cache, + IdType* __restrict__ pos_ids, uint32_t nnz, uint32_t num_qo_heads, uint32_t num_kv_heads, + uint32_t rotary_dim, size_t q_stride_n, size_t q_stride_h, size_t k_stride_n, size_t k_stride_h, + size_t q_rope_stride_n, size_t q_rope_stride_h, size_t k_rope_stride_n, + size_t k_rope_stride_h) { + uint32_t bx = blockIdx.x, tx = threadIdx.x, ty = threadIdx.y; + uint32_t by = blockIdx.y; + const uint32_t bdy = blockDim.y; + + vec_t cos, sin; + if (bx * bdy + ty < nnz) { + const uint32_t idx = bx * bdy + ty; + const IdType pos = pos_ids[idx]; + + const int half_rotary_dim = rotary_dim / 2; + + // 1. if interleave: + // - cos = cos_sin_cache[pos_id][tx * vec_size // 2] + // - sin = cos_sin_cache[pos_id][(rot_dim // 2) + tx * vec_size // 2] + // 2. if not interleave + // - cos = cos_cache[pos_id][(tx * vec_size) % (rot_dim // 2)] + // - sin = sin_cache[pos_id][(rot_dim // 2) + (tx * vec_size) % + // (rot_dim // 2)] + if (tx * vec_size < rotary_dim) { + int sin_offset = rotary_dim / 2; + int vec_idx; + if constexpr (interleave) { + vec_idx = (tx * vec_size) / 2; // Force integer division + } else { + vec_idx = (tx * vec_size) % half_rotary_dim; // Use half_rotary_dim + } + cos.load(cos_sin_cache + (pos * rotary_dim) + vec_idx); + sin.load(cos_sin_cache + (pos * rotary_dim) + (sin_offset + vec_idx)); + } - if (by < num_qo_heads) { - uint32_t qo_head_idx = by; - DType *q_ptr = q + get_elem_offset_impl(idx, qo_head_idx, 0, - q_stride_n, q_stride_h); - DType *q_rope_ptr = - q_rope + get_elem_offset_impl(idx, qo_head_idx, 0, - q_rope_stride_n, q_rope_stride_h); - vec_t q_vec; - if constexpr (interleave) { - q_vec = vec_apply_llama_rope_cos_sin_interleave_reuse_half< - vec_size, bdx>(q_ptr, cos, sin, rotary_dim); - } - else { - q_vec = vec_apply_llama_rope_cos_sin( - q_ptr, cos, sin, rotary_dim); - } - q_vec.cast_store(q_rope_ptr + tx * vec_size); - } - else { - uint32_t kv_head_idx = by - num_qo_heads; - DType *k_ptr = k + get_elem_offset_impl(idx, kv_head_idx, 0, - k_stride_n, k_stride_h); - DType *k_rope_ptr = - k_rope + get_elem_offset_impl(idx, kv_head_idx, 0, - k_rope_stride_n, k_rope_stride_h); - vec_t k_vec; - if constexpr (interleave) { - k_vec = vec_apply_llama_rope_cos_sin_interleave_reuse_half< - vec_size, bdx>(k_ptr, cos, sin, rotary_dim); - } - else { - k_vec = vec_apply_llama_rope_cos_sin( - k_ptr, cos, sin, rotary_dim); - } - k_vec.cast_store(k_rope_ptr + tx * vec_size); - } + if (by < num_qo_heads) { + uint32_t qo_head_idx = by; + DType* q_ptr = q + get_elem_offset_impl(idx, qo_head_idx, 0, q_stride_n, q_stride_h); + DType* q_rope_ptr = + q_rope + get_elem_offset_impl(idx, qo_head_idx, 0, q_rope_stride_n, q_rope_stride_h); + vec_t q_vec; + if constexpr (interleave) { + q_vec = vec_apply_llama_rope_cos_sin_interleave_reuse_half(q_ptr, cos, sin, + rotary_dim); + } else { + q_vec = vec_apply_llama_rope_cos_sin(q_ptr, cos, sin, rotary_dim); + } + q_vec.cast_store(q_rope_ptr + tx * vec_size); + } else { + uint32_t kv_head_idx = by - num_qo_heads; + DType* k_ptr = k + get_elem_offset_impl(idx, kv_head_idx, 0, k_stride_n, k_stride_h); + DType* k_rope_ptr = + k_rope + get_elem_offset_impl(idx, kv_head_idx, 0, k_rope_stride_n, k_rope_stride_h); + vec_t k_vec; + if constexpr (interleave) { + k_vec = vec_apply_llama_rope_cos_sin_interleave_reuse_half(k_ptr, cos, sin, + rotary_dim); + } else { + k_vec = vec_apply_llama_rope_cos_sin(k_ptr, cos, sin, rotary_dim); + } + k_vec.cast_store(k_rope_ptr + tx * vec_size); } + } } -template -__global__ void -BatchQKApplyRotaryPosIdsCosSinCacheKernel(DType *q, - DType *k, - DType *q_rope, - DType *k_rope, - float *__restrict__ cos_sin_cache, - IdType *__restrict__ pos_ids, - uint32_t nnz, - uint32_t num_qo_heads, - uint32_t num_kv_heads, - uint32_t rotary_dim, - size_t q_stride_n, - size_t q_stride_h, - size_t k_stride_n, - size_t k_stride_h, - size_t q_rope_stride_n, - size_t q_rope_stride_h, - size_t k_rope_stride_n, - size_t k_rope_stride_h) +__global__ void BatchQKApplyRotaryPosIdsCosSinCacheKernel( + DType* q, DType* k, DType* q_rope, DType* k_rope, float* __restrict__ cos_sin_cache, + IdType* __restrict__ pos_ids, uint32_t nnz, uint32_t num_qo_heads, uint32_t num_kv_heads, + uint32_t rotary_dim, size_t q_stride_n, size_t q_stride_h, size_t k_stride_n, size_t k_stride_h, + size_t q_rope_stride_n, size_t q_rope_stride_h, size_t k_rope_stride_n, size_t k_rope_stride_h) { - uint32_t bx = blockIdx.x, tx = threadIdx.x, ty = threadIdx.y; - const uint32_t bdy = blockDim.y; - - vec_t cos, sin; - if (bx * bdy + ty < nnz) { - const uint32_t idx = bx * bdy + ty; - const IdType pos = pos_ids[idx]; - const int half_rotary_dim = rotary_dim / 2; - - // 1. if interleave: - // - cos = cos_sin_cache[pos_id][tx * vec_size // 2] - // - sin = cos_sin_cache[pos_id][(rot_dim // 2) + tx * vec_size // 2] - // 2. if not interleave - // - cos = cos_cache[pos_id][(tx * vec_size) % (rot_dim // 2)] - // - sin = sin_cache[pos_id][(rot_dim // 2) + (tx * vec_size) % - // (rot_dim // 2)] - if (tx * vec_size < rotary_dim) { - int sin_offset = rotary_dim / 2; - int vec_idx; - if constexpr (interleave) { - vec_idx = (tx * vec_size) / 2; // Force integer division - } - else { - vec_idx = - (tx * vec_size) % half_rotary_dim; // Use half_rotary_dim - } - cos.load(cos_sin_cache + (pos * rotary_dim) + vec_idx); - sin.load(cos_sin_cache + (pos * rotary_dim) + - (sin_offset + vec_idx)); - } + uint32_t bx = blockIdx.x, tx = threadIdx.x, ty = threadIdx.y; + const uint32_t bdy = blockDim.y; + + vec_t cos, sin; + if (bx * bdy + ty < nnz) { + const uint32_t idx = bx * bdy + ty; + const IdType pos = pos_ids[idx]; + const int half_rotary_dim = rotary_dim / 2; + + // 1. if interleave: + // - cos = cos_sin_cache[pos_id][tx * vec_size // 2] + // - sin = cos_sin_cache[pos_id][(rot_dim // 2) + tx * vec_size // 2] + // 2. if not interleave + // - cos = cos_cache[pos_id][(tx * vec_size) % (rot_dim // 2)] + // - sin = sin_cache[pos_id][(rot_dim // 2) + (tx * vec_size) % + // (rot_dim // 2)] + if (tx * vec_size < rotary_dim) { + int sin_offset = rotary_dim / 2; + int vec_idx; + if constexpr (interleave) { + vec_idx = (tx * vec_size) / 2; // Force integer division + } else { + vec_idx = (tx * vec_size) % half_rotary_dim; // Use half_rotary_dim + } + cos.load(cos_sin_cache + (pos * rotary_dim) + vec_idx); + sin.load(cos_sin_cache + (pos * rotary_dim) + (sin_offset + vec_idx)); + } - // not to unroll the loop, because num head might be large and might - // lead to worse performance + // not to unroll the loop, because num head might be large and might + // lead to worse performance #pragma unroll 1 - for (uint32_t qo_head_idx = 0; qo_head_idx < num_qo_heads; - ++qo_head_idx) - { - DType *q_ptr = q + get_elem_offset_impl(idx, qo_head_idx, 0, - q_stride_n, q_stride_h); - DType *q_rope_ptr = - q_rope + get_elem_offset_impl(idx, qo_head_idx, 0, - q_rope_stride_n, q_rope_stride_h); - vec_t q_vec; - if constexpr (interleave) { - q_vec = vec_apply_llama_rope_cos_sin_interleave_reuse_half< - vec_size, bdx>(q_ptr, cos, sin, rotary_dim); - } - else { - q_vec = vec_apply_llama_rope_cos_sin( - q_ptr, cos, sin, rotary_dim); - } - q_vec.cast_store(q_rope_ptr + tx * vec_size); - } + for (uint32_t qo_head_idx = 0; qo_head_idx < num_qo_heads; ++qo_head_idx) { + DType* q_ptr = q + get_elem_offset_impl(idx, qo_head_idx, 0, q_stride_n, q_stride_h); + DType* q_rope_ptr = + q_rope + get_elem_offset_impl(idx, qo_head_idx, 0, q_rope_stride_n, q_rope_stride_h); + vec_t q_vec; + if constexpr (interleave) { + q_vec = vec_apply_llama_rope_cos_sin_interleave_reuse_half(q_ptr, cos, sin, + rotary_dim); + } else { + q_vec = vec_apply_llama_rope_cos_sin(q_ptr, cos, sin, rotary_dim); + } + q_vec.cast_store(q_rope_ptr + tx * vec_size); + } #pragma unroll 1 - for (uint32_t kv_head_idx = 0; kv_head_idx < num_kv_heads; - ++kv_head_idx) - { - DType *k_ptr = k + get_elem_offset_impl(idx, kv_head_idx, 0, - k_stride_n, k_stride_h); - DType *k_rope_ptr = - k_rope + get_elem_offset_impl(idx, kv_head_idx, 0, - k_rope_stride_n, k_rope_stride_h); - vec_t k_vec; - if constexpr (interleave) { - k_vec = vec_apply_llama_rope_cos_sin_interleave_reuse_half< - vec_size, bdx>(k_ptr, cos, sin, rotary_dim); - } - else { - k_vec = vec_apply_llama_rope_cos_sin( - k_ptr, cos, sin, rotary_dim); - } - k_vec.cast_store(k_rope_ptr + tx * vec_size); - } + for (uint32_t kv_head_idx = 0; kv_head_idx < num_kv_heads; ++kv_head_idx) { + DType* k_ptr = k + get_elem_offset_impl(idx, kv_head_idx, 0, k_stride_n, k_stride_h); + DType* k_rope_ptr = + k_rope + get_elem_offset_impl(idx, kv_head_idx, 0, k_rope_stride_n, k_rope_stride_h); + vec_t k_vec; + if constexpr (interleave) { + k_vec = vec_apply_llama_rope_cos_sin_interleave_reuse_half(k_ptr, cos, sin, + rotary_dim); + } else { + k_vec = vec_apply_llama_rope_cos_sin(k_ptr, cos, sin, rotary_dim); + } + k_vec.cast_store(k_rope_ptr + tx * vec_size); } + } } -template -__global__ void -BatchQKApplyRotaryPosIdsHeadParallelismKernel(DType *q, - DType *k, - DType *q_rope, - DType *k_rope, - IdType *__restrict__ pos_ids, - uint32_t nnz, - uint32_t num_qo_heads, - uint32_t num_kv_heads, - uint32_t rotary_dim, - size_t q_stride_n, - size_t q_stride_h, - size_t k_stride_n, - size_t k_stride_h, - size_t q_rope_stride_n, - size_t q_rope_stride_h, - size_t k_rope_stride_n, - size_t k_rope_stride_h, - float smooth_a, - float smooth_b, - float rope_rcp_scale, - float rope_rcp_theta) -{ - // NOTE: q and q_rope may be the same ptr, so do k and k_rope - uint32_t bx = blockIdx.x, tx = threadIdx.x, ty = threadIdx.y; - uint32_t by = blockIdx.y; - const uint32_t bdy = blockDim.y; - vec_t freq; - if (tx * vec_size < rotary_dim) { +__global__ void BatchQKApplyRotaryPosIdsHeadParallelismKernel( + DType* q, DType* k, DType* q_rope, DType* k_rope, IdType* __restrict__ pos_ids, uint32_t nnz, + uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t rotary_dim, size_t q_stride_n, + size_t q_stride_h, size_t k_stride_n, size_t k_stride_h, size_t q_rope_stride_n, + size_t q_rope_stride_h, size_t k_rope_stride_n, size_t k_rope_stride_h, float smooth_a, + float smooth_b, float rope_rcp_scale, float rope_rcp_theta) { + // NOTE: q and q_rope may be the same ptr, so do k and k_rope + uint32_t bx = blockIdx.x, tx = threadIdx.x, ty = threadIdx.y; + uint32_t by = blockIdx.y; + const uint32_t bdy = blockDim.y; + vec_t freq; + if (tx * vec_size < rotary_dim) { #pragma unroll - for (uint32_t i = 0; i < vec_size; ++i) { - if constexpr (interleave) { - freq[i] = __powf(rope_rcp_theta, - float(2 * ((tx * vec_size + i) / 2)) / - float(rotary_dim)); - } - else { - freq[i] = - __powf(rope_rcp_theta, - float(2 * ((tx * vec_size + i) % (rotary_dim / 2))) / - float(rotary_dim)); - } - - float smooth = freq[i] * smooth_a + smooth_b; - smooth = max(0.0f, min(1.0f, smooth)); // clamp to [0, 1] - freq[i] = - (1 - smooth) * (freq[i] * rope_rcp_scale) + smooth * freq[i]; - } + for (uint32_t i = 0; i < vec_size; ++i) { + if constexpr (interleave) { + freq[i] = __powf(rope_rcp_theta, float(2 * ((tx * vec_size + i) / 2)) / float(rotary_dim)); + } else { + freq[i] = __powf(rope_rcp_theta, + float(2 * ((tx * vec_size + i) % (rotary_dim / 2))) / float(rotary_dim)); + } + + float smooth = freq[i] * smooth_a + smooth_b; + smooth = max(0.0f, min(1.0f, smooth)); // clamp to [0, 1] + freq[i] = (1 - smooth) * (freq[i] * rope_rcp_scale) + smooth * freq[i]; } + } - vec_t cos, sin; + vec_t cos, sin; - if (bx * bdy + ty < nnz) { - const uint32_t idx = bx * bdy + ty; - const IdType pos = pos_ids[idx]; + if (bx * bdy + ty < nnz) { + const uint32_t idx = bx * bdy + ty; + const IdType pos = pos_ids[idx]; - if (tx * vec_size < rotary_dim) { + if (tx * vec_size < rotary_dim) { #pragma unroll - for (uint32_t i = 0; i < vec_size; ++i) { - float embed = float(pos) * freq[i]; - __sincosf(embed, &sin[i], &cos[i]); - } - } + for (uint32_t i = 0; i < vec_size; ++i) { + float embed = float(pos) * freq[i]; + __sincosf(embed, &sin[i], &cos[i]); + } + } - if (by < num_qo_heads) { - uint32_t qo_head_idx = by; - DType *q_ptr = q + get_elem_offset_impl(idx, qo_head_idx, 0, - q_stride_n, q_stride_h); - DType *q_rope_ptr = - q_rope + get_elem_offset_impl(idx, qo_head_idx, 0, - q_rope_stride_n, q_rope_stride_h); - vec_t q_vec; - if constexpr (interleave) { - q_vec = vec_apply_llama_rope_cos_sin_interleave( - q_ptr, cos, sin, rotary_dim); - } - else { - q_vec = vec_apply_llama_rope_cos_sin( - q_ptr, cos, sin, rotary_dim); - } - q_vec.cast_store(q_rope_ptr + tx * vec_size); - } - else { - uint32_t kv_head_idx = by - num_qo_heads; - DType *k_ptr = k + get_elem_offset_impl(idx, kv_head_idx, 0, - k_stride_n, k_stride_h); - DType *k_rope_ptr = - k_rope + get_elem_offset_impl(idx, kv_head_idx, 0, - k_rope_stride_n, k_rope_stride_h); - vec_t k_vec; - if constexpr (interleave) { - k_vec = vec_apply_llama_rope_cos_sin_interleave( - k_ptr, cos, sin, rotary_dim); - } - else { - k_vec = vec_apply_llama_rope_cos_sin( - k_ptr, cos, sin, rotary_dim); - } - k_vec.cast_store(k_rope_ptr + tx * vec_size); - } + if (by < num_qo_heads) { + uint32_t qo_head_idx = by; + DType* q_ptr = q + get_elem_offset_impl(idx, qo_head_idx, 0, q_stride_n, q_stride_h); + DType* q_rope_ptr = + q_rope + get_elem_offset_impl(idx, qo_head_idx, 0, q_rope_stride_n, q_rope_stride_h); + vec_t q_vec; + if constexpr (interleave) { + q_vec = vec_apply_llama_rope_cos_sin_interleave(q_ptr, cos, sin, rotary_dim); + } else { + q_vec = vec_apply_llama_rope_cos_sin(q_ptr, cos, sin, rotary_dim); + } + q_vec.cast_store(q_rope_ptr + tx * vec_size); + } else { + uint32_t kv_head_idx = by - num_qo_heads; + DType* k_ptr = k + get_elem_offset_impl(idx, kv_head_idx, 0, k_stride_n, k_stride_h); + DType* k_rope_ptr = + k_rope + get_elem_offset_impl(idx, kv_head_idx, 0, k_rope_stride_n, k_rope_stride_h); + vec_t k_vec; + if constexpr (interleave) { + k_vec = vec_apply_llama_rope_cos_sin_interleave(k_ptr, cos, sin, rotary_dim); + } else { + k_vec = vec_apply_llama_rope_cos_sin(k_ptr, cos, sin, rotary_dim); + } + k_vec.cast_store(k_rope_ptr + tx * vec_size); } + } } -template -__global__ void BatchQKApplyRotaryPosIdsKernel(DType *q, - DType *k, - DType *q_rope, - DType *k_rope, - IdType *__restrict__ pos_ids, - uint32_t nnz, - uint32_t num_qo_heads, - uint32_t num_kv_heads, - uint32_t rotary_dim, - size_t q_stride_n, - size_t q_stride_h, - size_t k_stride_n, - size_t k_stride_h, - size_t q_rope_stride_n, - size_t q_rope_stride_h, - size_t k_rope_stride_n, - size_t k_rope_stride_h, - float smooth_a, - float smooth_b, - float rope_rcp_scale, - float rope_rcp_theta) -{ - // NOTE: q and q_rope may be the same ptr, so do k and k_rope - uint32_t bx = blockIdx.x, tx = threadIdx.x, ty = threadIdx.y; - const uint32_t bdy = blockDim.y; - vec_t freq; - if (tx * vec_size < rotary_dim) { +__global__ void BatchQKApplyRotaryPosIdsKernel( + DType* q, DType* k, DType* q_rope, DType* k_rope, IdType* __restrict__ pos_ids, uint32_t nnz, + uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t rotary_dim, size_t q_stride_n, + size_t q_stride_h, size_t k_stride_n, size_t k_stride_h, size_t q_rope_stride_n, + size_t q_rope_stride_h, size_t k_rope_stride_n, size_t k_rope_stride_h, float smooth_a, + float smooth_b, float rope_rcp_scale, float rope_rcp_theta) { + // NOTE: q and q_rope may be the same ptr, so do k and k_rope + uint32_t bx = blockIdx.x, tx = threadIdx.x, ty = threadIdx.y; + const uint32_t bdy = blockDim.y; + vec_t freq; + if (tx * vec_size < rotary_dim) { #pragma unroll - for (uint32_t i = 0; i < vec_size; ++i) { - if constexpr (interleave) { - freq[i] = __powf(rope_rcp_theta, - float(2 * ((tx * vec_size + i) / 2)) / - float(rotary_dim)); - } - else { - freq[i] = - __powf(rope_rcp_theta, - float(2 * ((tx * vec_size + i) % (rotary_dim / 2))) / - float(rotary_dim)); - } - - float smooth = freq[i] * smooth_a + smooth_b; - smooth = max(0.0f, min(1.0f, smooth)); // clamp to [0, 1] - freq[i] = - (1 - smooth) * (freq[i] * rope_rcp_scale) + smooth * freq[i]; - } + for (uint32_t i = 0; i < vec_size; ++i) { + if constexpr (interleave) { + freq[i] = __powf(rope_rcp_theta, float(2 * ((tx * vec_size + i) / 2)) / float(rotary_dim)); + } else { + freq[i] = __powf(rope_rcp_theta, + float(2 * ((tx * vec_size + i) % (rotary_dim / 2))) / float(rotary_dim)); + } + + float smooth = freq[i] * smooth_a + smooth_b; + smooth = max(0.0f, min(1.0f, smooth)); // clamp to [0, 1] + freq[i] = (1 - smooth) * (freq[i] * rope_rcp_scale) + smooth * freq[i]; } + } - vec_t cos, sin; + vec_t cos, sin; - if (bx * bdy + ty < nnz) { - const uint32_t idx = bx * bdy + ty; - const IdType pos = pos_ids[idx]; + if (bx * bdy + ty < nnz) { + const uint32_t idx = bx * bdy + ty; + const IdType pos = pos_ids[idx]; - if (tx * vec_size < rotary_dim) { + if (tx * vec_size < rotary_dim) { #pragma unroll - for (uint32_t i = 0; i < vec_size; ++i) { - float embed = float(pos) * freq[i]; - __sincosf(embed, &sin[i], &cos[i]); - } - } + for (uint32_t i = 0; i < vec_size; ++i) { + float embed = float(pos) * freq[i]; + __sincosf(embed, &sin[i], &cos[i]); + } + } #pragma unroll 1 - for (uint32_t qo_head_idx = 0; qo_head_idx < num_qo_heads; - ++qo_head_idx) - { - DType *q_ptr = q + get_elem_offset_impl(idx, qo_head_idx, 0, - q_stride_n, q_stride_h); - DType *q_rope_ptr = - q_rope + get_elem_offset_impl(idx, qo_head_idx, 0, - q_rope_stride_n, q_rope_stride_h); - vec_t q_vec; - if constexpr (interleave) { - q_vec = vec_apply_llama_rope_cos_sin_interleave( - q_ptr, cos, sin, rotary_dim); - } - else { - q_vec = vec_apply_llama_rope_cos_sin( - q_ptr, cos, sin, rotary_dim); - } - q_vec.cast_store(q_rope_ptr + tx * vec_size); - } + for (uint32_t qo_head_idx = 0; qo_head_idx < num_qo_heads; ++qo_head_idx) { + DType* q_ptr = q + get_elem_offset_impl(idx, qo_head_idx, 0, q_stride_n, q_stride_h); + DType* q_rope_ptr = + q_rope + get_elem_offset_impl(idx, qo_head_idx, 0, q_rope_stride_n, q_rope_stride_h); + vec_t q_vec; + if constexpr (interleave) { + q_vec = vec_apply_llama_rope_cos_sin_interleave(q_ptr, cos, sin, rotary_dim); + } else { + q_vec = vec_apply_llama_rope_cos_sin(q_ptr, cos, sin, rotary_dim); + } + q_vec.cast_store(q_rope_ptr + tx * vec_size); + } #pragma unroll 1 - for (uint32_t kv_head_idx = 0; kv_head_idx < num_kv_heads; - ++kv_head_idx) - { - DType *k_ptr = k + get_elem_offset_impl(idx, kv_head_idx, 0, - k_stride_n, k_stride_h); - DType *k_rope_ptr = - k_rope + get_elem_offset_impl(idx, kv_head_idx, 0, - k_rope_stride_n, k_rope_stride_h); - vec_t k_vec; - if constexpr (interleave) { - k_vec = vec_apply_llama_rope_cos_sin_interleave( - k_ptr, cos, sin, rotary_dim); - } - else { - k_vec = vec_apply_llama_rope_cos_sin( - k_ptr, cos, sin, rotary_dim); - } - k_vec.cast_store(k_rope_ptr + tx * vec_size); - } + for (uint32_t kv_head_idx = 0; kv_head_idx < num_kv_heads; ++kv_head_idx) { + DType* k_ptr = k + get_elem_offset_impl(idx, kv_head_idx, 0, k_stride_n, k_stride_h); + DType* k_rope_ptr = + k_rope + get_elem_offset_impl(idx, kv_head_idx, 0, k_rope_stride_n, k_rope_stride_h); + vec_t k_vec; + if constexpr (interleave) { + k_vec = vec_apply_llama_rope_cos_sin_interleave(k_ptr, cos, sin, rotary_dim); + } else { + k_vec = vec_apply_llama_rope_cos_sin(k_ptr, cos, sin, rotary_dim); + } + k_vec.cast_store(k_rope_ptr + tx * vec_size); } + } } -template -__global__ void BatchQKApplyRotaryKernel(DType *q, - DType *k, - DType *q_rope, - DType *k_rope, - IdType *__restrict__ indptr, - IdType *__restrict__ offsets, - uint32_t batch_size, - uint32_t num_qo_heads, - uint32_t num_kv_heads, - uint32_t rotary_dim, - size_t q_stride_n, - size_t q_stride_h, - size_t k_stride_n, - size_t k_stride_h, - size_t q_rope_stride_n, - size_t q_rope_stride_h, - size_t k_rope_stride_n, - size_t k_rope_stride_h, - float smooth_a, - float smooth_b, - float rope_rcp_scale, - float rope_rcp_theta) -{ - uint32_t bx = blockIdx.x, tx = threadIdx.x, ty = threadIdx.y; - const uint32_t bdy = blockDim.y; - vec_t freq; - if (tx * vec_size < rotary_dim) { +__global__ void BatchQKApplyRotaryKernel( + DType* q, DType* k, DType* q_rope, DType* k_rope, IdType* __restrict__ indptr, + IdType* __restrict__ offsets, uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads, + uint32_t rotary_dim, size_t q_stride_n, size_t q_stride_h, size_t k_stride_n, size_t k_stride_h, + size_t q_rope_stride_n, size_t q_rope_stride_h, size_t k_rope_stride_n, size_t k_rope_stride_h, + float smooth_a, float smooth_b, float rope_rcp_scale, float rope_rcp_theta) { + uint32_t bx = blockIdx.x, tx = threadIdx.x, ty = threadIdx.y; + const uint32_t bdy = blockDim.y; + vec_t freq; + if (tx * vec_size < rotary_dim) { #pragma unroll - for (uint32_t i = 0; i < vec_size; ++i) { - if constexpr (interleave) { - freq[i] = __powf(rope_rcp_theta, - float(2 * ((tx * vec_size + i) / 2)) / - float(rotary_dim)); - } - else { - freq[i] = - __powf(rope_rcp_theta, - float(2 * ((tx * vec_size + i) % (rotary_dim / 2))) / - float(rotary_dim)); - } - - float smooth = freq[i] * smooth_a + smooth_b; - smooth = max(0.0f, min(1.0f, smooth)); // clamp to [0, 1] - freq[i] = - (1 - smooth) * (freq[i] * rope_rcp_scale) + smooth * freq[i]; - } - } + for (uint32_t i = 0; i < vec_size; ++i) { + if constexpr (interleave) { + freq[i] = __powf(rope_rcp_theta, float(2 * ((tx * vec_size + i) / 2)) / float(rotary_dim)); + } else { + freq[i] = __powf(rope_rcp_theta, + float(2 * ((tx * vec_size + i) % (rotary_dim / 2))) / float(rotary_dim)); + } - if (bx < batch_size * num_qo_heads) { - // apply rotary to q - const uint32_t batch_idx = bx / num_qo_heads; - const uint32_t qo_head_idx = bx % num_qo_heads; - const uint32_t seq_len = indptr[batch_idx + 1] - indptr[batch_idx]; - const uint32_t offset = offsets[batch_idx]; + float smooth = freq[i] * smooth_a + smooth_b; + smooth = max(0.0f, min(1.0f, smooth)); // clamp to [0, 1] + freq[i] = (1 - smooth) * (freq[i] * rope_rcp_scale) + smooth * freq[i]; + } + } + + if (bx < batch_size * num_qo_heads) { + // apply rotary to q + const uint32_t batch_idx = bx / num_qo_heads; + const uint32_t qo_head_idx = bx % num_qo_heads; + const uint32_t seq_len = indptr[batch_idx + 1] - indptr[batch_idx]; + const uint32_t offset = offsets[batch_idx]; #pragma unroll 2 - for (uint32_t i = 0; i < (seq_len + bdy - 1) / bdy; ++i) { - vec_t q_vec; - if (i * bdy + ty < seq_len) { - DType *q_ptr = q + get_elem_offset_impl( - indptr[batch_idx] + i * bdy + ty, - qo_head_idx, 0, q_stride_n, q_stride_h); - DType *q_rope_ptr = - q_rope + get_elem_offset_impl( - indptr[batch_idx] + i * bdy + ty, qo_head_idx, - 0, q_rope_stride_n, q_rope_stride_h); - if constexpr (interleave) { - q_vec = vec_apply_llama_rope_interleave( - q_ptr, freq, offset + i * bdy + ty, rotary_dim); - } - else { - q_vec = vec_apply_llama_rope( - q_ptr, freq, offset + i * bdy + ty, rotary_dim); - } - q_vec.cast_store(q_rope_ptr + tx * vec_size); - } + for (uint32_t i = 0; i < (seq_len + bdy - 1) / bdy; ++i) { + vec_t q_vec; + if (i * bdy + ty < seq_len) { + DType* q_ptr = q + get_elem_offset_impl(indptr[batch_idx] + i * bdy + ty, qo_head_idx, 0, + q_stride_n, q_stride_h); + DType* q_rope_ptr = + q_rope + get_elem_offset_impl(indptr[batch_idx] + i * bdy + ty, qo_head_idx, 0, + q_rope_stride_n, q_rope_stride_h); + if constexpr (interleave) { + q_vec = vec_apply_llama_rope_interleave(q_ptr, freq, offset + i * bdy + ty, + rotary_dim); + } else { + q_vec = + vec_apply_llama_rope(q_ptr, freq, offset + i * bdy + ty, rotary_dim); } + q_vec.cast_store(q_rope_ptr + tx * vec_size); + } } - else { - // apply rotary to k - uint32_t batch_idx = (bx - batch_size * num_qo_heads) / num_kv_heads; - uint32_t kv_head_idx = (bx - batch_size * num_qo_heads) % num_kv_heads; - const uint32_t seq_len = indptr[batch_idx + 1] - indptr[batch_idx]; - const uint32_t offset = offsets[batch_idx]; + } else { + // apply rotary to k + uint32_t batch_idx = (bx - batch_size * num_qo_heads) / num_kv_heads; + uint32_t kv_head_idx = (bx - batch_size * num_qo_heads) % num_kv_heads; + const uint32_t seq_len = indptr[batch_idx + 1] - indptr[batch_idx]; + const uint32_t offset = offsets[batch_idx]; #pragma unroll 2 - for (uint32_t i = 0; i < (seq_len + bdy - 1) / bdy; ++i) { - vec_t k_vec; - if (i * bdy + ty < seq_len) { - DType *k_ptr = k + get_elem_offset_impl( - indptr[batch_idx] + i * bdy + ty, - kv_head_idx, 0, k_stride_n, k_stride_h); - DType *k_rope_ptr = - k_rope + get_elem_offset_impl( - indptr[batch_idx] + i * bdy + ty, kv_head_idx, - 0, k_rope_stride_n, k_rope_stride_h); - if constexpr (interleave) { - k_vec = vec_apply_llama_rope_interleave( - k_ptr, freq, offset + i * bdy + ty, rotary_dim); - } - else { - k_vec = vec_apply_llama_rope( - k_ptr, freq, offset + i * bdy + ty, rotary_dim); - } - k_vec.cast_store(k_rope_ptr + tx * vec_size); - } + for (uint32_t i = 0; i < (seq_len + bdy - 1) / bdy; ++i) { + vec_t k_vec; + if (i * bdy + ty < seq_len) { + DType* k_ptr = k + get_elem_offset_impl(indptr[batch_idx] + i * bdy + ty, kv_head_idx, 0, + k_stride_n, k_stride_h); + DType* k_rope_ptr = + k_rope + get_elem_offset_impl(indptr[batch_idx] + i * bdy + ty, kv_head_idx, 0, + k_rope_stride_n, k_rope_stride_h); + if constexpr (interleave) { + k_vec = vec_apply_llama_rope_interleave(k_ptr, freq, offset + i * bdy + ty, + rotary_dim); + } else { + k_vec = + vec_apply_llama_rope(k_ptr, freq, offset + i * bdy + ty, rotary_dim); } + k_vec.cast_store(k_rope_ptr + tx * vec_size); + } } + } } -#define DISPATCH_INTERLEAVE(interleave, INTERLEAVE, ...) \ - if (interleave) { \ - const bool INTERLEAVE = true; \ - __VA_ARGS__ \ - } \ - else { \ - const bool INTERLEAVE = false; \ - __VA_ARGS__ \ - } +#define DISPATCH_INTERLEAVE(interleave, INTERLEAVE, ...) \ + if (interleave) { \ + const bool INTERLEAVE = true; \ + __VA_ARGS__ \ + } else { \ + const bool INTERLEAVE = false; \ + __VA_ARGS__ \ + } template -hipError_t BatchQKApplyRotaryPosIdsCosSinCache(DType *q, - DType *k, - DType *q_rope, - DType *k_rope, - float *cos_sin_cache, - IdType *pos_ids, - uint32_t nnz, - uint32_t num_qo_heads, - uint32_t num_kv_heads, - uint32_t rotary_dim, - uint32_t head_dim, - size_t q_stride_n, - size_t q_stride_h, - size_t k_stride_n, - size_t k_stride_h, - size_t q_rope_stride_n, - size_t q_rope_stride_h, - size_t k_rope_stride_n, - size_t k_rope_stride_h, - bool interleave, - hipStream_t stream = nullptr) -{ - int dev_id = 0; - int num_sms = 0; - FLASHINFER_CUDA_CALL(hipGetDevice(&dev_id)); - FLASHINFER_CUDA_CALL(hipDeviceGetAttribute( - &num_sms, hipDeviceAttributeMultiprocessorCount, dev_id)); - - DISPATCH_INTERLEAVE(interleave, INTERLEAVE, { - DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { - // operate on 16 Bytes at a time - constexpr uint32_t vec_size = - std::max(16 / sizeof(DType), HEAD_DIM / 32); - // how many threads needed per head_dim - constexpr uint32_t bdx = HEAD_DIM / vec_size; - // how many threads needed per block - uint32_t num_threads = std::max(128U, bdx); - // how many tokens can we process in a block - uint32_t bdy = num_threads / bdx; - // how many blocks needed to process all tokens - uint32_t nblks_x = (nnz + bdy - 1) / bdy; - void *args[] = {(void *)&q, - (void *)&k, - (void *)&q_rope, - (void *)&k_rope, - (void *)&cos_sin_cache, - (void *)&pos_ids, - (void *)&nnz, - (void *)&num_qo_heads, - (void *)&num_kv_heads, - (void *)&rotary_dim, - (void *)&q_stride_n, - (void *)&q_stride_h, - (void *)&k_stride_n, - (void *)&k_stride_h, - (void *)&q_rope_stride_n, - (void *)&q_rope_stride_h, - (void *)&k_rope_stride_n, - (void *)&k_rope_stride_h}; - auto kernel_0 = BatchQKApplyRotaryPosIdsCosSinCacheKernel< - INTERLEAVE, HEAD_DIM, vec_size, bdx, DType, IdType>; - - int num_blocks_per_sm_0 = 0; - FLASHINFER_CUDA_CALL(hipOccupancyMaxActiveBlocksPerMultiprocessor( - &num_blocks_per_sm_0, kernel_0, num_threads, /*smem_size=*/0)); - uint32_t num_ctas_0 = num_blocks_per_sm_0 * num_sms; - - if ((nnz + bdy - 1) / bdy >= num_ctas_0) { - dim3 nblks(nblks_x); - dim3 nthrs(bdx, bdy); - BatchQKApplyRotaryPosIdsCosSinCacheKernel< - INTERLEAVE, HEAD_DIM, vec_size, bdx, DType, IdType><<>>( - q, k, q_rope, k_rope, cos_sin_cache, pos_ids, nnz, - num_qo_heads, num_kv_heads, rotary_dim, q_stride_n, - q_stride_h, k_stride_n, k_stride_h, q_rope_stride_n, - q_rope_stride_h, k_rope_stride_n, k_rope_stride_h); - } - else { - dim3 nblks(nblks_x, num_qo_heads + num_kv_heads); - dim3 nthrs(bdx, bdy); - auto kernel_1 = - BatchQKApplyRotaryPosIdsCosSinCacheHeadParallelismKernel< - INTERLEAVE, HEAD_DIM, vec_size, bdx, DType, IdType>; - BatchQKApplyRotaryPosIdsCosSinCacheHeadParallelismKernel< - INTERLEAVE, HEAD_DIM, vec_size, bdx, DType, IdType><<>>( - q, k, q_rope, k_rope, cos_sin_cache, pos_ids, nnz, - num_qo_heads, num_kv_heads, rotary_dim, q_stride_n, - q_stride_h, k_stride_n, k_stride_h, q_rope_stride_n, - q_rope_stride_h, k_rope_stride_n, k_rope_stride_h); - } - }); +hipError_t BatchQKApplyRotaryPosIdsCosSinCache( + DType* q, DType* k, DType* q_rope, DType* k_rope, float* cos_sin_cache, IdType* pos_ids, + uint32_t nnz, uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t rotary_dim, + uint32_t head_dim, size_t q_stride_n, size_t q_stride_h, size_t k_stride_n, size_t k_stride_h, + size_t q_rope_stride_n, size_t q_rope_stride_h, size_t k_rope_stride_n, size_t k_rope_stride_h, + bool interleave, hipStream_t stream = nullptr) { + int dev_id = 0; + int num_sms = 0; + FLASHINFER_CUDA_CALL(hipGetDevice(&dev_id)); + FLASHINFER_CUDA_CALL( + hipDeviceGetAttribute(&num_sms, hipDeviceAttributeMultiprocessorCount, dev_id)); + + DISPATCH_INTERLEAVE(interleave, INTERLEAVE, { + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + // operate on 16 Bytes at a time + constexpr uint32_t vec_size = std::max(16 / sizeof(DType), HEAD_DIM / 32); + // how many threads needed per head_dim + constexpr uint32_t bdx = HEAD_DIM / vec_size; + // how many threads needed per block + uint32_t num_threads = std::max(128U, bdx); + // how many tokens can we process in a block + uint32_t bdy = num_threads / bdx; + // how many blocks needed to process all tokens + uint32_t nblks_x = (nnz + bdy - 1) / bdy; + void* args[] = {(void*)&q, + (void*)&k, + (void*)&q_rope, + (void*)&k_rope, + (void*)&cos_sin_cache, + (void*)&pos_ids, + (void*)&nnz, + (void*)&num_qo_heads, + (void*)&num_kv_heads, + (void*)&rotary_dim, + (void*)&q_stride_n, + (void*)&q_stride_h, + (void*)&k_stride_n, + (void*)&k_stride_h, + (void*)&q_rope_stride_n, + (void*)&q_rope_stride_h, + (void*)&k_rope_stride_n, + (void*)&k_rope_stride_h}; + auto kernel_0 = BatchQKApplyRotaryPosIdsCosSinCacheKernel; + + int num_blocks_per_sm_0 = 0; + FLASHINFER_CUDA_CALL(hipOccupancyMaxActiveBlocksPerMultiprocessor( + &num_blocks_per_sm_0, kernel_0, num_threads, /*smem_size=*/0)); + uint32_t num_ctas_0 = num_blocks_per_sm_0 * num_sms; + + if ((nnz + bdy - 1) / bdy >= num_ctas_0) { + dim3 nblks(nblks_x); + dim3 nthrs(bdx, bdy); + BatchQKApplyRotaryPosIdsCosSinCacheKernel<<>>( + q, k, q_rope, k_rope, cos_sin_cache, pos_ids, nnz, num_qo_heads, num_kv_heads, + rotary_dim, q_stride_n, q_stride_h, k_stride_n, k_stride_h, q_rope_stride_n, + q_rope_stride_h, k_rope_stride_n, k_rope_stride_h); + } else { + dim3 nblks(nblks_x, num_qo_heads + num_kv_heads); + dim3 nthrs(bdx, bdy); + auto kernel_1 = + BatchQKApplyRotaryPosIdsCosSinCacheHeadParallelismKernel; + BatchQKApplyRotaryPosIdsCosSinCacheHeadParallelismKernel + <<>>(q, k, q_rope, k_rope, cos_sin_cache, pos_ids, nnz, + num_qo_heads, num_kv_heads, rotary_dim, q_stride_n, + q_stride_h, k_stride_n, k_stride_h, q_rope_stride_n, + q_rope_stride_h, k_rope_stride_n, k_rope_stride_h); + } }); + }); - return hipSuccess; + return hipSuccess; } template @@ -884,8 +655,7 @@ hipError_t BatchQKApplyRotaryPosIds( uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t rotary_dim, uint32_t head_dim, size_t q_stride_n, size_t q_stride_h, size_t k_stride_n, size_t k_stride_h, size_t q_rope_stride_n, size_t q_rope_stride_h, size_t k_rope_stride_n, size_t k_rope_stride_h, - bool interleave, float rope_scale, float rope_theta, hipStream_t stream = nullptr) - { + bool interleave, float rope_scale, float rope_theta, hipStream_t stream = nullptr) { float rope_rcp_scale = 1.0f / rope_scale; float rope_rcp_theta = 1.0f / rope_theta; float smooth_a = 0.f; @@ -893,7 +663,8 @@ hipError_t BatchQKApplyRotaryPosIds( int dev_id = 0; int num_sms = 0; FLASHINFER_CUDA_CALL(hipGetDevice(&dev_id)); - FLASHINFER_CUDA_CALL(hipDeviceGetAttribute(&num_sms, hipDeviceAttributeMultiprocessorCount, dev_id)); + FLASHINFER_CUDA_CALL( + hipDeviceGetAttribute(&num_sms, hipDeviceAttributeMultiprocessorCount, dev_id)); DISPATCH_INTERLEAVE(interleave, INTERLEAVE, { DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { @@ -936,8 +707,12 @@ hipError_t BatchQKApplyRotaryPosIds( dim3 nthrs(bdx, bdy); // FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel_0, nblks, nthrs, args, 0, stream)); - BatchQKApplyRotaryPosIdsKernel<<>>( - q, k, q_rope, k_rope, pos_ids, nnz, num_qo_heads, num_kv_heads, rotary_dim, q_stride_n, q_stride_h, k_stride_n, k_stride_h, q_rope_stride_n, q_rope_stride_h, k_rope_stride_n, k_rope_stride_h, smooth_a, smooth_b, rope_rcp_scale, rope_rcp_theta); + BatchQKApplyRotaryPosIdsKernel + <<>>(q, k, q_rope, k_rope, pos_ids, nnz, num_qo_heads, + num_kv_heads, rotary_dim, q_stride_n, q_stride_h, + k_stride_n, k_stride_h, q_rope_stride_n, q_rope_stride_h, + k_rope_stride_n, k_rope_stride_h, smooth_a, smooth_b, + rope_rcp_scale, rope_rcp_theta); } else { dim3 nblks(nblks_x, num_qo_heads + num_kv_heads); dim3 nthrs(bdx, bdy); @@ -945,8 +720,11 @@ hipError_t BatchQKApplyRotaryPosIds( vec_size, bdx, DType, IdType>; // FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel_1, nblks, nthrs, args, 0, stream)); - BatchQKApplyRotaryPosIdsHeadParallelismKernel<<>>(q, k, q_rope, k_rope, pos_ids, nnz, num_qo_heads, num_kv_heads, rotary_dim, q_stride_n, q_stride_h, k_stride_n, k_stride_h, q_rope_stride_n, q_rope_stride_h, k_rope_stride_n, k_rope_stride_h, smooth_a, smooth_b, rope_rcp_scale, rope_rcp_theta); + BatchQKApplyRotaryPosIdsHeadParallelismKernel<<>>( + q, k, q_rope, k_rope, pos_ids, nnz, num_qo_heads, num_kv_heads, rotary_dim, q_stride_n, + q_stride_h, k_stride_n, k_stride_h, q_rope_stride_n, q_rope_stride_h, k_rope_stride_n, + k_rope_stride_h, smooth_a, smooth_b, rope_rcp_scale, rope_rcp_theta); } }); }); @@ -955,271 +733,189 @@ hipError_t BatchQKApplyRotaryPosIds( } template -hipError_t BatchQKApplyRotary(DType *q, - DType *k, - DType *q_rope, - DType *k_rope, - IdType *__restrict__ indptr, - IdType *__restrict__ offsets, - uint32_t batch_size, - uint32_t num_qo_heads, - uint32_t num_kv_heads, - uint32_t rotary_dim, - uint32_t head_dim, - size_t q_stride_n, - size_t q_stride_h, - size_t k_stride_n, - size_t k_stride_h, - size_t q_rope_stride_n, - size_t q_rope_stride_h, - size_t k_rope_stride_n, - size_t k_rope_stride_h, - bool interleave, - float rope_scale, - float rope_theta, - hipStream_t stream = nullptr) -{ - float rope_rcp_scale = 1.0f / rope_scale; - float rope_rcp_theta = 1.0f / rope_theta; - float smooth_a = 0.f; - float smooth_b = 0.f; - - DISPATCH_INTERLEAVE(interleave, INTERLEAVE, { - DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { - constexpr uint32_t vec_size = - std::max(16 / sizeof(DType), HEAD_DIM / 32); - constexpr uint32_t bdx = HEAD_DIM / vec_size; - uint32_t num_threads = std::max(128U, bdx); - uint32_t bdy = num_threads / bdx; - dim3 nblks(batch_size * (num_qo_heads + num_kv_heads)); - dim3 nthrs(bdx, bdy); - auto kernel = - BatchQKApplyRotaryKernel; - void *args[] = {(void *)&q, - (void *)&k, - (void *)&q_rope, - (void *)&k_rope, - (void *)&indptr, - (void *)&offsets, - (void *)&batch_size, - (void *)&num_qo_heads, - (void *)&num_kv_heads, - (void *)&rotary_dim, - (void *)&q_stride_n, - (void *)&q_stride_h, - (void *)&k_stride_n, - (void *)&k_stride_h, - (void *)&q_rope_stride_n, - (void *)&q_rope_stride_h, - (void *)&k_rope_stride_n, - (void *)&k_rope_stride_h, - (void *)&smooth_a, - (void *)&smooth_b, - (void *)&rope_rcp_scale, - (void *)&rope_rcp_theta}; - // FLASHINFER_CUDA_CALL(hipLaunchKernelGGL((void*)kernel, nblks, - // nthrs, args, 0, stream)); - kernel<<>>( - q, k, q_rope, k_rope, indptr, offsets, batch_size, num_qo_heads, - num_kv_heads, rotary_dim, q_stride_n, q_stride_h, k_stride_n, - k_stride_h, q_rope_stride_n, q_rope_stride_h, k_rope_stride_n, - k_rope_stride_h, smooth_a, smooth_b, rope_rcp_scale, - rope_rcp_theta); - }); +hipError_t BatchQKApplyRotary(DType* q, DType* k, DType* q_rope, DType* k_rope, + IdType* __restrict__ indptr, IdType* __restrict__ offsets, + uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads, + uint32_t rotary_dim, uint32_t head_dim, size_t q_stride_n, + size_t q_stride_h, size_t k_stride_n, size_t k_stride_h, + size_t q_rope_stride_n, size_t q_rope_stride_h, + size_t k_rope_stride_n, size_t k_rope_stride_h, bool interleave, + float rope_scale, float rope_theta, hipStream_t stream = nullptr) { + float rope_rcp_scale = 1.0f / rope_scale; + float rope_rcp_theta = 1.0f / rope_theta; + float smooth_a = 0.f; + float smooth_b = 0.f; + + DISPATCH_INTERLEAVE(interleave, INTERLEAVE, { + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + constexpr uint32_t vec_size = std::max(16 / sizeof(DType), HEAD_DIM / 32); + constexpr uint32_t bdx = HEAD_DIM / vec_size; + uint32_t num_threads = std::max(128U, bdx); + uint32_t bdy = num_threads / bdx; + dim3 nblks(batch_size * (num_qo_heads + num_kv_heads)); + dim3 nthrs(bdx, bdy); + auto kernel = BatchQKApplyRotaryKernel; + void* args[] = {(void*)&q, + (void*)&k, + (void*)&q_rope, + (void*)&k_rope, + (void*)&indptr, + (void*)&offsets, + (void*)&batch_size, + (void*)&num_qo_heads, + (void*)&num_kv_heads, + (void*)&rotary_dim, + (void*)&q_stride_n, + (void*)&q_stride_h, + (void*)&k_stride_n, + (void*)&k_stride_h, + (void*)&q_rope_stride_n, + (void*)&q_rope_stride_h, + (void*)&k_rope_stride_n, + (void*)&k_rope_stride_h, + (void*)&smooth_a, + (void*)&smooth_b, + (void*)&rope_rcp_scale, + (void*)&rope_rcp_theta}; + // FLASHINFER_CUDA_CALL(hipLaunchKernelGGL((void*)kernel, nblks, + // nthrs, args, 0, stream)); + kernel<<>>( + q, k, q_rope, k_rope, indptr, offsets, batch_size, num_qo_heads, num_kv_heads, rotary_dim, + q_stride_n, q_stride_h, k_stride_n, k_stride_h, q_rope_stride_n, q_rope_stride_h, + k_rope_stride_n, k_rope_stride_h, smooth_a, smooth_b, rope_rcp_scale, rope_rcp_theta); }); + }); - return hipSuccess; + return hipSuccess; } template -hipError_t BatchQKApplyRotaryInPlace(DType *__restrict__ q, - DType *__restrict__ k, - IdType *__restrict__ indptr, - IdType *__restrict__ offsets, - uint32_t batch_size, - uint32_t num_qo_heads, - uint32_t num_kv_heads, - uint32_t rotary_dim, - uint32_t head_dim, - size_t q_stride_n, - size_t q_stride_h, - size_t k_stride_n, - size_t k_stride_h, - bool interleave, - float rope_scale, - float rope_theta, - hipStream_t stream = nullptr) -{ - return BatchQKApplyRotary( - q, k, q, k, indptr, offsets, batch_size, num_qo_heads, num_kv_heads, - rotary_dim, head_dim, q_stride_n, q_stride_h, k_stride_n, k_stride_h, - q_stride_n, q_stride_h, k_stride_n, k_stride_h, interleave, rope_scale, - rope_theta, stream); +hipError_t BatchQKApplyRotaryInPlace(DType* __restrict__ q, DType* __restrict__ k, + IdType* __restrict__ indptr, IdType* __restrict__ offsets, + uint32_t batch_size, uint32_t num_qo_heads, + uint32_t num_kv_heads, uint32_t rotary_dim, uint32_t head_dim, + size_t q_stride_n, size_t q_stride_h, size_t k_stride_n, + size_t k_stride_h, bool interleave, float rope_scale, + float rope_theta, hipStream_t stream = nullptr) { + return BatchQKApplyRotary( + q, k, q, k, indptr, offsets, batch_size, num_qo_heads, num_kv_heads, rotary_dim, head_dim, + q_stride_n, q_stride_h, k_stride_n, k_stride_h, q_stride_n, q_stride_h, k_stride_n, + k_stride_h, interleave, rope_scale, rope_theta, stream); } template -hipError_t BatchQKApplyLlama31Rotary(DType *q, - DType *k, - DType *q_rope, - DType *k_rope, - IdType *__restrict__ indptr, - IdType *__restrict__ offsets, - uint32_t batch_size, - uint32_t num_qo_heads, - uint32_t num_kv_heads, - uint32_t rotary_dim, - uint32_t head_dim, - size_t q_stride_n, - size_t q_stride_h, - size_t k_stride_n, - size_t k_stride_h, - size_t q_rope_stride_n, - size_t q_rope_stride_h, - size_t k_rope_stride_n, - size_t k_rope_stride_h, - bool interleave, - float rope_scale, - float rope_theta, - float low_freq_factor, - float high_freq_factor, - float old_context_length, - hipStream_t stream = nullptr) -{ - float rope_rcp_scale = 1.0f / rope_scale; - float rope_rcp_theta = 1.0f / rope_theta; - float smooth_a = old_context_length / - (2 * M_PI * high_freq_factor - 2 * M_PI * low_freq_factor); - float smooth_b = -1.0f / (high_freq_factor / low_freq_factor - 1.0f); - - DISPATCH_INTERLEAVE(interleave, INTERLEAVE, { - DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { - constexpr uint32_t vec_size = - std::max(16 / sizeof(DType), HEAD_DIM / 32); - constexpr uint32_t bdx = HEAD_DIM / vec_size; - uint32_t num_threads = std::max(128U, bdx); - uint32_t bdy = num_threads / bdx; - dim3 nblks(batch_size * (num_qo_heads + num_kv_heads)); - dim3 nthrs(bdx, bdy); - auto kernel = - BatchQKApplyRotaryKernel; - void *args[] = {(void *)&q, - (void *)&k, - (void *)&q_rope, - (void *)&k_rope, - (void *)&indptr, - (void *)&offsets, - (void *)&batch_size, - (void *)&num_qo_heads, - (void *)&num_kv_heads, - (void *)&rotary_dim, - (void *)&q_stride_n, - (void *)&q_stride_h, - (void *)&k_stride_n, - (void *)&k_stride_h, - (void *)&q_rope_stride_n, - (void *)&q_rope_stride_h, - (void *)&k_rope_stride_n, - (void *)&k_rope_stride_h, - (void *)&smooth_a, - (void *)&smooth_b, - (void *)&rope_rcp_scale, - (void *)&rope_rcp_theta}; - // FLASHINFER_CUDA_CALL(hipLaunchKernelGGL((void*)kernel, nblks, - // nthrs, args, 0, stream)); - kernel<<>>( - q, k, q_rope, k_rope, indptr, offsets, batch_size, num_qo_heads, - num_kv_heads, rotary_dim, q_stride_n, q_stride_h, k_stride_n, - k_stride_h, q_rope_stride_n, q_rope_stride_h, k_rope_stride_n, - k_rope_stride_h, smooth_a, smooth_b, rope_rcp_scale, - rope_rcp_theta); - }); +hipError_t BatchQKApplyLlama31Rotary( + DType* q, DType* k, DType* q_rope, DType* k_rope, IdType* __restrict__ indptr, + IdType* __restrict__ offsets, uint32_t batch_size, uint32_t num_qo_heads, uint32_t num_kv_heads, + uint32_t rotary_dim, uint32_t head_dim, size_t q_stride_n, size_t q_stride_h, size_t k_stride_n, + size_t k_stride_h, size_t q_rope_stride_n, size_t q_rope_stride_h, size_t k_rope_stride_n, + size_t k_rope_stride_h, bool interleave, float rope_scale, float rope_theta, + float low_freq_factor, float high_freq_factor, float old_context_length, + hipStream_t stream = nullptr) { + float rope_rcp_scale = 1.0f / rope_scale; + float rope_rcp_theta = 1.0f / rope_theta; + float smooth_a = old_context_length / (2 * M_PI * high_freq_factor - 2 * M_PI * low_freq_factor); + float smooth_b = -1.0f / (high_freq_factor / low_freq_factor - 1.0f); + + DISPATCH_INTERLEAVE(interleave, INTERLEAVE, { + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + constexpr uint32_t vec_size = std::max(16 / sizeof(DType), HEAD_DIM / 32); + constexpr uint32_t bdx = HEAD_DIM / vec_size; + uint32_t num_threads = std::max(128U, bdx); + uint32_t bdy = num_threads / bdx; + dim3 nblks(batch_size * (num_qo_heads + num_kv_heads)); + dim3 nthrs(bdx, bdy); + auto kernel = BatchQKApplyRotaryKernel; + void* args[] = {(void*)&q, + (void*)&k, + (void*)&q_rope, + (void*)&k_rope, + (void*)&indptr, + (void*)&offsets, + (void*)&batch_size, + (void*)&num_qo_heads, + (void*)&num_kv_heads, + (void*)&rotary_dim, + (void*)&q_stride_n, + (void*)&q_stride_h, + (void*)&k_stride_n, + (void*)&k_stride_h, + (void*)&q_rope_stride_n, + (void*)&q_rope_stride_h, + (void*)&k_rope_stride_n, + (void*)&k_rope_stride_h, + (void*)&smooth_a, + (void*)&smooth_b, + (void*)&rope_rcp_scale, + (void*)&rope_rcp_theta}; + // FLASHINFER_CUDA_CALL(hipLaunchKernelGGL((void*)kernel, nblks, + // nthrs, args, 0, stream)); + kernel<<>>( + q, k, q_rope, k_rope, indptr, offsets, batch_size, num_qo_heads, num_kv_heads, rotary_dim, + q_stride_n, q_stride_h, k_stride_n, k_stride_h, q_rope_stride_n, q_rope_stride_h, + k_rope_stride_n, k_rope_stride_h, smooth_a, smooth_b, rope_rcp_scale, rope_rcp_theta); }); + }); - return hipSuccess; + return hipSuccess; } template -hipError_t BatchQKApplyLlama31RotaryPosIds(DType *q, - DType *k, - DType *q_rope, - DType *k_rope, - IdType *pos_ids, - uint32_t nnz, - uint32_t num_qo_heads, - uint32_t num_kv_heads, - uint32_t rotary_dim, - uint32_t head_dim, - size_t q_stride_n, - size_t q_stride_h, - size_t k_stride_n, - size_t k_stride_h, - size_t q_rope_stride_n, - size_t q_rope_stride_h, - size_t k_rope_stride_n, - size_t k_rope_stride_h, - bool interleave, - float rope_scale, - float rope_theta, - float low_freq_factor, - float high_freq_factor, - float old_context_length, - hipStream_t stream = nullptr) -{ - float rope_rcp_scale = 1.0f / rope_scale; - float rope_rcp_theta = 1.0f / rope_theta; - float smooth_a = old_context_length / - (2 * M_PI * high_freq_factor - 2 * M_PI * low_freq_factor); - float smooth_b = -1.0f / (high_freq_factor / low_freq_factor - 1.0f); - - DISPATCH_INTERLEAVE(interleave, INTERLEAVE, { - DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { - constexpr uint32_t vec_size = - std::max(16 / sizeof(DType), HEAD_DIM / 32); - constexpr uint32_t bdx = HEAD_DIM / vec_size; - uint32_t num_threads = std::max(128U, bdx); - uint32_t bdy = num_threads / bdx; - dim3 nblks((nnz + bdy - 1) / bdy); - dim3 nthrs(bdx, bdy); - auto kernel = - BatchQKApplyRotaryPosIdsKernel; - void *args[] = {(void *)&q, - (void *)&k, - (void *)&q_rope, - (void *)&k_rope, - (void *)&pos_ids, - (void *)&nnz, - (void *)&num_qo_heads, - (void *)&num_kv_heads, - (void *)&rotary_dim, - (void *)&q_stride_n, - (void *)&q_stride_h, - (void *)&k_stride_n, - (void *)&k_stride_h, - (void *)&q_rope_stride_n, - (void *)&q_rope_stride_h, - (void *)&k_rope_stride_n, - (void *)&k_rope_stride_h, - (void *)&smooth_a, - (void *)&smooth_b, - (void *)&rope_rcp_scale, - (void *)&rope_rcp_theta}; - // FLASHINFER_CUDA_CALL(hipLaunchKernelGGL((void*)kernel, nblks, - // nthrs, args, 0, stream)); - kernel<<>>( - q, k, q_rope, k_rope, pos_ids, nnz, num_qo_heads, num_kv_heads, - rotary_dim, q_stride_n, q_stride_h, k_stride_n, k_stride_h, - q_rope_stride_n, q_rope_stride_h, k_rope_stride_n, - k_rope_stride_h, smooth_a, smooth_b, rope_rcp_scale, - rope_rcp_theta); - }); +hipError_t BatchQKApplyLlama31RotaryPosIds( + DType* q, DType* k, DType* q_rope, DType* k_rope, IdType* pos_ids, uint32_t nnz, + uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t rotary_dim, uint32_t head_dim, + size_t q_stride_n, size_t q_stride_h, size_t k_stride_n, size_t k_stride_h, + size_t q_rope_stride_n, size_t q_rope_stride_h, size_t k_rope_stride_n, size_t k_rope_stride_h, + bool interleave, float rope_scale, float rope_theta, float low_freq_factor, + float high_freq_factor, float old_context_length, hipStream_t stream = nullptr) { + float rope_rcp_scale = 1.0f / rope_scale; + float rope_rcp_theta = 1.0f / rope_theta; + float smooth_a = old_context_length / (2 * M_PI * high_freq_factor - 2 * M_PI * low_freq_factor); + float smooth_b = -1.0f / (high_freq_factor / low_freq_factor - 1.0f); + + DISPATCH_INTERLEAVE(interleave, INTERLEAVE, { + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + constexpr uint32_t vec_size = std::max(16 / sizeof(DType), HEAD_DIM / 32); + constexpr uint32_t bdx = HEAD_DIM / vec_size; + uint32_t num_threads = std::max(128U, bdx); + uint32_t bdy = num_threads / bdx; + dim3 nblks((nnz + bdy - 1) / bdy); + dim3 nthrs(bdx, bdy); + auto kernel = + BatchQKApplyRotaryPosIdsKernel; + void* args[] = {(void*)&q, + (void*)&k, + (void*)&q_rope, + (void*)&k_rope, + (void*)&pos_ids, + (void*)&nnz, + (void*)&num_qo_heads, + (void*)&num_kv_heads, + (void*)&rotary_dim, + (void*)&q_stride_n, + (void*)&q_stride_h, + (void*)&k_stride_n, + (void*)&k_stride_h, + (void*)&q_rope_stride_n, + (void*)&q_rope_stride_h, + (void*)&k_rope_stride_n, + (void*)&k_rope_stride_h, + (void*)&smooth_a, + (void*)&smooth_b, + (void*)&rope_rcp_scale, + (void*)&rope_rcp_theta}; + // FLASHINFER_CUDA_CALL(hipLaunchKernelGGL((void*)kernel, nblks, + // nthrs, args, 0, stream)); + kernel<<>>( + q, k, q_rope, k_rope, pos_ids, nnz, num_qo_heads, num_kv_heads, rotary_dim, q_stride_n, + q_stride_h, k_stride_n, k_stride_h, q_rope_stride_n, q_rope_stride_h, k_rope_stride_n, + k_rope_stride_h, smooth_a, smooth_b, rope_rcp_scale, rope_rcp_theta); }); + }); - return hipSuccess; + return hipSuccess; } -} // namespace flashinfer +} // namespace flashinfer -#endif // FLASHINFER_POS_ENC_CUH_ +#endif // FLASHINFER_POS_ENC_CUH_ diff --git a/libflashinfer/include/flashinfer/hip/utils.hip.h b/libflashinfer/include/flashinfer/hip/utils.hip.h index 392dfa1a95..0e68bf974a 100644 --- a/libflashinfer/include/flashinfer/hip/utils.hip.h +++ b/libflashinfer/include/flashinfer/hip/utils.hip.h @@ -7,15 +7,14 @@ #ifndef FLASHINFER_UTILS_CUH_ #define FLASHINFER_UTILS_CUH_ -#include "../exception.h" - -#include "hip_platform.h" - #include #include #include #include +#include "../exception.h" +#include "hip_platform.h" + #define STR_HELPER(x) #x #define STR(x) STR_HELPER(x) @@ -25,406 +24,335 @@ #endif #ifndef NDEBUG -#define FLASHINFER_CUDA_CALL(func, ...) \ - { \ - hipError_t e = (func); \ - if (e != hipSuccess) { \ - std::cerr << "HIP Error: " << hipGetErrorString(e) << " (" << e \ - << ") " << __FILE__ << ": line " << __LINE__ \ - << " at function " << STR(func) << std::endl; \ - return e; \ - } \ - } +#define FLASHINFER_CUDA_CALL(func, ...) \ + { \ + hipError_t e = (func); \ + if (e != hipSuccess) { \ + std::cerr << "HIP Error: " << hipGetErrorString(e) << " (" << e << ") " << __FILE__ \ + << ": line " << __LINE__ << " at function " << STR(func) << std::endl; \ + return e; \ + } \ + } #else -#define FLASHINFER_CUDA_CALL(func, ...) \ - { \ - hipError_t e = (func); \ - if (e != hipSuccess) { \ - return e; \ - } \ - } +#define FLASHINFER_CUDA_CALL(func, ...) \ + { \ + hipError_t e = (func); \ + if (e != hipSuccess) { \ + return e; \ + } \ + } #endif -#define CHECK_HIP_ERROR(call) \ - { \ - hipError_t err = call; \ - if (err != hipSuccess) { \ - std::cerr << "HIP error at " << __FILE__ << " : " << __LINE__ \ - << " -> " << hipGetErrorString(err) << std::endl; \ - exit(1); \ - } \ - } +#define CHECK_HIP_ERROR(call) \ + { \ + hipError_t err = call; \ + if (err != hipSuccess) { \ + std::cerr << "HIP error at " << __FILE__ << " : " << __LINE__ << " -> " \ + << hipGetErrorString(err) << std::endl; \ + exit(1); \ + } \ + } -#define DISPATCH_USE_FP16_QK_REDUCTION(use_fp16_qk_reduction, \ - USE_FP16_QK_REDUCTION, ...) \ - if (use_fp16_qk_reduction) { \ - FLASHINFER_ERROR("FP16_QK_REDUCTION disabled at compile time"); \ - } \ - else { \ - constexpr bool USE_FP16_QK_REDUCTION = false; \ - __VA_ARGS__ \ - } +#define DISPATCH_USE_FP16_QK_REDUCTION(use_fp16_qk_reduction, USE_FP16_QK_REDUCTION, ...) \ + if (use_fp16_qk_reduction) { \ + FLASHINFER_ERROR("FP16_QK_REDUCTION disabled at compile time"); \ + } else { \ + constexpr bool USE_FP16_QK_REDUCTION = false; \ + __VA_ARGS__ \ + } -#define DISPATCH_NUM_MMA_Q(num_mma_q, NUM_MMA_Q, ...) \ - if (num_mma_q == 1) { \ - constexpr size_t NUM_MMA_Q = 1; \ - __VA_ARGS__ \ - } \ - else if (num_mma_q == 2) { \ - constexpr size_t NUM_MMA_Q = 2; \ - __VA_ARGS__ \ - } \ - else { \ - std::ostringstream err_msg; \ - err_msg << "Unsupported num_mma_q: " << num_mma_q; \ - FLASHINFER_ERROR(err_msg.str()); \ - } +#define DISPATCH_NUM_MMA_Q(num_mma_q, NUM_MMA_Q, ...) \ + if (num_mma_q == 1) { \ + constexpr size_t NUM_MMA_Q = 1; \ + __VA_ARGS__ \ + } else if (num_mma_q == 2) { \ + constexpr size_t NUM_MMA_Q = 2; \ + __VA_ARGS__ \ + } else { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported num_mma_q: " << num_mma_q; \ + FLASHINFER_ERROR(err_msg.str()); \ + } -#define DISPATCH_NUM_MMA_KV(max_mma_kv, NUM_MMA_KV, ...) \ - if (max_mma_kv >= 8) { \ - constexpr size_t NUM_MMA_KV = 8; \ - __VA_ARGS__ \ - } \ - else if (max_mma_kv >= 4) { \ - constexpr size_t NUM_MMA_KV = 4; \ - __VA_ARGS__ \ - } \ - else if (max_mma_kv >= 2) { \ - constexpr size_t NUM_MMA_KV = 2; \ - __VA_ARGS__ \ - } \ - else if (max_mma_kv >= 1) { \ - constexpr size_t NUM_MMA_KV = 1; \ - __VA_ARGS__ \ - } \ - else { \ - std::ostringstream err_msg; \ - err_msg << "Unsupported max_mma_kv: " << max_mma_kv; \ - FLASHINFER_ERROR(err_msg.str()); \ - } +#define DISPATCH_NUM_MMA_KV(max_mma_kv, NUM_MMA_KV, ...) \ + if (max_mma_kv >= 8) { \ + constexpr size_t NUM_MMA_KV = 8; \ + __VA_ARGS__ \ + } else if (max_mma_kv >= 4) { \ + constexpr size_t NUM_MMA_KV = 4; \ + __VA_ARGS__ \ + } else if (max_mma_kv >= 2) { \ + constexpr size_t NUM_MMA_KV = 2; \ + __VA_ARGS__ \ + } else if (max_mma_kv >= 1) { \ + constexpr size_t NUM_MMA_KV = 1; \ + __VA_ARGS__ \ + } else { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported max_mma_kv: " << max_mma_kv; \ + FLASHINFER_ERROR(err_msg.str()); \ + } -#define DISPATCH_CTA_TILE_Q(cta_tile_q, CTA_TILE_Q, ...) \ - switch (cta_tile_q) { \ - case 128: \ - { \ - constexpr uint32_t CTA_TILE_Q = 128; \ - __VA_ARGS__ \ - break; \ - } \ - case 64: \ - { \ - constexpr uint32_t CTA_TILE_Q = 64; \ - __VA_ARGS__ \ - break; \ - } \ - case 16: \ - { \ - constexpr uint32_t CTA_TILE_Q = 16; \ - __VA_ARGS__ \ - break; \ - } \ - default: \ - { \ - std::ostringstream err_msg; \ - err_msg << "Unsupported cta_tile_q: " << cta_tile_q; \ - FLASHINFER_ERROR(err_msg.str()); \ - } \ - } +#define DISPATCH_CTA_TILE_Q(cta_tile_q, CTA_TILE_Q, ...) \ + switch (cta_tile_q) { \ + case 128: { \ + constexpr uint32_t CTA_TILE_Q = 128; \ + __VA_ARGS__ \ + break; \ + } \ + case 64: { \ + constexpr uint32_t CTA_TILE_Q = 64; \ + __VA_ARGS__ \ + break; \ + } \ + case 16: { \ + constexpr uint32_t CTA_TILE_Q = 16; \ + __VA_ARGS__ \ + break; \ + } \ + default: { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported cta_tile_q: " << cta_tile_q; \ + FLASHINFER_ERROR(err_msg.str()); \ + } \ + } -#define DISPATCH_GQA_GROUP_SIZE(group_size, GROUP_SIZE, ...) \ - if (group_size == 1) { \ - constexpr size_t GROUP_SIZE = 1; \ - __VA_ARGS__ \ - } \ - else if (group_size == 2) { \ - constexpr size_t GROUP_SIZE = 2; \ - __VA_ARGS__ \ - } \ - else if (group_size == 3) { \ - constexpr size_t GROUP_SIZE = 3; \ - __VA_ARGS__ \ - } \ - else if (group_size == 4) { \ - constexpr size_t GROUP_SIZE = 4; \ - __VA_ARGS__ \ - } \ - else if (group_size == 8) { \ - constexpr size_t GROUP_SIZE = 8; \ - __VA_ARGS__ \ - } \ - else { \ - std::ostringstream err_msg; \ - err_msg << "Unsupported group_size: " << group_size; \ - FLASHINFER_ERROR(err_msg.str()); \ - } +#define DISPATCH_GQA_GROUP_SIZE(group_size, GROUP_SIZE, ...) \ + if (group_size == 1) { \ + constexpr size_t GROUP_SIZE = 1; \ + __VA_ARGS__ \ + } else if (group_size == 2) { \ + constexpr size_t GROUP_SIZE = 2; \ + __VA_ARGS__ \ + } else if (group_size == 3) { \ + constexpr size_t GROUP_SIZE = 3; \ + __VA_ARGS__ \ + } else if (group_size == 4) { \ + constexpr size_t GROUP_SIZE = 4; \ + __VA_ARGS__ \ + } else if (group_size == 8) { \ + constexpr size_t GROUP_SIZE = 8; \ + __VA_ARGS__ \ + } else { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported group_size: " << group_size; \ + FLASHINFER_ERROR(err_msg.str()); \ + } -#define DISPATCH_MASK_MODE(mask_mode, MASK_MODE, ...) \ - switch (mask_mode) { \ - case MaskMode::kNone: \ - { \ - constexpr MaskMode MASK_MODE = MaskMode::kNone; \ - __VA_ARGS__ \ - break; \ - } \ - case MaskMode::kCausal: \ - { \ - constexpr MaskMode MASK_MODE = MaskMode::kCausal; \ - __VA_ARGS__ \ - break; \ - } \ - case MaskMode::kCustom: \ - { \ - constexpr MaskMode MASK_MODE = MaskMode::kCustom; \ - __VA_ARGS__ \ - break; \ - } \ - default: \ - { \ - std::ostringstream err_msg; \ - err_msg << "Unsupported mask_mode: " << int(mask_mode); \ - FLASHINFER_ERROR(err_msg.str()); \ - } \ - } +#define DISPATCH_MASK_MODE(mask_mode, MASK_MODE, ...) \ + switch (mask_mode) { \ + case MaskMode::kNone: { \ + constexpr MaskMode MASK_MODE = MaskMode::kNone; \ + __VA_ARGS__ \ + break; \ + } \ + case MaskMode::kCausal: { \ + constexpr MaskMode MASK_MODE = MaskMode::kCausal; \ + __VA_ARGS__ \ + break; \ + } \ + case MaskMode::kCustom: { \ + constexpr MaskMode MASK_MODE = MaskMode::kCustom; \ + __VA_ARGS__ \ + break; \ + } \ + default: { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported mask_mode: " << int(mask_mode); \ + FLASHINFER_ERROR(err_msg.str()); \ + } \ + } // convert head_dim to compile-time constant -#define DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, ...) \ - switch (head_dim) { \ - case 64: \ - { \ - constexpr size_t HEAD_DIM = 64; \ - __VA_ARGS__ \ - break; \ - } \ - case 128: \ - { \ - constexpr size_t HEAD_DIM = 128; \ - __VA_ARGS__ \ - break; \ - } \ - case 256: \ - { \ - constexpr size_t HEAD_DIM = 256; \ - __VA_ARGS__ \ - break; \ - } \ - case 512: \ - { \ - constexpr size_t HEAD_DIM = 512; \ - __VA_ARGS__ \ - break; \ - } \ - default: \ - { \ - std::ostringstream err_msg; \ - err_msg << "Unsupported head_dim: " << head_dim; \ - FLASHINFER_ERROR(err_msg.str()); \ - } \ - } +#define DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, ...) \ + switch (head_dim) { \ + case 64: { \ + constexpr size_t HEAD_DIM = 64; \ + __VA_ARGS__ \ + break; \ + } \ + case 128: { \ + constexpr size_t HEAD_DIM = 128; \ + __VA_ARGS__ \ + break; \ + } \ + case 256: { \ + constexpr size_t HEAD_DIM = 256; \ + __VA_ARGS__ \ + break; \ + } \ + case 512: { \ + constexpr size_t HEAD_DIM = 512; \ + __VA_ARGS__ \ + break; \ + } \ + default: { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported head_dim: " << head_dim; \ + FLASHINFER_ERROR(err_msg.str()); \ + } \ + } -#define DISPATCH_POS_ENCODING_MODE(pos_encoding_mode, POS_ENCODING_MODE, ...) \ - switch (pos_encoding_mode) { \ - case PosEncodingMode::kNone: \ - { \ - constexpr PosEncodingMode POS_ENCODING_MODE = PosEncodingMode::kNone; \ - __VA_ARGS__ \ - break; \ - } \ - case PosEncodingMode::kRoPELlama: \ - { \ - constexpr PosEncodingMode POS_ENCODING_MODE = \ - PosEncodingMode::kRoPELlama; \ - __VA_ARGS__ \ - break; \ - } \ - case PosEncodingMode::kALiBi: \ - { \ - constexpr PosEncodingMode POS_ENCODING_MODE = PosEncodingMode::kALiBi; \ - __VA_ARGS__ \ - break; \ - } \ - default: \ - { \ - std::ostringstream err_msg; \ - err_msg << "Unsupported pos_encoding_mode: " \ - << int(pos_encoding_mode); \ - FLASHINFER_ERROR(err_msg.str()); \ - } \ - } +#define DISPATCH_POS_ENCODING_MODE(pos_encoding_mode, POS_ENCODING_MODE, ...) \ + switch (pos_encoding_mode) { \ + case PosEncodingMode::kNone: { \ + constexpr PosEncodingMode POS_ENCODING_MODE = PosEncodingMode::kNone; \ + __VA_ARGS__ \ + break; \ + } \ + case PosEncodingMode::kRoPELlama: { \ + constexpr PosEncodingMode POS_ENCODING_MODE = PosEncodingMode::kRoPELlama; \ + __VA_ARGS__ \ + break; \ + } \ + case PosEncodingMode::kALiBi: { \ + constexpr PosEncodingMode POS_ENCODING_MODE = PosEncodingMode::kALiBi; \ + __VA_ARGS__ \ + break; \ + } \ + default: { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported pos_encoding_mode: " << int(pos_encoding_mode); \ + FLASHINFER_ERROR(err_msg.str()); \ + } \ + } -#define DISPATCH_ALIGNED_VEC_SIZE(aligned_vec_size, ALIGNED_VEC_SIZE, ...) \ - switch (aligned_vec_size) { \ - case 16: \ - { \ - constexpr size_t ALIGNED_VEC_SIZE = 16; \ - __VA_ARGS__ \ - break; \ - } \ - case 8: \ - { \ - constexpr size_t ALIGNED_VEC_SIZE = 8; \ - __VA_ARGS__ \ - break; \ - } \ - case 4: \ - { \ - constexpr size_t ALIGNED_VEC_SIZE = 4; \ - __VA_ARGS__ \ - break; \ - } \ - case 2: \ - { \ - constexpr size_t ALIGNED_VEC_SIZE = 2; \ - __VA_ARGS__ \ - break; \ - } \ - case 1: \ - { \ - constexpr size_t ALIGNED_VEC_SIZE = 1; \ - __VA_ARGS__ \ - break; \ - } \ - default: \ - { \ - std::ostringstream err_msg; \ - err_msg << "Unsupported aligned_vec_size: " << aligned_vec_size; \ - FLASHINFER_ERROR(err_msg.str()); \ - } \ - } +#define DISPATCH_ALIGNED_VEC_SIZE(aligned_vec_size, ALIGNED_VEC_SIZE, ...) \ + switch (aligned_vec_size) { \ + case 16: { \ + constexpr size_t ALIGNED_VEC_SIZE = 16; \ + __VA_ARGS__ \ + break; \ + } \ + case 8: { \ + constexpr size_t ALIGNED_VEC_SIZE = 8; \ + __VA_ARGS__ \ + break; \ + } \ + case 4: { \ + constexpr size_t ALIGNED_VEC_SIZE = 4; \ + __VA_ARGS__ \ + break; \ + } \ + case 2: { \ + constexpr size_t ALIGNED_VEC_SIZE = 2; \ + __VA_ARGS__ \ + break; \ + } \ + case 1: { \ + constexpr size_t ALIGNED_VEC_SIZE = 1; \ + __VA_ARGS__ \ + break; \ + } \ + default: { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported aligned_vec_size: " << aligned_vec_size; \ + FLASHINFER_ERROR(err_msg.str()); \ + } \ + } -#define DISPATCH_COMPUTE_CAP_DECODE_NUM_STAGES_SMEM(compute_capacity, \ - NUM_STAGES_SMEM, ...) \ - if (compute_capacity.first >= 8) { \ - constexpr uint32_t NUM_STAGES_SMEM = 2; \ - __VA_ARGS__ \ - } \ - else { \ - constexpr uint32_t NUM_STAGES_SMEM = 1; \ - __VA_ARGS__ \ - } +#define DISPATCH_COMPUTE_CAP_DECODE_NUM_STAGES_SMEM(compute_capacity, NUM_STAGES_SMEM, ...) \ + if (compute_capacity.first >= 8) { \ + constexpr uint32_t NUM_STAGES_SMEM = 2; \ + __VA_ARGS__ \ + } else { \ + constexpr uint32_t NUM_STAGES_SMEM = 1; \ + __VA_ARGS__ \ + } -namespace flashinfer -{ +namespace flashinfer { template -__forceinline__ __device__ __host__ T1 ceil_div(const T1 x, const T2 y) -{ - T2 y2 = y; - // if(y2 == 0){ - // y2 = 1; - // } - return (x + y2 - 1) / y2; +__forceinline__ __device__ __host__ T1 ceil_div(const T1 x, const T2 y) { + T2 y2 = y; + // if(y2 == 0){ + // y2 = 1; + // } + return (x + y2 - 1) / y2; } -inline std::pair GetCudaComputeCapability() -{ - int device_id = 0; - FI_GPU_CALL(hipGetDevice(&device_id)); - int major = 0, minor = 0; - FI_GPU_CALL(hipDeviceGetAttribute( - &major, hipDeviceAttributeComputeCapabilityMajor, device_id)); - FI_GPU_CALL(hipDeviceGetAttribute( - &minor, hipDeviceAttributeComputeCapabilityMinor, device_id)); - return std::make_pair(major, minor); +inline std::pair GetCudaComputeCapability() { + int device_id = 0; + FI_GPU_CALL(hipGetDevice(&device_id)); + int major = 0, minor = 0; + FI_GPU_CALL(hipDeviceGetAttribute(&major, hipDeviceAttributeComputeCapabilityMajor, device_id)); + FI_GPU_CALL(hipDeviceGetAttribute(&minor, hipDeviceAttributeComputeCapabilityMinor, device_id)); + return std::make_pair(major, minor); } template -inline void -DebugPrintCUDAArray(T *device_ptr, size_t size, std::string prefix = "") -{ - std::vector host_array(size); - std::cout << prefix; - hipMemcpy(host_array.data(), device_ptr, size * sizeof(T), - hipMemcpyDeviceToHost); - for (size_t i = 0; i < size; ++i) { - std::cout << host_array[i] << " "; - } - std::cout << std::endl; +inline void DebugPrintCUDAArray(T* device_ptr, size_t size, std::string prefix = "") { + std::vector host_array(size); + std::cout << prefix; + hipMemcpy(host_array.data(), device_ptr, size * sizeof(T), hipMemcpyDeviceToHost); + for (size_t i = 0; i < size; ++i) { + std::cout << host_array[i] << " "; + } + std::cout << std::endl; } -inline uint32_t FA2DetermineCtaTileQ(int64_t avg_packed_qo_len, - uint32_t head_dim) -{ - if (avg_packed_qo_len > 64 && head_dim < 256) { - return 128; - } - else { - auto compute_capacity = GetCudaComputeCapability(); - if (compute_capacity.first >= 8) { - // Ampere or newer - if (avg_packed_qo_len > 16) { - // avg_packed_qo_len <= 64 - return 64; - } - else { - // avg_packed_qo_len <= 16 - return 16; - } - } - else { - // NOTE(Zihao): not enough shared memory on Turing for 1x4 warp - // layout - return 64; - } +inline uint32_t FA2DetermineCtaTileQ(int64_t avg_packed_qo_len, uint32_t head_dim) { + if (avg_packed_qo_len > 64 && head_dim < 256) { + return 128; + } else { + auto compute_capacity = GetCudaComputeCapability(); + if (compute_capacity.first >= 8) { + // Ampere or newer + if (avg_packed_qo_len > 16) { + // avg_packed_qo_len <= 64 + return 64; + } else { + // avg_packed_qo_len <= 16 + return 16; + } + } else { + // NOTE(Zihao): not enough shared memory on Turing for 1x4 warp + // layout + return 64; } + } } /// @brief Perform Simple Subtraction /// @param x Input param X /// @param y Input param y /// @return Returns x - y if x > y; else 0; -__device__ __forceinline__ uint32_t sub_if_greater_or_zero(uint32_t x, - uint32_t y) -{ - return (x > y) ? x - y : 0U; +__device__ __forceinline__ uint32_t sub_if_greater_or_zero(uint32_t x, uint32_t y) { + return (x > y) ? x - y : 0U; } -__device__ __forceinline__ void swap(uint32_t &a, uint32_t &b) -{ - uint32_t tmp = a; - a = b; - b = tmp; +__device__ __forceinline__ void swap(uint32_t& a, uint32_t& b) { + uint32_t tmp = a; + a = b; + b = tmp; } -__device__ __forceinline__ uint32_t dim2_offset(const uint32_t &dim_a, - const uint32_t &idx_b, - const uint32_t &idx_a) -{ - return idx_b * dim_a + idx_a; +__device__ __forceinline__ uint32_t dim2_offset(const uint32_t& dim_a, const uint32_t& idx_b, + const uint32_t& idx_a) { + return idx_b * dim_a + idx_a; } -__device__ __forceinline__ uint32_t dim3_offset(const uint32_t &dim_b, - const uint32_t &dim_a, - const uint32_t &idx_c, - const uint32_t &idx_b, - const uint32_t &idx_a) -{ - return (idx_c * dim_b + idx_b) * dim_a + idx_a; +__device__ __forceinline__ uint32_t dim3_offset(const uint32_t& dim_b, const uint32_t& dim_a, + const uint32_t& idx_c, const uint32_t& idx_b, + const uint32_t& idx_a) { + return (idx_c * dim_b + idx_b) * dim_a + idx_a; } -__device__ __forceinline__ uint32_t dim4_offset(const uint32_t &dim_c, - const uint32_t &dim_b, - const uint32_t &dim_a, - const uint32_t &idx_d, - const uint32_t &idx_c, - const uint32_t &idx_b, - const uint32_t &idx_a) -{ - return ((idx_d * dim_c + idx_c) * dim_b + idx_b) * dim_a + idx_a; +__device__ __forceinline__ uint32_t dim4_offset(const uint32_t& dim_c, const uint32_t& dim_b, + const uint32_t& dim_a, const uint32_t& idx_d, + const uint32_t& idx_c, const uint32_t& idx_b, + const uint32_t& idx_a) { + return ((idx_d * dim_c + idx_c) * dim_b + idx_b) * dim_a + idx_a; } -#define DEFINE_HAS_MEMBER(member) \ - template \ - struct has_##member : std::false_type \ - { \ - }; \ - template \ - struct has_##member().member)>> \ - : std::true_type \ - { \ - }; \ - template \ - inline constexpr bool has_##member##_v = has_##member::value; +#define DEFINE_HAS_MEMBER(member) \ + template \ + struct has_##member : std::false_type {}; \ + template \ + struct has_##member().member)>> : std::true_type {}; \ + template \ + inline constexpr bool has_##member##_v = has_##member::value; -} // namespace flashinfer +} // namespace flashinfer -#endif // FLASHINFER_UTILS_CUH_ +#endif // FLASHINFER_UTILS_CUH_ diff --git a/libflashinfer/include/flashinfer/hip/vec_dtypes.hip.h b/libflashinfer/include/flashinfer/hip/vec_dtypes.hip.h index 3bc9aa0393..41fb76baf4 100644 --- a/libflashinfer/include/flashinfer/hip/vec_dtypes.hip.h +++ b/libflashinfer/include/flashinfer/hip/vec_dtypes.hip.h @@ -20,155 +20,132 @@ #define FLASHINFER_INLINE inline __attribute__((always_inline)) __device__ -__host__ __device__ inline __hip_bfloat162 __float2bfloat162_rn(const float a) -{ - return __hip_bfloat162{__float2bfloat16(a), __float2bfloat16(a)}; +__host__ __device__ inline __hip_bfloat162 __float2bfloat162_rn(const float a) { + return __hip_bfloat162{__float2bfloat16(a), __float2bfloat16(a)}; } -FLASHINFER_INLINE __hip_bfloat162 make_bfloat162(const __hip_bfloat16 x, - const __hip_bfloat16 y) -{ - __hip_bfloat162 t; - t.x = x; - t.y = y; - return t; +FLASHINFER_INLINE __hip_bfloat162 make_bfloat162(const __hip_bfloat16 x, const __hip_bfloat16 y) { + __hip_bfloat162 t; + t.x = x; + t.y = y; + return t; } -namespace flashinfer -{ +namespace flashinfer { #define FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED #define FLASHINFER_INLINE inline __attribute__((always_inline)) __device__ -#if (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 < 120400) && \ +#if (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 < 120400) && \ (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)) // CUDA version < 12.4 and GPU architecture < 80 -FLASHINFER_INLINE __hip_bfloat16 __hmul(const __hip_bfloat16 a, - const __hip_bfloat16 b) -{ - __hip_bfloat16 val; - const float fa = __bfloat162float(a); - const float fb = __bfloat162float(b); - // avoid ftz in device code - val = __float2bfloat16(__fmaf_ieee_rn(fa, fb, -0.0f)); - return val; +FLASHINFER_INLINE __hip_bfloat16 __hmul(const __hip_bfloat16 a, const __hip_bfloat16 b) { + __hip_bfloat16 val; + const float fa = __bfloat162float(a); + const float fb = __bfloat162float(b); + // avoid ftz in device code + val = __float2bfloat16(__fmaf_ieee_rn(fa, fb, -0.0f)); + return val; } -FLASHINFER_INLINE __hip_bfloat162 __hmul2(const __hip_bfloat162 a, - const __hip_bfloat162 b) -{ - __hip_bfloat162 val; - val.x = __hmul(a.x, b.x); - val.y = __hmul(a.y, b.y); - return val; +FLASHINFER_INLINE __hip_bfloat162 __hmul2(const __hip_bfloat162 a, const __hip_bfloat162 b) { + __hip_bfloat162 val; + val.x = __hmul(a.x, b.x); + val.y = __hmul(a.y, b.y); + return val; } -FLASHINFER_INLINE __hip_bfloat162 __floats2bfloat162_rn(const float a, - const float b) -{ - __hip_bfloat162 val; - val = __hip_bfloat162(__float2bfloat16(a), __float2bfloat16(b)); - return val; +FLASHINFER_INLINE __hip_bfloat162 __floats2bfloat162_rn(const float a, const float b) { + __hip_bfloat162 val; + val = __hip_bfloat162(__float2bfloat16(a), __float2bfloat16(b)); + return val; } -FLASHINFER_INLINE __hip_bfloat162 __float22bfloat162_rn(const float2 a) -{ - __hip_bfloat162 val = __float22bfloat162_rn(a.x, a.y); - return val; +FLASHINFER_INLINE __hip_bfloat162 __float22bfloat162_rn(const float2 a) { + __hip_bfloat162 val = __float22bfloat162_rn(a.x, a.y); + return val; } -FLASHINFER_INLINE float2 __bfloat1622float2(const __hip_bfloat162 a) -{ - float hi_float; - float lo_float; - // lo_float = __internal_bfloat162float(((__gpu_bfloat162_raw)a).x); - // hi_float = __internal_bfloat162float(((__gpu_bfloat162_raw)a).y); - lo_float = __bfloat1622float2(a.x); - hi_float = __bfloat1622float2(a.y); - return make_float2(lo_float, hi_float); +FLASHINFER_INLINE float2 __bfloat1622float2(const __hip_bfloat162 a) { + float hi_float; + float lo_float; + // lo_float = __internal_bfloat162float(((__gpu_bfloat162_raw)a).x); + // hi_float = __internal_bfloat162float(((__gpu_bfloat162_raw)a).y); + lo_float = __bfloat1622float2(a.x); + hi_float = __bfloat1622float2(a.y); + return make_float2(lo_float, hi_float); } #endif /******************* vec_t type cast *******************/ -template struct vec_cast -{ - template - FLASHINFER_INLINE static void cast(dst_t *dst, const src_t *src) - { +template +struct vec_cast { + template + FLASHINFER_INLINE static void cast(dst_t* dst, const src_t* src) { #pragma unroll - for (size_t i = 0; i < vec_size; ++i) { - dst[i] = (dst_t)src[i]; - } + for (size_t i = 0; i < vec_size; ++i) { + dst[i] = (dst_t)src[i]; } + } }; -template <> struct vec_cast -{ - template - FLASHINFER_INLINE static void cast(float *dst, const half *src) - { - if constexpr (vec_size == 1) { - // dst[0] = (float)src[0]; - dst[0] = __half2float(src[0]); - } - else { +template <> +struct vec_cast { + template + FLASHINFER_INLINE static void cast(float* dst, const half* src) { + if constexpr (vec_size == 1) { + // dst[0] = (float)src[0]; + dst[0] = __half2float(src[0]); + } else { #pragma unroll - for (size_t i = 0; i < vec_size / 2; ++i) { - ((float2 *)dst)[i] = __half22float2(((half2 *)src)[i]); - } - } + for (size_t i = 0; i < vec_size / 2; ++i) { + ((float2*)dst)[i] = __half22float2(((half2*)src)[i]); + } } + } }; -template <> struct vec_cast -{ - template - FLASHINFER_INLINE static void cast(half *dst, const float *src) - { - if constexpr (vec_size == 1) { - dst[0] = __float2half(src[0]); - } - else { +template <> +struct vec_cast { + template + FLASHINFER_INLINE static void cast(half* dst, const float* src) { + if constexpr (vec_size == 1) { + dst[0] = __float2half(src[0]); + } else { #pragma unroll - for (size_t i = 0; i < vec_size / 2; ++i) { - ((half2 *)dst)[i] = __float22half2_rn(((float2 *)src)[i]); - } - } + for (size_t i = 0; i < vec_size / 2; ++i) { + ((half2*)dst)[i] = __float22half2_rn(((float2*)src)[i]); + } } + } }; -template constexpr FLASHINFER_INLINE int get_exponent_bits() -{ - if constexpr (std::is_same_v) { - return 4; - } - else if constexpr (std::is_same_v) { - return 5; - } - else if constexpr (std::is_same_v) { - return 5; - } - else if constexpr (std::is_same_v) { - return 8; - } -} - -template constexpr FLASHINFER_INLINE int get_mantissa_bits() -{ - if constexpr (std::is_same_v) { - return 3; - } - else if constexpr (std::is_same_v) { - return 2; - } - else if constexpr (std::is_same_v) { - return 11; - } - else if constexpr (std::is_same_v) { - return 7; - } +template +constexpr FLASHINFER_INLINE int get_exponent_bits() { + if constexpr (std::is_same_v) { + return 4; + } else if constexpr (std::is_same_v) { + return 5; + } else if constexpr (std::is_same_v) { + return 5; + } else if constexpr (std::is_same_v) { + return 8; + } +} + +template +constexpr FLASHINFER_INLINE int get_mantissa_bits() { + if constexpr (std::is_same_v) { + return 3; + } else if constexpr (std::is_same_v) { + return 2; + } else if constexpr (std::is_same_v) { + return 11; + } else if constexpr (std::is_same_v) { + return 7; + } } /*! @@ -180,207 +157,180 @@ template constexpr FLASHINFER_INLINE int get_mantissa_bits() * https://github.com/vllm-project/vllm/blob/6dffa4b0a6120159ef2fe44d695a46817aff65bc/csrc/quantization/fp8/fp8_marlin.cu#L120 */ template -__device__ void fast_dequant_f8f16x4(uint32_t *input, uint2 *output) -{ - uint32_t q = *input; - if constexpr (std::is_same_v && - std::is_same_v) - { - output->x = __byte_perm(0U, q, 0x5140); - output->y = __byte_perm(0U, q, 0x7362); - } - else { - constexpr int FP8_EXPONENT = get_exponent_bits(); - constexpr int FP8_MANTISSA = get_mantissa_bits(); - constexpr int FP16_EXPONENT = get_exponent_bits(); - - constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP8_EXPONENT; - // Calculate MASK for extracting mantissa and exponent - // XXX: duplicate defs of `MASK1` and `MASK2`, - // in the HIP file "include/hip/amd_detail/amd_device_functions.h". - constexpr int MASK1_orig = 0x80000000; - constexpr int MASK2_orig = MASK1_orig >> (FP8_EXPONENT + FP8_MANTISSA); - constexpr int MASK3 = MASK2_orig & 0x7fffffff; - constexpr int MASK = MASK3 | (MASK3 >> 16); - q = __byte_perm(q, q, 0x1302); - - // Extract and shift FP8 values to FP16 format - uint32_t Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); - uint32_t Out2 = - ((q << 8) & 0x80008000) | (((q << 8) & MASK) >> RIGHT_SHIFT); - - constexpr int BIAS_OFFSET = - (1 << (FP16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1)); - // Construct and apply exponent bias - if constexpr (std::is_same_v) { - const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET)); - - // Convert to half2 and apply bias - *(half2 *)&(output->x) = - __hmul2(*reinterpret_cast(&Out1), bias_reg); - *(half2 *)&(output->y) = - __hmul2(*reinterpret_cast(&Out2), bias_reg); - } - else { - constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23; - const __hip_bfloat162 bias_reg = - __float2bfloat162_rn(*reinterpret_cast(&BIAS)); - // Convert to bfloat162 and apply bias - *(__hip_bfloat162 *)&(output->x) = __hmul2( - *reinterpret_cast(&Out1), bias_reg); - *(__hip_bfloat162 *)&(output->y) = __hmul2( - *reinterpret_cast(&Out2), bias_reg); - } - } -} - -template <> struct vec_cast<__hip_bfloat16, __hip_fp8_e4m3_fnuz> -{ - template - FLASHINFER_INLINE static void cast(__hip_bfloat16 *dst, - const __hip_fp8_e4m3_fnuz *src) - { - if constexpr (vec_size == 1) { - dst[0] = __hip_bfloat16(src[0]); - } - else if constexpr (vec_size == 2) { - dst[0] = __hip_bfloat16(src[0]); - dst[1] = __hip_bfloat16(src[1]); - } - else { - static_assert(vec_size % 4 == 0, - "vec_size must be a multiple of 4"); +__device__ void fast_dequant_f8f16x4(uint32_t* input, uint2* output) { + uint32_t q = *input; + if constexpr (std::is_same_v && + std::is_same_v) { + output->x = __byte_perm(0U, q, 0x5140); + output->y = __byte_perm(0U, q, 0x7362); + } else { + constexpr int FP8_EXPONENT = get_exponent_bits(); + constexpr int FP8_MANTISSA = get_mantissa_bits(); + constexpr int FP16_EXPONENT = get_exponent_bits(); + + constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP8_EXPONENT; + // Calculate MASK for extracting mantissa and exponent + // XXX: duplicate defs of `MASK1` and `MASK2`, + // in the HIP file "include/hip/amd_detail/amd_device_functions.h". + constexpr int MASK1_orig = 0x80000000; + constexpr int MASK2_orig = MASK1_orig >> (FP8_EXPONENT + FP8_MANTISSA); + constexpr int MASK3 = MASK2_orig & 0x7fffffff; + constexpr int MASK = MASK3 | (MASK3 >> 16); + q = __byte_perm(q, q, 0x1302); + + // Extract and shift FP8 values to FP16 format + uint32_t Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + uint32_t Out2 = ((q << 8) & 0x80008000) | (((q << 8) & MASK) >> RIGHT_SHIFT); + + constexpr int BIAS_OFFSET = (1 << (FP16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1)); + // Construct and apply exponent bias + if constexpr (std::is_same_v) { + const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET)); + + // Convert to half2 and apply bias + *(half2*)&(output->x) = __hmul2(*reinterpret_cast(&Out1), bias_reg); + *(half2*)&(output->y) = __hmul2(*reinterpret_cast(&Out2), bias_reg); + } else { + constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23; + const __hip_bfloat162 bias_reg = __float2bfloat162_rn(*reinterpret_cast(&BIAS)); + // Convert to bfloat162 and apply bias + *(__hip_bfloat162*)&(output->x) = + __hmul2(*reinterpret_cast(&Out1), bias_reg); + *(__hip_bfloat162*)&(output->y) = + __hmul2(*reinterpret_cast(&Out2), bias_reg); + } + } +} + +template <> +struct vec_cast<__hip_bfloat16, __hip_fp8_e4m3_fnuz> { + template + FLASHINFER_INLINE static void cast(__hip_bfloat16* dst, const __hip_fp8_e4m3_fnuz* src) { + if constexpr (vec_size == 1) { + dst[0] = __hip_bfloat16(src[0]); + } else if constexpr (vec_size == 2) { + dst[0] = __hip_bfloat16(src[0]); + dst[1] = __hip_bfloat16(src[1]); + } else { + static_assert(vec_size % 4 == 0, "vec_size must be a multiple of 4"); #pragma unroll - for (uint32_t i = 0; i < vec_size / 4; ++i) { - fast_dequant_f8f16x4<__hip_fp8_e4m3_fnuz, __hip_bfloat16>( - (uint32_t *)&src[i * 4], (uint2 *)&dst[i * 4]); - } - } + for (uint32_t i = 0; i < vec_size / 4; ++i) { + fast_dequant_f8f16x4<__hip_fp8_e4m3_fnuz, __hip_bfloat16>((uint32_t*)&src[i * 4], + (uint2*)&dst[i * 4]); + } } + } }; -template <> struct vec_cast<__hip_bfloat16, __hip_fp8_e5m2_fnuz> -{ - template - FLASHINFER_INLINE static void cast(__hip_bfloat16 *dst, - const __hip_fp8_e5m2_fnuz *src) - { - if constexpr (vec_size == 1) { - dst[0] = __hip_bfloat16(src[0]); - } - else if constexpr (vec_size == 2) { - dst[0] = __hip_bfloat16(src[0]); - dst[1] = __hip_bfloat16(src[1]); - } - else { - static_assert(vec_size % 4 == 0, - "vec_size must be a multiple of 4"); +template <> +struct vec_cast<__hip_bfloat16, __hip_fp8_e5m2_fnuz> { + template + FLASHINFER_INLINE static void cast(__hip_bfloat16* dst, const __hip_fp8_e5m2_fnuz* src) { + if constexpr (vec_size == 1) { + dst[0] = __hip_bfloat16(src[0]); + } else if constexpr (vec_size == 2) { + dst[0] = __hip_bfloat16(src[0]); + dst[1] = __hip_bfloat16(src[1]); + } else { + static_assert(vec_size % 4 == 0, "vec_size must be a multiple of 4"); #pragma unroll - for (uint32_t i = 0; i < vec_size / 4; ++i) { - fast_dequant_f8f16x4<__hip_fp8_e5m2_fnuz, __hip_bfloat16>( - (uint32_t *)&src[i * 4], (uint2 *)&dst[i * 4]); - } - } + for (uint32_t i = 0; i < vec_size / 4; ++i) { + fast_dequant_f8f16x4<__hip_fp8_e5m2_fnuz, __hip_bfloat16>((uint32_t*)&src[i * 4], + (uint2*)&dst[i * 4]); + } } + } }; // Function to convert float to e4m3 -__device__ uint8_t convert_f32_to_e4m3(float val) -{ - // Define the range of e4m3 - // 1. Minimum representable value for e4m3 - // 2. Binary 1000.000 in e4m3 - // 3. FLT_MIN is not suitable for e4m3 because e4m3 has a much smaller - // dynamic range. - float min_e4m3 = -8.0f; - // 1. Maximum representable value for e4m3 - // 2. Binary 0111.111 in e4m3 - // FLT_MAX far exceeds the maximum value representable in e4m3. - float max_e4m3 = 7.875f; - - // Saturate the value to the e4m3 range - val = fminf(fmaxf(val, min_e4m3), max_e4m3); - - // Perform conversion - // Decompose into mantissa and exponent - int exp; - float mantissa = frexpf(val, &exp); - - // Encode sign bit - uint8_t sign = (mantissa < 0) ? 0x80 : 0x00; - - // Normalize mantissa and encode exponent - mantissa = - fabsf(mantissa) * 16.0f; // Scale mantissa for e4m3's 3-bit precision - uint8_t exponent = static_cast(exp + 7); // Bias of 7 for e4m3 - - // Quantize mantissa - // Apply round-to-nearest-even to the mantissa - uint8_t quant_mantissa = static_cast(roundf(mantissa)) & 0x07; - - // Combine into 8 bits: [sign][exponent][mantissa] - return sign | (exponent << 3) | quant_mantissa; -} - -__device__ __half2 convert_uint32_to_half2(uint32_t input) -{ - // Extract the low and high 16 bits - uint16_t low_val = input & 0xFFFF; - uint16_t high_val = (input >> 16) & 0xFFFF; - // Convert to __half - __half low_half = __float2half(static_cast(low_val)); - __half high_half = __float2half(static_cast(high_val)); - // Pack into __half2 - return __halves2half2(low_half, high_half); +__device__ uint8_t convert_f32_to_e4m3(float val) { + // Define the range of e4m3 + // 1. Minimum representable value for e4m3 + // 2. Binary 1000.000 in e4m3 + // 3. FLT_MIN is not suitable for e4m3 because e4m3 has a much smaller + // dynamic range. + float min_e4m3 = -8.0f; + // 1. Maximum representable value for e4m3 + // 2. Binary 0111.111 in e4m3 + // FLT_MAX far exceeds the maximum value representable in e4m3. + float max_e4m3 = 7.875f; + + // Saturate the value to the e4m3 range + val = fminf(fmaxf(val, min_e4m3), max_e4m3); + + // Perform conversion + // Decompose into mantissa and exponent + int exp; + float mantissa = frexpf(val, &exp); + + // Encode sign bit + uint8_t sign = (mantissa < 0) ? 0x80 : 0x00; + + // Normalize mantissa and encode exponent + mantissa = fabsf(mantissa) * 16.0f; // Scale mantissa for e4m3's 3-bit precision + uint8_t exponent = static_cast(exp + 7); // Bias of 7 for e4m3 + + // Quantize mantissa + // Apply round-to-nearest-even to the mantissa + uint8_t quant_mantissa = static_cast(roundf(mantissa)) & 0x07; + + // Combine into 8 bits: [sign][exponent][mantissa] + return sign | (exponent << 3) | quant_mantissa; +} + +__device__ __half2 convert_uint32_to_half2(uint32_t input) { + // Extract the low and high 16 bits + uint16_t low_val = input & 0xFFFF; + uint16_t high_val = (input >> 16) & 0xFFFF; + // Convert to __half + __half low_half = __float2half(static_cast(low_val)); + __half high_half = __float2half(static_cast(high_val)); + // Pack into __half2 + return __halves2half2(low_half, high_half); } // Convert f16x2 (__half2) to e4m3x2 (packed 16-bit) -__device__ uint16_t convert_f16x2_to_e4m3x2(__half2 x) -{ - float f32_0 = __half2float(__low2half(x)); - float f32_1 = __half2float(__high2half(x)); - uint8_t e4m3_0 = convert_f32_to_e4m3(f32_0); - uint8_t e4m3_1 = convert_f32_to_e4m3(f32_1); - return (static_cast(e4m3_1) << 8) | e4m3_0; -} - -template <> struct vec_cast<__hip_fp8_e4m3_fnuz, half> -{ - template - FLASHINFER_INLINE static void cast(__hip_fp8_e4m3_fnuz *dst, - const half *src) - { +__device__ uint16_t convert_f16x2_to_e4m3x2(__half2 x) { + float f32_0 = __half2float(__low2half(x)); + float f32_1 = __half2float(__high2half(x)); + uint8_t e4m3_0 = convert_f32_to_e4m3(f32_0); + uint8_t e4m3_1 = convert_f32_to_e4m3(f32_1); + return (static_cast(e4m3_1) << 8) | e4m3_0; +} + +template <> +struct vec_cast<__hip_fp8_e4m3_fnuz, half> { + template + FLASHINFER_INLINE static void cast(__hip_fp8_e4m3_fnuz* dst, const half* src) { #ifdef FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED - if constexpr (vec_size == 1) { - dst[0] = __hip_fp8_e4m3_fnuz(src[0]); - } - else { + if constexpr (vec_size == 1) { + dst[0] = __hip_fp8_e4m3_fnuz(src[0]); + } else { #pragma unroll - for (size_t i = 0; i < vec_size / 2; ++i) { - uint16_t y; - uint32_t x = *(uint32_t *)&src[i * 2]; - __half2 x_h2 = convert_uint32_to_half2(x); - y = convert_f16x2_to_e4m3x2(x_h2); - - *(uint16_t *)&dst[i * 2] = y; - } - } + for (size_t i = 0; i < vec_size / 2; ++i) { + uint16_t y; + uint32_t x = *(uint32_t*)&src[i * 2]; + __half2 x_h2 = convert_uint32_to_half2(x); + y = convert_f16x2_to_e4m3x2(x_h2); + + *(uint16_t*)&dst[i * 2] = y; + } + } #else #pragma unroll - for (size_t i = 0; i < vec_size; ++i) { - dst[i] = __hip_fp8_e4m3_fnuz(src[i]); - } -#endif // FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED + for (size_t i = 0; i < vec_size; ++i) { + dst[i] = __hip_fp8_e4m3_fnuz(src[i]); } +#endif // FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED + } }; -__device__ uint16_t convert_f16x2_to_e5m2x2(uint32_t x) -{ - // Unpack the two 16-bit half-precision floats from the input - // Extract lower 16 bits - __half h1 = __ushort_as_half(x & 0xFFFF); - // Extract upper 16 bits - __half h2 = __ushort_as_half((x >> 16) & 0xFFFF); +__device__ uint16_t convert_f16x2_to_e5m2x2(uint32_t x) { + // Unpack the two 16-bit half-precision floats from the input + // Extract lower 16 bits + __half h1 = __ushort_as_half(x & 0xFFFF); + // Extract upper 16 bits + __half h2 = __ushort_as_half((x >> 16) & 0xFFFF); #if 0 // Alternative with `__uint2half_rn` @@ -390,1620 +340,1292 @@ __device__ uint16_t convert_f16x2_to_e5m2x2(uint32_t x) __half h2 = __uint2half_rn(val2); #endif - // Define the range of e5m2 - // Minimum representable value for e5m2 - const float min_e5m2 = -8.0f; - // Maximum representable value for e5m2 - const float max_e5m2 = 7.75f; + // Define the range of e5m2 + // Minimum representable value for e5m2 + const float min_e5m2 = -8.0f; + // Maximum representable value for e5m2 + const float max_e5m2 = 7.75f; - // Helper lambda for conversion - auto f32_to_e5m2 = [min_e5m2, max_e5m2](float val) -> uint8_t { - // Saturate the val - val = fminf(fmaxf(val, min_e5m2), max_e5m2); + // Helper lambda for conversion + auto f32_to_e5m2 = [min_e5m2, max_e5m2](float val) -> uint8_t { + // Saturate the val + val = fminf(fmaxf(val, min_e5m2), max_e5m2); - // Decompose into mantissa and exponent - int exp; - float mantissa = frexpf(val, &exp); + // Decompose into mantissa and exponent + int exp; + float mantissa = frexpf(val, &exp); - // Encode sign bit - uint8_t sign = (mantissa < 0) ? 0x10 : 0x00; // Sign in bit 4 - mantissa = fabsf(mantissa); + // Encode sign bit + uint8_t sign = (mantissa < 0) ? 0x10 : 0x00; // Sign in bit 4 + mantissa = fabsf(mantissa); - // Normalize mantissa and encode exponent - mantissa *= 4.0f; // Scale for 2-bit mantissa - uint8_t exponent = static_cast(exp + 7); // Apply bias for e5m2 + // Normalize mantissa and encode exponent + mantissa *= 4.0f; // Scale for 2-bit mantissa + uint8_t exponent = static_cast(exp + 7); // Apply bias for e5m2 - // Apply round-to-nearest-even - uint8_t quant_mantissa = static_cast(roundf(mantissa)) & 0x03; + // Apply round-to-nearest-even + uint8_t quant_mantissa = static_cast(roundf(mantissa)) & 0x03; - // Combine into 5 bits: [sign][exponent][mantissa] - return sign | (exponent << 2) | quant_mantissa; - }; + // Combine into 5 bits: [sign][exponent][mantissa] + return sign | (exponent << 2) | quant_mantissa; + }; - // Convert the two __half values to e5m2 - uint8_t e5m2_1 = f32_to_e5m2(__half2float(h1)); - uint8_t e5m2_2 = f32_to_e5m2(__half2float(h2)); + // Convert the two __half values to e5m2 + uint8_t e5m2_1 = f32_to_e5m2(__half2float(h1)); + uint8_t e5m2_2 = f32_to_e5m2(__half2float(h2)); - // Pack the two e5m2 values into a single 16-bit output - return (e5m2_2 << 8) | e5m2_1; + // Pack the two e5m2 values into a single 16-bit output + return (e5m2_2 << 8) | e5m2_1; } #endif -template <> struct vec_cast<__hip_fp8_e5m2_fnuz, half> -{ - template - FLASHINFER_INLINE static void cast(__hip_fp8_e5m2_fnuz *dst, - const half *src) - { +template <> +struct vec_cast<__hip_fp8_e5m2_fnuz, half> { + template + FLASHINFER_INLINE static void cast(__hip_fp8_e5m2_fnuz* dst, const half* src) { #ifdef FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED - if constexpr (vec_size == 1) { - dst[0] = __hip_fp8_e5m2_fnuz(src[0]); - } - else { + if constexpr (vec_size == 1) { + dst[0] = __hip_fp8_e5m2_fnuz(src[0]); + } else { #pragma unroll - for (size_t i = 0; i < vec_size / 2; ++i) { - uint16_t y; - uint32_t x = *(uint32_t *)&src[i * 2]; - y = convert_f16x2_to_e5m2x2(x); - *(uint16_t *)&dst[i * 2] = y; - } - } + for (size_t i = 0; i < vec_size / 2; ++i) { + uint16_t y; + uint32_t x = *(uint32_t*)&src[i * 2]; + y = convert_f16x2_to_e5m2x2(x); + *(uint16_t*)&dst[i * 2] = y; + } + } #else #pragma unroll - for (size_t i = 0; i < vec_size; ++i) { - dst[i] = __hip_fp8_e5m2_fnuz(src[i]); - } -#endif // FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED + for (size_t i = 0; i < vec_size; ++i) { + dst[i] = __hip_fp8_e5m2_fnuz(src[i]); } +#endif // FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED + } }; -__device__ uint32_t convert_e4m3x2_to_f16x2(uint16_t x) -{ - // Extract two e4m3 values from the 16-bit input - uint8_t e4m3_1 = x & 0xFF; // Lower 8 bits - uint8_t e4m3_2 = (x >> 8) & 0xFF; // Upper 8 bits - - // Decode e4m3 to float - auto e4m3_to_f32 = [](uint8_t e4m3) -> float { - // Extract sign, exponent, and mantissa - int sign = (e4m3 & 0x80) ? -1 : 1; - int exponent = ((e4m3 >> 3) & 0x0F) - 7; // 4-bit exponent with bias 7 - int mantissa = e4m3 & 0x07; // 3-bit mantissa - - // Handle special case: zero - if (exponent == -7 && mantissa == 0) { - return 0.0f; - } - - // Convert to float - float f32_val = sign * ldexpf(1.0f + mantissa / 8.0f, exponent); - return f32_val; - }; - - float f1 = e4m3_to_f32(e4m3_1); - float f2 = e4m3_to_f32(e4m3_2); - - // Convert float to IEEE f16 - __half h1 = __float2half_rn(f1); - __half h2 = __float2half_rn(f2); - - // Pack the two f16 values into a single uint32_t - uint32_t f16x2 = (__half_as_ushort(h2) << 16) | __half_as_ushort(h1); - return f16x2; +__device__ uint32_t convert_e4m3x2_to_f16x2(uint16_t x) { + // Extract two e4m3 values from the 16-bit input + uint8_t e4m3_1 = x & 0xFF; // Lower 8 bits + uint8_t e4m3_2 = (x >> 8) & 0xFF; // Upper 8 bits + + // Decode e4m3 to float + auto e4m3_to_f32 = [](uint8_t e4m3) -> float { + // Extract sign, exponent, and mantissa + int sign = (e4m3 & 0x80) ? -1 : 1; + int exponent = ((e4m3 >> 3) & 0x0F) - 7; // 4-bit exponent with bias 7 + int mantissa = e4m3 & 0x07; // 3-bit mantissa + + // Handle special case: zero + if (exponent == -7 && mantissa == 0) { + return 0.0f; + } + + // Convert to float + float f32_val = sign * ldexpf(1.0f + mantissa / 8.0f, exponent); + return f32_val; + }; + + float f1 = e4m3_to_f32(e4m3_1); + float f2 = e4m3_to_f32(e4m3_2); + + // Convert float to IEEE f16 + __half h1 = __float2half_rn(f1); + __half h2 = __float2half_rn(f2); + + // Pack the two f16 values into a single uint32_t + uint32_t f16x2 = (__half_as_ushort(h2) << 16) | __half_as_ushort(h1); + return f16x2; } -template <> struct vec_cast -{ - template - FLASHINFER_INLINE static void cast(half *dst, - const __hip_fp8_e4m3_fnuz *src) - { +template <> +struct vec_cast { + template + FLASHINFER_INLINE static void cast(half* dst, const __hip_fp8_e4m3_fnuz* src) { #ifdef FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED - if constexpr (vec_size == 1) { - dst[0] = half(src[0]); - } - else { + if constexpr (vec_size == 1) { + dst[0] = half(src[0]); + } else { #pragma unroll - for (size_t i = 0; i < vec_size / 2; ++i) { - uint32_t y; - uint16_t x = *(uint16_t *)&src[i * 2]; - y = convert_e4m3x2_to_f16x2(x); - - *(uint32_t *)&dst[i * 2] = y; - } - } + for (size_t i = 0; i < vec_size / 2; ++i) { + uint32_t y; + uint16_t x = *(uint16_t*)&src[i * 2]; + y = convert_e4m3x2_to_f16x2(x); + + *(uint32_t*)&dst[i * 2] = y; + } + } #else - if constexpr (vec_size == 1) { - dst[0] = half(src[0]); - } - else if constexpr (vec_size == 2) { - dst[0] = half(src[0]); - dst[1] = half(src[1]); - } - else { - static_assert(vec_size % 4 == 0, - "vec_size must be a multiple of 4"); + if constexpr (vec_size == 1) { + dst[0] = half(src[0]); + } else if constexpr (vec_size == 2) { + dst[0] = half(src[0]); + dst[1] = half(src[1]); + } else { + static_assert(vec_size % 4 == 0, "vec_size must be a multiple of 4"); #pragma unroll - for (uint32_t i = 0; i < vec_size / 4; ++i) { - fast_dequant_f8f16x4<__hip_fp8_e4m3_fnuz, half>( - (uint32_t *)&src[i * 4], (uint2 *)&dst[i * 4]); - } - } -#endif // FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED + for (uint32_t i = 0; i < vec_size / 4; ++i) { + fast_dequant_f8f16x4<__hip_fp8_e4m3_fnuz, half>((uint32_t*)&src[i * 4], + (uint2*)&dst[i * 4]); + } } +#endif // FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED + } }; -__device__ uint32_t convert_e5m2x2_to_f16x2(uint16_t x) -{ - // Extract two e5m2 values from the 16-bit input - uint8_t e5m2_1 = x & 0xFF; // Lower 8 bits - uint8_t e5m2_2 = (x >> 8) & 0xFF; // Upper 8 bits - - // Decode e5m2 to float - auto e5m2_to_f32 = [](uint8_t e5m2) -> float { - // Extract sign, exponent, and mantissa - int sign = (e5m2 & 0x80) ? -1 : 1; // Sign bit - int exponent = ((e5m2 >> 2) & 0x1F) - 15; // 5-bit exponent with bias 15 - int mantissa = e5m2 & 0x03; // 2-bit mantissa - - // Handle special case: zero - if (exponent == -15 && mantissa == 0) { - return 0.0f; - } - - // Convert to float - float value = sign * ldexpf(1.0f + mantissa / 4.0f, exponent); - return value; - }; - - float f1 = e5m2_to_f32(e5m2_1); - float f2 = e5m2_to_f32(e5m2_2); - - // Convert float to IEEE f16 - __half h1 = __float2half_rn(f1); - __half h2 = __float2half_rn(f2); - - // Pack the two f16 values into a single uint32_t - uint32_t f16x2 = (__half_as_ushort(h2) << 16) | __half_as_ushort(h1); - return f16x2; +__device__ uint32_t convert_e5m2x2_to_f16x2(uint16_t x) { + // Extract two e5m2 values from the 16-bit input + uint8_t e5m2_1 = x & 0xFF; // Lower 8 bits + uint8_t e5m2_2 = (x >> 8) & 0xFF; // Upper 8 bits + + // Decode e5m2 to float + auto e5m2_to_f32 = [](uint8_t e5m2) -> float { + // Extract sign, exponent, and mantissa + int sign = (e5m2 & 0x80) ? -1 : 1; // Sign bit + int exponent = ((e5m2 >> 2) & 0x1F) - 15; // 5-bit exponent with bias 15 + int mantissa = e5m2 & 0x03; // 2-bit mantissa + + // Handle special case: zero + if (exponent == -15 && mantissa == 0) { + return 0.0f; + } + + // Convert to float + float value = sign * ldexpf(1.0f + mantissa / 4.0f, exponent); + return value; + }; + + float f1 = e5m2_to_f32(e5m2_1); + float f2 = e5m2_to_f32(e5m2_2); + + // Convert float to IEEE f16 + __half h1 = __float2half_rn(f1); + __half h2 = __float2half_rn(f2); + + // Pack the two f16 values into a single uint32_t + uint32_t f16x2 = (__half_as_ushort(h2) << 16) | __half_as_ushort(h1); + return f16x2; } -template <> struct vec_cast -{ - template - FLASHINFER_INLINE static void cast(half *dst, - const __hip_fp8_e5m2_fnuz *src) - { +template <> +struct vec_cast { + template + FLASHINFER_INLINE static void cast(half* dst, const __hip_fp8_e5m2_fnuz* src) { #ifdef FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED - if constexpr (vec_size == 1) { - dst[0] = half(src[0]); - } - else { + if constexpr (vec_size == 1) { + dst[0] = half(src[0]); + } else { #pragma unroll - for (size_t i = 0; i < vec_size / 2; ++i) { - uint32_t y; - uint16_t x = *(uint16_t *)&src[i * 2]; - y = convert_e5m2x2_to_f16x2(x); - *(uint32_t *)&dst[i * 2] = y; - } - } + for (size_t i = 0; i < vec_size / 2; ++i) { + uint32_t y; + uint16_t x = *(uint16_t*)&src[i * 2]; + y = convert_e5m2x2_to_f16x2(x); + *(uint32_t*)&dst[i * 2] = y; + } + } #else - if constexpr (vec_size == 1) { - dst[0] = half(src[0]); - } - else if constexpr (vec_size == 2) { - dst[0] = half(src[0]); - dst[1] = half(src[1]); - } - else { - static_assert(vec_size % 4 == 0, - "vec_size must be a multiple of 4"); + if constexpr (vec_size == 1) { + dst[0] = half(src[0]); + } else if constexpr (vec_size == 2) { + dst[0] = half(src[0]); + dst[1] = half(src[1]); + } else { + static_assert(vec_size % 4 == 0, "vec_size must be a multiple of 4"); #pragma unroll - for (uint32_t i = 0; i < vec_size / 4; ++i) { - fast_dequant_f8f16x4<__hip_fp8_e5m2_fnuz, half>( - (uint32_t *)&src[i * 4], (uint2 *)&dst[i * 4]); - } - } -#endif // FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED + for (uint32_t i = 0; i < vec_size / 4; ++i) { + fast_dequant_f8f16x4<__hip_fp8_e5m2_fnuz, half>((uint32_t*)&src[i * 4], + (uint2*)&dst[i * 4]); + } } +#endif // FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED + } }; -template <> struct vec_cast -{ - template - FLASHINFER_INLINE static void cast(float *dst, const __hip_bfloat16 *src) - { - if constexpr (vec_size == 1) { - dst[0] = (float)src[0]; - } - else { +template <> +struct vec_cast { + template + FLASHINFER_INLINE static void cast(float* dst, const __hip_bfloat16* src) { + if constexpr (vec_size == 1) { + dst[0] = (float)src[0]; + } else { #pragma unroll - for (size_t i = 0; i < vec_size / 2; ++i) { - ((float2 *)dst)[i] = - __bfloat1622float2(((__hip_bfloat162 *)src)[i]); - } - } + for (size_t i = 0; i < vec_size / 2; ++i) { + ((float2*)dst)[i] = __bfloat1622float2(((__hip_bfloat162*)src)[i]); + } } + } }; -template <> struct vec_cast<__hip_bfloat16, float> -{ - template - FLASHINFER_INLINE static void cast(__hip_bfloat16 *dst, const float *src) - { - if constexpr (vec_size == 1) { - dst[0] = __hip_bfloat16(src[0]); - } - else { +template <> +struct vec_cast<__hip_bfloat16, float> { + template + FLASHINFER_INLINE static void cast(__hip_bfloat16* dst, const float* src) { + if constexpr (vec_size == 1) { + dst[0] = __hip_bfloat16(src[0]); + } else { #pragma unroll - for (size_t i = 0; i < vec_size / 2; ++i) { - ((__hip_bfloat162 *)dst)[i] = - __float22bfloat162_rn(((float2 *)src)[i]); - } - } + for (size_t i = 0; i < vec_size / 2; ++i) { + ((__hip_bfloat162*)dst)[i] = __float22bfloat162_rn(((float2*)src)[i]); + } } + } }; -template struct vec_t -{ - FLASHINFER_INLINE float_t &operator[](size_t i); - FLASHINFER_INLINE const float_t &operator[](size_t i) const; - FLASHINFER_INLINE void fill(float_t val); - FLASHINFER_INLINE void load(const float_t *ptr); - FLASHINFER_INLINE void store(float_t *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src); - template FLASHINFER_INLINE void cast_load(const T *ptr); - template FLASHINFER_INLINE void cast_store(T *ptr) const; - FLASHINFER_INLINE static void memcpy(float_t *dst, const float_t *src); - FLASHINFER_INLINE float_t *ptr(); +template +struct vec_t { + FLASHINFER_INLINE float_t& operator[](size_t i); + FLASHINFER_INLINE const float_t& operator[](size_t i) const; + FLASHINFER_INLINE void fill(float_t val); + FLASHINFER_INLINE void load(const float_t* ptr); + FLASHINFER_INLINE void store(float_t* ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t& src); + template + FLASHINFER_INLINE void cast_load(const T* ptr); + template + FLASHINFER_INLINE void cast_store(T* ptr) const; + FLASHINFER_INLINE static void memcpy(float_t* dst, const float_t* src); + FLASHINFER_INLINE float_t* ptr(); }; template -FLASHINFER_INLINE void cast_from_impl(vec_t &dst, - const vec_t &src) -{ - - vec_cast::template cast( - dst.ptr(), const_cast *>(&src)->ptr()); +FLASHINFER_INLINE void cast_from_impl(vec_t& dst, + const vec_t& src) { + vec_cast::template cast( + dst.ptr(), const_cast*>(&src)->ptr()); } template -FLASHINFER_INLINE void cast_load_impl(vec_t &dst, - const src_float_t *src_ptr) -{ - if constexpr (std::is_same_v) { - dst.load(src_ptr); - } - else { - vec_t tmp; - tmp.load(src_ptr); - dst.cast_from(tmp); - } +FLASHINFER_INLINE void cast_load_impl(vec_t& dst, + const src_float_t* src_ptr) { + if constexpr (std::is_same_v) { + dst.load(src_ptr); + } else { + vec_t tmp; + tmp.load(src_ptr); + dst.cast_from(tmp); + } } template -FLASHINFER_INLINE void cast_store_impl(tgt_float_t *dst_ptr, - const vec_t &src) -{ - if constexpr (std::is_same_v) { - src.store(dst_ptr); - } - else { - vec_t tmp; - tmp.cast_from(src); - tmp.store(dst_ptr); - } +FLASHINFER_INLINE void cast_store_impl(tgt_float_t* dst_ptr, + const vec_t& src) { + if constexpr (std::is_same_v) { + src.store(dst_ptr); + } else { + vec_t tmp; + tmp.cast_from(src); + tmp.store(dst_ptr); + } } /******************* vec_t<__hip_fp8_e4m3_fnuz> *******************/ // __hip_fp8_e4m3_fnuz x 1 -template <> struct vec_t<__hip_fp8_e4m3_fnuz, 1> -{ - __hip_fp8_e4m3_fnuz data; - - FLASHINFER_INLINE __hip_fp8_e4m3_fnuz &operator[](size_t i) - { - return ((__hip_fp8_e4m3_fnuz *)(&data))[i]; - } - FLASHINFER_INLINE const __hip_fp8_e4m3_fnuz &operator[](size_t i) const - { - return ((const __hip_fp8_e4m3_fnuz *)(&data))[i]; - } - FLASHINFER_INLINE __hip_fp8_e4m3_fnuz *ptr() - { - return reinterpret_cast<__hip_fp8_e4m3_fnuz *>(&data); - } - FLASHINFER_INLINE void fill(__hip_fp8_e4m3_fnuz val); - FLASHINFER_INLINE void load(const __hip_fp8_e4m3_fnuz *ptr); - FLASHINFER_INLINE void store(__hip_fp8_e4m3_fnuz *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) - { - cast_from_impl(*this, src); - } - template FLASHINFER_INLINE void cast_load(const T *ptr) - { - cast_load_impl(*this, ptr); - } - template FLASHINFER_INLINE void cast_store(T *ptr) const - { - cast_store_impl(ptr, *this); - } - - FLASHINFER_INLINE static void memcpy(__hip_fp8_e4m3_fnuz *dst, - const __hip_fp8_e4m3_fnuz *src); +template <> +struct vec_t<__hip_fp8_e4m3_fnuz, 1> { + __hip_fp8_e4m3_fnuz data; + + FLASHINFER_INLINE __hip_fp8_e4m3_fnuz& operator[](size_t i) { + return ((__hip_fp8_e4m3_fnuz*)(&data))[i]; + } + FLASHINFER_INLINE const __hip_fp8_e4m3_fnuz& operator[](size_t i) const { + return ((const __hip_fp8_e4m3_fnuz*)(&data))[i]; + } + FLASHINFER_INLINE __hip_fp8_e4m3_fnuz* ptr() { + return reinterpret_cast<__hip_fp8_e4m3_fnuz*>(&data); + } + FLASHINFER_INLINE void fill(__hip_fp8_e4m3_fnuz val); + FLASHINFER_INLINE void load(const __hip_fp8_e4m3_fnuz* ptr); + FLASHINFER_INLINE void store(__hip_fp8_e4m3_fnuz* ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + + FLASHINFER_INLINE static void memcpy(__hip_fp8_e4m3_fnuz* dst, const __hip_fp8_e4m3_fnuz* src); }; -FLASHINFER_INLINE void -vec_t<__hip_fp8_e4m3_fnuz, 1>::fill(__hip_fp8_e4m3_fnuz val) -{ - data = val; -} +FLASHINFER_INLINE void vec_t<__hip_fp8_e4m3_fnuz, 1>::fill(__hip_fp8_e4m3_fnuz val) { data = val; } -FLASHINFER_INLINE void -vec_t<__hip_fp8_e4m3_fnuz, 1>::load(const __hip_fp8_e4m3_fnuz *ptr) -{ - data = *ptr; +FLASHINFER_INLINE void vec_t<__hip_fp8_e4m3_fnuz, 1>::load(const __hip_fp8_e4m3_fnuz* ptr) { + data = *ptr; } -FLASHINFER_INLINE void -vec_t<__hip_fp8_e4m3_fnuz, 1>::store(__hip_fp8_e4m3_fnuz *ptr) const -{ - *ptr = data; +FLASHINFER_INLINE void vec_t<__hip_fp8_e4m3_fnuz, 1>::store(__hip_fp8_e4m3_fnuz* ptr) const { + *ptr = data; } -FLASHINFER_INLINE void -vec_t<__hip_fp8_e4m3_fnuz, 1>::memcpy(__hip_fp8_e4m3_fnuz *dst, - const __hip_fp8_e4m3_fnuz *src) -{ - *dst = *src; +FLASHINFER_INLINE void vec_t<__hip_fp8_e4m3_fnuz, 1>::memcpy(__hip_fp8_e4m3_fnuz* dst, + const __hip_fp8_e4m3_fnuz* src) { + *dst = *src; } // __hip_fp8_e4m3_fnuz x 2 -template <> struct vec_t<__hip_fp8_e4m3_fnuz, 2> -{ - __hip_fp8x2_e4m3_fnuz data; - - FLASHINFER_INLINE __hip_fp8_e4m3_fnuz &operator[](size_t i) - { - return ((__hip_fp8_e4m3_fnuz *)(&data))[i]; - } - FLASHINFER_INLINE const __hip_fp8_e4m3_fnuz &operator[](size_t i) const - { - return ((const __hip_fp8_e4m3_fnuz *)(&data))[i]; - } - FLASHINFER_INLINE __hip_fp8_e4m3_fnuz *ptr() - { - return reinterpret_cast<__hip_fp8_e4m3_fnuz *>(&data); - } - FLASHINFER_INLINE void fill(__hip_fp8_e4m3_fnuz val); - FLASHINFER_INLINE void load(const __hip_fp8_e4m3_fnuz *ptr); - FLASHINFER_INLINE void store(__hip_fp8_e4m3_fnuz *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) - { - cast_from_impl(*this, src); - } - template FLASHINFER_INLINE void cast_load(const T *ptr) - { - cast_load_impl(*this, ptr); - } - template FLASHINFER_INLINE void cast_store(T *ptr) const - { - cast_store_impl(ptr, *this); - } - FLASHINFER_INLINE static void memcpy(__hip_fp8_e4m3_fnuz *dst, - const __hip_fp8_e4m3_fnuz *src); +template <> +struct vec_t<__hip_fp8_e4m3_fnuz, 2> { + __hip_fp8x2_e4m3_fnuz data; + + FLASHINFER_INLINE __hip_fp8_e4m3_fnuz& operator[](size_t i) { + return ((__hip_fp8_e4m3_fnuz*)(&data))[i]; + } + FLASHINFER_INLINE const __hip_fp8_e4m3_fnuz& operator[](size_t i) const { + return ((const __hip_fp8_e4m3_fnuz*)(&data))[i]; + } + FLASHINFER_INLINE __hip_fp8_e4m3_fnuz* ptr() { + return reinterpret_cast<__hip_fp8_e4m3_fnuz*>(&data); + } + FLASHINFER_INLINE void fill(__hip_fp8_e4m3_fnuz val); + FLASHINFER_INLINE void load(const __hip_fp8_e4m3_fnuz* ptr); + FLASHINFER_INLINE void store(__hip_fp8_e4m3_fnuz* ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + FLASHINFER_INLINE static void memcpy(__hip_fp8_e4m3_fnuz* dst, const __hip_fp8_e4m3_fnuz* src); }; -FLASHINFER_INLINE void -vec_t<__hip_fp8_e4m3_fnuz, 2>::fill(__hip_fp8_e4m3_fnuz val) -{ - data.__x = - (__hip_fp8x2_storage_t(val.__x) << 8) | __hip_fp8x2_storage_t(val.__x); +FLASHINFER_INLINE void vec_t<__hip_fp8_e4m3_fnuz, 2>::fill(__hip_fp8_e4m3_fnuz val) { + data.__x = (__hip_fp8x2_storage_t(val.__x) << 8) | __hip_fp8x2_storage_t(val.__x); } -FLASHINFER_INLINE void -vec_t<__hip_fp8_e4m3_fnuz, 2>::load(const __hip_fp8_e4m3_fnuz *ptr) -{ - data = *((__hip_fp8x2_e4m3_fnuz *)ptr); +FLASHINFER_INLINE void vec_t<__hip_fp8_e4m3_fnuz, 2>::load(const __hip_fp8_e4m3_fnuz* ptr) { + data = *((__hip_fp8x2_e4m3_fnuz*)ptr); } -FLASHINFER_INLINE void -vec_t<__hip_fp8_e4m3_fnuz, 2>::store(__hip_fp8_e4m3_fnuz *ptr) const -{ - *((__hip_fp8x2_e4m3_fnuz *)ptr) = data; +FLASHINFER_INLINE void vec_t<__hip_fp8_e4m3_fnuz, 2>::store(__hip_fp8_e4m3_fnuz* ptr) const { + *((__hip_fp8x2_e4m3_fnuz*)ptr) = data; } -FLASHINFER_INLINE void -vec_t<__hip_fp8_e4m3_fnuz, 2>::memcpy(__hip_fp8_e4m3_fnuz *dst, - const __hip_fp8_e4m3_fnuz *src) -{ - *((__hip_fp8x2_e4m3_fnuz *)dst) = *((__hip_fp8x2_e4m3_fnuz *)src); +FLASHINFER_INLINE void vec_t<__hip_fp8_e4m3_fnuz, 2>::memcpy(__hip_fp8_e4m3_fnuz* dst, + const __hip_fp8_e4m3_fnuz* src) { + *((__hip_fp8x2_e4m3_fnuz*)dst) = *((__hip_fp8x2_e4m3_fnuz*)src); } // __hip_fp8_e4m3_fnuz x 4 -template <> struct vec_t<__hip_fp8_e4m3_fnuz, 4> -{ - __hip_fp8x4_e4m3_fnuz data; - - FLASHINFER_INLINE __hip_fp8_e4m3_fnuz &operator[](size_t i) - { - return ((__hip_fp8_e4m3_fnuz *)(&data))[i]; - } - FLASHINFER_INLINE const __hip_fp8_e4m3_fnuz &operator[](size_t i) const - { - return ((const __hip_fp8_e4m3_fnuz *)(&data))[i]; - } - FLASHINFER_INLINE __hip_fp8_e4m3_fnuz *ptr() - { - return reinterpret_cast<__hip_fp8_e4m3_fnuz *>(&data); - } - FLASHINFER_INLINE void fill(__hip_fp8_e4m3_fnuz val); - FLASHINFER_INLINE void load(const __hip_fp8_e4m3_fnuz *ptr); - FLASHINFER_INLINE void store(__hip_fp8_e4m3_fnuz *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) - { - cast_from_impl(*this, src); - } - template FLASHINFER_INLINE void cast_load(const T *ptr) - { - cast_load_impl(*this, ptr); - } - template FLASHINFER_INLINE void cast_store(T *ptr) const - { - cast_store_impl(ptr, *this); - } - - FLASHINFER_INLINE static void memcpy(__hip_fp8_e4m3_fnuz *dst, - const __hip_fp8_e4m3_fnuz *src); +template <> +struct vec_t<__hip_fp8_e4m3_fnuz, 4> { + __hip_fp8x4_e4m3_fnuz data; + + FLASHINFER_INLINE __hip_fp8_e4m3_fnuz& operator[](size_t i) { + return ((__hip_fp8_e4m3_fnuz*)(&data))[i]; + } + FLASHINFER_INLINE const __hip_fp8_e4m3_fnuz& operator[](size_t i) const { + return ((const __hip_fp8_e4m3_fnuz*)(&data))[i]; + } + FLASHINFER_INLINE __hip_fp8_e4m3_fnuz* ptr() { + return reinterpret_cast<__hip_fp8_e4m3_fnuz*>(&data); + } + FLASHINFER_INLINE void fill(__hip_fp8_e4m3_fnuz val); + FLASHINFER_INLINE void load(const __hip_fp8_e4m3_fnuz* ptr); + FLASHINFER_INLINE void store(__hip_fp8_e4m3_fnuz* ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + + FLASHINFER_INLINE static void memcpy(__hip_fp8_e4m3_fnuz* dst, const __hip_fp8_e4m3_fnuz* src); }; -FLASHINFER_INLINE void -vec_t<__hip_fp8_e4m3_fnuz, 4>::fill(__hip_fp8_e4m3_fnuz val) -{ - data.__x = (__hip_fp8x4_storage_t(val.__x) << 24) | - (__hip_fp8x4_storage_t(val.__x) << 16) | - (__hip_fp8x4_storage_t(val.__x) << 8) | - __hip_fp8x4_storage_t(val.__x); +FLASHINFER_INLINE void vec_t<__hip_fp8_e4m3_fnuz, 4>::fill(__hip_fp8_e4m3_fnuz val) { + data.__x = (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | + (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); } -FLASHINFER_INLINE void -vec_t<__hip_fp8_e4m3_fnuz, 4>::load(const __hip_fp8_e4m3_fnuz *ptr) -{ - data = *((__hip_fp8x4_e4m3_fnuz *)ptr); +FLASHINFER_INLINE void vec_t<__hip_fp8_e4m3_fnuz, 4>::load(const __hip_fp8_e4m3_fnuz* ptr) { + data = *((__hip_fp8x4_e4m3_fnuz*)ptr); } -FLASHINFER_INLINE void -vec_t<__hip_fp8_e4m3_fnuz, 4>::store(__hip_fp8_e4m3_fnuz *ptr) const -{ - *((__hip_fp8x4_e4m3_fnuz *)ptr) = data; +FLASHINFER_INLINE void vec_t<__hip_fp8_e4m3_fnuz, 4>::store(__hip_fp8_e4m3_fnuz* ptr) const { + *((__hip_fp8x4_e4m3_fnuz*)ptr) = data; } -FLASHINFER_INLINE void -vec_t<__hip_fp8_e4m3_fnuz, 4>::memcpy(__hip_fp8_e4m3_fnuz *dst, - const __hip_fp8_e4m3_fnuz *src) -{ - *((__hip_fp8x4_e4m3_fnuz *)dst) = *((__hip_fp8x4_e4m3_fnuz *)src); +FLASHINFER_INLINE void vec_t<__hip_fp8_e4m3_fnuz, 4>::memcpy(__hip_fp8_e4m3_fnuz* dst, + const __hip_fp8_e4m3_fnuz* src) { + *((__hip_fp8x4_e4m3_fnuz*)dst) = *((__hip_fp8x4_e4m3_fnuz*)src); } // __hip_fp8_e4m3_fnuz x 8 -template <> struct vec_t<__hip_fp8_e4m3_fnuz, 8> -{ - uint2 data; - - FLASHINFER_INLINE __hip_fp8_e4m3_fnuz &operator[](size_t i) - { - return ((__hip_fp8_e4m3_fnuz *)(&data))[i]; - } - FLASHINFER_INLINE const __hip_fp8_e4m3_fnuz &operator[](size_t i) const - { - return ((const __hip_fp8_e4m3_fnuz *)(&data))[i]; - } - FLASHINFER_INLINE __hip_fp8_e4m3_fnuz *ptr() - { - return reinterpret_cast<__hip_fp8_e4m3_fnuz *>(&data); - } - FLASHINFER_INLINE void fill(__hip_fp8_e4m3_fnuz val); - FLASHINFER_INLINE void load(const __hip_fp8_e4m3_fnuz *ptr); - FLASHINFER_INLINE void store(__hip_fp8_e4m3_fnuz *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) - { - cast_from_impl(*this, src); - } - template FLASHINFER_INLINE void cast_load(const T *ptr) - { - cast_load_impl(*this, ptr); - } - template FLASHINFER_INLINE void cast_store(T *ptr) const - { - cast_store_impl(ptr, *this); - } - - FLASHINFER_INLINE static void memcpy(__hip_fp8_e4m3_fnuz *dst, - const __hip_fp8_e4m3_fnuz *src); +template <> +struct vec_t<__hip_fp8_e4m3_fnuz, 8> { + uint2 data; + + FLASHINFER_INLINE __hip_fp8_e4m3_fnuz& operator[](size_t i) { + return ((__hip_fp8_e4m3_fnuz*)(&data))[i]; + } + FLASHINFER_INLINE const __hip_fp8_e4m3_fnuz& operator[](size_t i) const { + return ((const __hip_fp8_e4m3_fnuz*)(&data))[i]; + } + FLASHINFER_INLINE __hip_fp8_e4m3_fnuz* ptr() { + return reinterpret_cast<__hip_fp8_e4m3_fnuz*>(&data); + } + FLASHINFER_INLINE void fill(__hip_fp8_e4m3_fnuz val); + FLASHINFER_INLINE void load(const __hip_fp8_e4m3_fnuz* ptr); + FLASHINFER_INLINE void store(__hip_fp8_e4m3_fnuz* ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + + FLASHINFER_INLINE static void memcpy(__hip_fp8_e4m3_fnuz* dst, const __hip_fp8_e4m3_fnuz* src); }; -FLASHINFER_INLINE void -vec_t<__hip_fp8_e4m3_fnuz, 8>::fill(__hip_fp8_e4m3_fnuz val) -{ - ((__hip_fp8x4_e4m3_fnuz *)(&data.x))->__x = - (__hip_fp8x4_storage_t(val.__x) << 24) | - (__hip_fp8x4_storage_t(val.__x) << 16) | - (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); - ((__hip_fp8x4_e4m3_fnuz *)(&data.y))->__x = - (__hip_fp8x4_storage_t(val.__x) << 24) | - (__hip_fp8x4_storage_t(val.__x) << 16) | - (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); +FLASHINFER_INLINE void vec_t<__hip_fp8_e4m3_fnuz, 8>::fill(__hip_fp8_e4m3_fnuz val) { + ((__hip_fp8x4_e4m3_fnuz*)(&data.x))->__x = + (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | + (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); + ((__hip_fp8x4_e4m3_fnuz*)(&data.y))->__x = + (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | + (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); } -FLASHINFER_INLINE void -vec_t<__hip_fp8_e4m3_fnuz, 8>::load(const __hip_fp8_e4m3_fnuz *ptr) -{ - data = *((uint2 *)ptr); +FLASHINFER_INLINE void vec_t<__hip_fp8_e4m3_fnuz, 8>::load(const __hip_fp8_e4m3_fnuz* ptr) { + data = *((uint2*)ptr); } -FLASHINFER_INLINE void -vec_t<__hip_fp8_e4m3_fnuz, 8>::store(__hip_fp8_e4m3_fnuz *ptr) const -{ - *((uint2 *)ptr) = data; +FLASHINFER_INLINE void vec_t<__hip_fp8_e4m3_fnuz, 8>::store(__hip_fp8_e4m3_fnuz* ptr) const { + *((uint2*)ptr) = data; } -FLASHINFER_INLINE void -vec_t<__hip_fp8_e4m3_fnuz, 8>::memcpy(__hip_fp8_e4m3_fnuz *dst, - const __hip_fp8_e4m3_fnuz *src) -{ - *((uint2 *)dst) = *((uint2 *)src); +FLASHINFER_INLINE void vec_t<__hip_fp8_e4m3_fnuz, 8>::memcpy(__hip_fp8_e4m3_fnuz* dst, + const __hip_fp8_e4m3_fnuz* src) { + *((uint2*)dst) = *((uint2*)src); } // __hip_fp8_e4m3_fnuz x 16 or more -template struct vec_t<__hip_fp8_e4m3_fnuz, vec_size> -{ - uint4 data[vec_size / 16]; - - FLASHINFER_INLINE __hip_fp8_e4m3_fnuz &operator[](size_t i) - { - return ((__hip_fp8_e4m3_fnuz *)data)[i]; - } - FLASHINFER_INLINE const __hip_fp8_e4m3_fnuz &operator[](size_t i) const - { - return ((const __hip_fp8_e4m3_fnuz *)data)[i]; - } - FLASHINFER_INLINE __hip_fp8_e4m3_fnuz *ptr() - { - return reinterpret_cast<__hip_fp8_e4m3_fnuz *>(&data); - } - FLASHINFER_INLINE void fill(__hip_fp8_e4m3_fnuz val) - { +template +struct vec_t<__hip_fp8_e4m3_fnuz, vec_size> { + uint4 data[vec_size / 16]; + + FLASHINFER_INLINE __hip_fp8_e4m3_fnuz& operator[](size_t i) { + return ((__hip_fp8_e4m3_fnuz*)data)[i]; + } + FLASHINFER_INLINE const __hip_fp8_e4m3_fnuz& operator[](size_t i) const { + return ((const __hip_fp8_e4m3_fnuz*)data)[i]; + } + FLASHINFER_INLINE __hip_fp8_e4m3_fnuz* ptr() { + return reinterpret_cast<__hip_fp8_e4m3_fnuz*>(&data); + } + FLASHINFER_INLINE void fill(__hip_fp8_e4m3_fnuz val) { #pragma unroll - for (size_t i = 0; i < vec_size / 16; ++i) { - ((__hip_fp8x4_e4m3_fnuz *)(&(data[i].x)))->__x = - (__hip_fp8x4_storage_t(val.__x) << 24) | - (__hip_fp8x4_storage_t(val.__x) << 16) | - (__hip_fp8x4_storage_t(val.__x) << 8) | - __hip_fp8x4_storage_t(val.__x); - ((__hip_fp8x4_e4m3_fnuz *)(&(data[i].y)))->__x = - (__hip_fp8x4_storage_t(val.__x) << 24) | - (__hip_fp8x4_storage_t(val.__x) << 16) | - (__hip_fp8x4_storage_t(val.__x) << 8) | - __hip_fp8x4_storage_t(val.__x); - ((__hip_fp8x4_e4m3_fnuz *)(&(data[i].z)))->__x = - (__hip_fp8x4_storage_t(val.__x) << 24) | - (__hip_fp8x4_storage_t(val.__x) << 16) | - (__hip_fp8x4_storage_t(val.__x) << 8) | - __hip_fp8x4_storage_t(val.__x); - ((__hip_fp8x4_e4m3_fnuz *)(&(data[i].w)))->__x = - (__hip_fp8x4_storage_t(val.__x) << 24) | - (__hip_fp8x4_storage_t(val.__x) << 16) | - (__hip_fp8x4_storage_t(val.__x) << 8) | - __hip_fp8x4_storage_t(val.__x); - } - } - FLASHINFER_INLINE void load(const __hip_fp8_e4m3_fnuz *ptr) - { + for (size_t i = 0; i < vec_size / 16; ++i) { + ((__hip_fp8x4_e4m3_fnuz*)(&(data[i].x)))->__x = + (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | + (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); + ((__hip_fp8x4_e4m3_fnuz*)(&(data[i].y)))->__x = + (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | + (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); + ((__hip_fp8x4_e4m3_fnuz*)(&(data[i].z)))->__x = + (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | + (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); + ((__hip_fp8x4_e4m3_fnuz*)(&(data[i].w)))->__x = + (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | + (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); + } + } + FLASHINFER_INLINE void load(const __hip_fp8_e4m3_fnuz* ptr) { #pragma unroll - for (size_t i = 0; i < vec_size / 16; ++i) { - data[i] = ((uint4 *)ptr)[i]; - } + for (size_t i = 0; i < vec_size / 16; ++i) { + data[i] = ((uint4*)ptr)[i]; } - FLASHINFER_INLINE void store(__hip_fp8_e4m3_fnuz *ptr) const - { + } + FLASHINFER_INLINE void store(__hip_fp8_e4m3_fnuz* ptr) const { #pragma unroll - for (size_t i = 0; i < vec_size / 16; ++i) { - ((uint4 *)ptr)[i] = data[i]; - } - } - template - FLASHINFER_INLINE void cast_from(const vec_t &src) - { - cast_from_impl(*this, src); - } - template FLASHINFER_INLINE void cast_load(const T *ptr) - { - cast_load_impl(*this, ptr); - } - template FLASHINFER_INLINE void cast_store(T *ptr) const - { - cast_store_impl(ptr, *this); - } - - FLASHINFER_INLINE static void memcpy(__hip_fp8_e4m3_fnuz *dst, - const __hip_fp8_e4m3_fnuz *src) - { + for (size_t i = 0; i < vec_size / 16; ++i) { + ((uint4*)ptr)[i] = data[i]; + } + } + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + + FLASHINFER_INLINE static void memcpy(__hip_fp8_e4m3_fnuz* dst, const __hip_fp8_e4m3_fnuz* src) { #pragma unroll - for (size_t i = 0; i < vec_size / 16; ++i) { - ((uint4 *)dst)[i] = ((uint4 *)src)[i]; - } + for (size_t i = 0; i < vec_size / 16; ++i) { + ((uint4*)dst)[i] = ((uint4*)src)[i]; } + } }; /******************* vec_t<__hip_fp8_e5m2_fnuz> *******************/ // __hip_fp8_e5m2_fnuz x 1 -template <> struct vec_t<__hip_fp8_e5m2_fnuz, 1> -{ - __hip_fp8_e5m2_fnuz data; - - FLASHINFER_INLINE __hip_fp8_e5m2_fnuz &operator[](size_t i) - { - return ((__hip_fp8_e5m2_fnuz *)(&data))[i]; - } - FLASHINFER_INLINE const __hip_fp8_e5m2_fnuz &operator[](size_t i) const - { - return ((const __hip_fp8_e5m2_fnuz *)(&data))[i]; - } - FLASHINFER_INLINE __hip_fp8_e5m2_fnuz *ptr() - { - return reinterpret_cast<__hip_fp8_e5m2_fnuz *>(&data); - } - FLASHINFER_INLINE void fill(__hip_fp8_e5m2_fnuz val); - FLASHINFER_INLINE void load(const __hip_fp8_e5m2_fnuz *ptr); - FLASHINFER_INLINE void store(__hip_fp8_e5m2_fnuz *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) - { - cast_from_impl(*this, src); - } - template FLASHINFER_INLINE void cast_load(const T *ptr) - { - cast_load_impl(*this, ptr); - } - template FLASHINFER_INLINE void cast_store(T *ptr) const - { - cast_store_impl(ptr, *this); - } - - FLASHINFER_INLINE static void memcpy(__hip_fp8_e5m2_fnuz *dst, - const __hip_fp8_e5m2_fnuz *src); +template <> +struct vec_t<__hip_fp8_e5m2_fnuz, 1> { + __hip_fp8_e5m2_fnuz data; + + FLASHINFER_INLINE __hip_fp8_e5m2_fnuz& operator[](size_t i) { + return ((__hip_fp8_e5m2_fnuz*)(&data))[i]; + } + FLASHINFER_INLINE const __hip_fp8_e5m2_fnuz& operator[](size_t i) const { + return ((const __hip_fp8_e5m2_fnuz*)(&data))[i]; + } + FLASHINFER_INLINE __hip_fp8_e5m2_fnuz* ptr() { + return reinterpret_cast<__hip_fp8_e5m2_fnuz*>(&data); + } + FLASHINFER_INLINE void fill(__hip_fp8_e5m2_fnuz val); + FLASHINFER_INLINE void load(const __hip_fp8_e5m2_fnuz* ptr); + FLASHINFER_INLINE void store(__hip_fp8_e5m2_fnuz* ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + + FLASHINFER_INLINE static void memcpy(__hip_fp8_e5m2_fnuz* dst, const __hip_fp8_e5m2_fnuz* src); }; -FLASHINFER_INLINE void -vec_t<__hip_fp8_e5m2_fnuz, 1>::fill(__hip_fp8_e5m2_fnuz val) -{ - data = val; -} +FLASHINFER_INLINE void vec_t<__hip_fp8_e5m2_fnuz, 1>::fill(__hip_fp8_e5m2_fnuz val) { data = val; } -FLASHINFER_INLINE void -vec_t<__hip_fp8_e5m2_fnuz, 1>::load(const __hip_fp8_e5m2_fnuz *ptr) -{ - data = *ptr; +FLASHINFER_INLINE void vec_t<__hip_fp8_e5m2_fnuz, 1>::load(const __hip_fp8_e5m2_fnuz* ptr) { + data = *ptr; } -FLASHINFER_INLINE void -vec_t<__hip_fp8_e5m2_fnuz, 1>::store(__hip_fp8_e5m2_fnuz *ptr) const -{ - *ptr = data; +FLASHINFER_INLINE void vec_t<__hip_fp8_e5m2_fnuz, 1>::store(__hip_fp8_e5m2_fnuz* ptr) const { + *ptr = data; } -FLASHINFER_INLINE void -vec_t<__hip_fp8_e5m2_fnuz, 1>::memcpy(__hip_fp8_e5m2_fnuz *dst, - const __hip_fp8_e5m2_fnuz *src) -{ - *dst = *src; +FLASHINFER_INLINE void vec_t<__hip_fp8_e5m2_fnuz, 1>::memcpy(__hip_fp8_e5m2_fnuz* dst, + const __hip_fp8_e5m2_fnuz* src) { + *dst = *src; } // __hip_fp8_e5m2_fnuz x 2 -template <> struct vec_t<__hip_fp8_e5m2_fnuz, 2> -{ - __hip_fp8x2_e5m2_fnuz data; - - FLASHINFER_INLINE __hip_fp8_e5m2_fnuz &operator[](size_t i) - { - return ((__hip_fp8_e5m2_fnuz *)(&data))[i]; - } - FLASHINFER_INLINE const __hip_fp8_e5m2_fnuz &operator[](size_t i) const - { - return ((const __hip_fp8_e5m2_fnuz *)(&data))[i]; - } - FLASHINFER_INLINE __hip_fp8_e5m2_fnuz *ptr() - { - return reinterpret_cast<__hip_fp8_e5m2_fnuz *>(&data); - } - FLASHINFER_INLINE void fill(__hip_fp8_e5m2_fnuz val); - FLASHINFER_INLINE void load(const __hip_fp8_e5m2_fnuz *ptr); - FLASHINFER_INLINE void store(__hip_fp8_e5m2_fnuz *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) - { - cast_from_impl(*this, src); - } - template FLASHINFER_INLINE void cast_load(const T *ptr) - { - cast_load_impl(*this, ptr); - } - template FLASHINFER_INLINE void cast_store(T *ptr) const - { - cast_store_impl(ptr, *this); - } - - FLASHINFER_INLINE static void memcpy(__hip_fp8_e5m2_fnuz *dst, - const __hip_fp8_e5m2_fnuz *src); +template <> +struct vec_t<__hip_fp8_e5m2_fnuz, 2> { + __hip_fp8x2_e5m2_fnuz data; + + FLASHINFER_INLINE __hip_fp8_e5m2_fnuz& operator[](size_t i) { + return ((__hip_fp8_e5m2_fnuz*)(&data))[i]; + } + FLASHINFER_INLINE const __hip_fp8_e5m2_fnuz& operator[](size_t i) const { + return ((const __hip_fp8_e5m2_fnuz*)(&data))[i]; + } + FLASHINFER_INLINE __hip_fp8_e5m2_fnuz* ptr() { + return reinterpret_cast<__hip_fp8_e5m2_fnuz*>(&data); + } + FLASHINFER_INLINE void fill(__hip_fp8_e5m2_fnuz val); + FLASHINFER_INLINE void load(const __hip_fp8_e5m2_fnuz* ptr); + FLASHINFER_INLINE void store(__hip_fp8_e5m2_fnuz* ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + + FLASHINFER_INLINE static void memcpy(__hip_fp8_e5m2_fnuz* dst, const __hip_fp8_e5m2_fnuz* src); }; -FLASHINFER_INLINE void -vec_t<__hip_fp8_e5m2_fnuz, 2>::fill(__hip_fp8_e5m2_fnuz val) -{ - data.__x = - (__hip_fp8x2_storage_t(val.__x) << 8) | __hip_fp8x2_storage_t(val.__x); +FLASHINFER_INLINE void vec_t<__hip_fp8_e5m2_fnuz, 2>::fill(__hip_fp8_e5m2_fnuz val) { + data.__x = (__hip_fp8x2_storage_t(val.__x) << 8) | __hip_fp8x2_storage_t(val.__x); } -FLASHINFER_INLINE void -vec_t<__hip_fp8_e5m2_fnuz, 2>::load(const __hip_fp8_e5m2_fnuz *ptr) -{ - data = *((__hip_fp8x2_e5m2_fnuz *)ptr); +FLASHINFER_INLINE void vec_t<__hip_fp8_e5m2_fnuz, 2>::load(const __hip_fp8_e5m2_fnuz* ptr) { + data = *((__hip_fp8x2_e5m2_fnuz*)ptr); } -FLASHINFER_INLINE void -vec_t<__hip_fp8_e5m2_fnuz, 2>::store(__hip_fp8_e5m2_fnuz *ptr) const -{ - *((__hip_fp8x2_e5m2_fnuz *)ptr) = data; +FLASHINFER_INLINE void vec_t<__hip_fp8_e5m2_fnuz, 2>::store(__hip_fp8_e5m2_fnuz* ptr) const { + *((__hip_fp8x2_e5m2_fnuz*)ptr) = data; } -FLASHINFER_INLINE void -vec_t<__hip_fp8_e5m2_fnuz, 2>::memcpy(__hip_fp8_e5m2_fnuz *dst, - const __hip_fp8_e5m2_fnuz *src) -{ - *((__hip_fp8x2_e5m2_fnuz *)dst) = *((__hip_fp8x2_e5m2_fnuz *)src); +FLASHINFER_INLINE void vec_t<__hip_fp8_e5m2_fnuz, 2>::memcpy(__hip_fp8_e5m2_fnuz* dst, + const __hip_fp8_e5m2_fnuz* src) { + *((__hip_fp8x2_e5m2_fnuz*)dst) = *((__hip_fp8x2_e5m2_fnuz*)src); } // __hip_fp8_e5m2_fnuz x 4 -template <> struct vec_t<__hip_fp8_e5m2_fnuz, 4> -{ - __hip_fp8x4_e5m2_fnuz data; - - FLASHINFER_INLINE __hip_fp8_e5m2_fnuz &operator[](size_t i) - { - return ((__hip_fp8_e5m2_fnuz *)(&data))[i]; - } - FLASHINFER_INLINE const __hip_fp8_e5m2_fnuz &operator[](size_t i) const - { - return ((const __hip_fp8_e5m2_fnuz *)(&data))[i]; - } - FLASHINFER_INLINE __hip_fp8_e5m2_fnuz *ptr() - { - return reinterpret_cast<__hip_fp8_e5m2_fnuz *>(&data); - } - FLASHINFER_INLINE void fill(__hip_fp8_e5m2_fnuz val); - FLASHINFER_INLINE void load(const __hip_fp8_e5m2_fnuz *ptr); - FLASHINFER_INLINE void store(__hip_fp8_e5m2_fnuz *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) - { - cast_from_impl(*this, src); - } - template FLASHINFER_INLINE void cast_load(const T *ptr) - { - cast_load_impl(*this, ptr); - } - template FLASHINFER_INLINE void cast_store(T *ptr) const - { - cast_store_impl(ptr, *this); - } - - FLASHINFER_INLINE static void memcpy(__hip_fp8_e5m2_fnuz *dst, - const __hip_fp8_e5m2_fnuz *src); +template <> +struct vec_t<__hip_fp8_e5m2_fnuz, 4> { + __hip_fp8x4_e5m2_fnuz data; + + FLASHINFER_INLINE __hip_fp8_e5m2_fnuz& operator[](size_t i) { + return ((__hip_fp8_e5m2_fnuz*)(&data))[i]; + } + FLASHINFER_INLINE const __hip_fp8_e5m2_fnuz& operator[](size_t i) const { + return ((const __hip_fp8_e5m2_fnuz*)(&data))[i]; + } + FLASHINFER_INLINE __hip_fp8_e5m2_fnuz* ptr() { + return reinterpret_cast<__hip_fp8_e5m2_fnuz*>(&data); + } + FLASHINFER_INLINE void fill(__hip_fp8_e5m2_fnuz val); + FLASHINFER_INLINE void load(const __hip_fp8_e5m2_fnuz* ptr); + FLASHINFER_INLINE void store(__hip_fp8_e5m2_fnuz* ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + + FLASHINFER_INLINE static void memcpy(__hip_fp8_e5m2_fnuz* dst, const __hip_fp8_e5m2_fnuz* src); }; -FLASHINFER_INLINE void -vec_t<__hip_fp8_e5m2_fnuz, 4>::fill(__hip_fp8_e5m2_fnuz val) -{ - data.__x = (__hip_fp8x4_storage_t(val.__x) << 24) | - (__hip_fp8x4_storage_t(val.__x) << 16) | - (__hip_fp8x4_storage_t(val.__x) << 8) | - __hip_fp8x4_storage_t(val.__x); +FLASHINFER_INLINE void vec_t<__hip_fp8_e5m2_fnuz, 4>::fill(__hip_fp8_e5m2_fnuz val) { + data.__x = (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | + (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); } -FLASHINFER_INLINE void -vec_t<__hip_fp8_e5m2_fnuz, 4>::load(const __hip_fp8_e5m2_fnuz *ptr) -{ - data = *((__hip_fp8x4_e5m2_fnuz *)ptr); +FLASHINFER_INLINE void vec_t<__hip_fp8_e5m2_fnuz, 4>::load(const __hip_fp8_e5m2_fnuz* ptr) { + data = *((__hip_fp8x4_e5m2_fnuz*)ptr); } -FLASHINFER_INLINE void -vec_t<__hip_fp8_e5m2_fnuz, 4>::store(__hip_fp8_e5m2_fnuz *ptr) const -{ - *((__hip_fp8x4_e5m2_fnuz *)ptr) = data; +FLASHINFER_INLINE void vec_t<__hip_fp8_e5m2_fnuz, 4>::store(__hip_fp8_e5m2_fnuz* ptr) const { + *((__hip_fp8x4_e5m2_fnuz*)ptr) = data; } -FLASHINFER_INLINE void -vec_t<__hip_fp8_e5m2_fnuz, 4>::memcpy(__hip_fp8_e5m2_fnuz *dst, - const __hip_fp8_e5m2_fnuz *src) -{ - *((__hip_fp8x4_e5m2_fnuz *)dst) = *((__hip_fp8x4_e5m2_fnuz *)src); +FLASHINFER_INLINE void vec_t<__hip_fp8_e5m2_fnuz, 4>::memcpy(__hip_fp8_e5m2_fnuz* dst, + const __hip_fp8_e5m2_fnuz* src) { + *((__hip_fp8x4_e5m2_fnuz*)dst) = *((__hip_fp8x4_e5m2_fnuz*)src); } // __hip_fp8_e5m2_fnuz x 8 -template <> struct vec_t<__hip_fp8_e5m2_fnuz, 8> -{ - uint2 data; - - FLASHINFER_INLINE __hip_fp8_e5m2_fnuz &operator[](size_t i) - { - return ((__hip_fp8_e5m2_fnuz *)(&data))[i]; - } - FLASHINFER_INLINE const __hip_fp8_e5m2_fnuz &operator[](size_t i) const - { - return ((const __hip_fp8_e5m2_fnuz *)(&data))[i]; - } - FLASHINFER_INLINE __hip_fp8_e5m2_fnuz *ptr() - { - return reinterpret_cast<__hip_fp8_e5m2_fnuz *>(&data); - } - FLASHINFER_INLINE void fill(__hip_fp8_e5m2_fnuz val); - FLASHINFER_INLINE void load(const __hip_fp8_e5m2_fnuz *ptr); - FLASHINFER_INLINE void store(__hip_fp8_e5m2_fnuz *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) - { - cast_from_impl(*this, src); - } - template FLASHINFER_INLINE void cast_load(const T *ptr) - { - cast_load_impl(*this, ptr); - } - template FLASHINFER_INLINE void cast_store(T *ptr) const - { - cast_store_impl(ptr, *this); - } - FLASHINFER_INLINE static void memcpy(__hip_fp8_e5m2_fnuz *dst, - const __hip_fp8_e5m2_fnuz *src); +template <> +struct vec_t<__hip_fp8_e5m2_fnuz, 8> { + uint2 data; + + FLASHINFER_INLINE __hip_fp8_e5m2_fnuz& operator[](size_t i) { + return ((__hip_fp8_e5m2_fnuz*)(&data))[i]; + } + FLASHINFER_INLINE const __hip_fp8_e5m2_fnuz& operator[](size_t i) const { + return ((const __hip_fp8_e5m2_fnuz*)(&data))[i]; + } + FLASHINFER_INLINE __hip_fp8_e5m2_fnuz* ptr() { + return reinterpret_cast<__hip_fp8_e5m2_fnuz*>(&data); + } + FLASHINFER_INLINE void fill(__hip_fp8_e5m2_fnuz val); + FLASHINFER_INLINE void load(const __hip_fp8_e5m2_fnuz* ptr); + FLASHINFER_INLINE void store(__hip_fp8_e5m2_fnuz* ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + FLASHINFER_INLINE static void memcpy(__hip_fp8_e5m2_fnuz* dst, const __hip_fp8_e5m2_fnuz* src); }; -FLASHINFER_INLINE void -vec_t<__hip_fp8_e5m2_fnuz, 8>::fill(__hip_fp8_e5m2_fnuz val) -{ - ((__hip_fp8x4_e5m2_fnuz *)(&data.x))->__x = - (__hip_fp8x4_storage_t(val.__x) << 24) | - (__hip_fp8x4_storage_t(val.__x) << 16) | - (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); - ((__hip_fp8x4_e5m2_fnuz *)(&data.y))->__x = - (__hip_fp8x4_storage_t(val.__x) << 24) | - (__hip_fp8x4_storage_t(val.__x) << 16) | - (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); +FLASHINFER_INLINE void vec_t<__hip_fp8_e5m2_fnuz, 8>::fill(__hip_fp8_e5m2_fnuz val) { + ((__hip_fp8x4_e5m2_fnuz*)(&data.x))->__x = + (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | + (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); + ((__hip_fp8x4_e5m2_fnuz*)(&data.y))->__x = + (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | + (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); } -FLASHINFER_INLINE void -vec_t<__hip_fp8_e5m2_fnuz, 8>::load(const __hip_fp8_e5m2_fnuz *ptr) -{ - data = *((uint2 *)ptr); +FLASHINFER_INLINE void vec_t<__hip_fp8_e5m2_fnuz, 8>::load(const __hip_fp8_e5m2_fnuz* ptr) { + data = *((uint2*)ptr); } -FLASHINFER_INLINE void -vec_t<__hip_fp8_e5m2_fnuz, 8>::store(__hip_fp8_e5m2_fnuz *ptr) const -{ - *((uint2 *)ptr) = data; +FLASHINFER_INLINE void vec_t<__hip_fp8_e5m2_fnuz, 8>::store(__hip_fp8_e5m2_fnuz* ptr) const { + *((uint2*)ptr) = data; } -FLASHINFER_INLINE void -vec_t<__hip_fp8_e5m2_fnuz, 8>::memcpy(__hip_fp8_e5m2_fnuz *dst, - const __hip_fp8_e5m2_fnuz *src) -{ - *((uint2 *)dst) = *((uint2 *)src); +FLASHINFER_INLINE void vec_t<__hip_fp8_e5m2_fnuz, 8>::memcpy(__hip_fp8_e5m2_fnuz* dst, + const __hip_fp8_e5m2_fnuz* src) { + *((uint2*)dst) = *((uint2*)src); } // __hip_fp8_e5m2_fnuz x 16 or more -template struct vec_t<__hip_fp8_e5m2_fnuz, vec_size> -{ - uint4 data[vec_size / 16]; - - FLASHINFER_INLINE __hip_fp8_e5m2_fnuz &operator[](size_t i) - { - return ((__hip_fp8_e5m2_fnuz *)data)[i]; - } - FLASHINFER_INLINE const __hip_fp8_e5m2_fnuz &operator[](size_t i) const - { - return ((const __hip_fp8_e5m2_fnuz *)data)[i]; - } - FLASHINFER_INLINE __hip_fp8_e5m2_fnuz *ptr() - { - return reinterpret_cast<__hip_fp8_e5m2_fnuz *>(&data); - } - FLASHINFER_INLINE void fill(__hip_fp8_e5m2_fnuz val) - { +template +struct vec_t<__hip_fp8_e5m2_fnuz, vec_size> { + uint4 data[vec_size / 16]; + + FLASHINFER_INLINE __hip_fp8_e5m2_fnuz& operator[](size_t i) { + return ((__hip_fp8_e5m2_fnuz*)data)[i]; + } + FLASHINFER_INLINE const __hip_fp8_e5m2_fnuz& operator[](size_t i) const { + return ((const __hip_fp8_e5m2_fnuz*)data)[i]; + } + FLASHINFER_INLINE __hip_fp8_e5m2_fnuz* ptr() { + return reinterpret_cast<__hip_fp8_e5m2_fnuz*>(&data); + } + FLASHINFER_INLINE void fill(__hip_fp8_e5m2_fnuz val) { #pragma unroll - for (size_t i = 0; i < vec_size / 16; ++i) { - ((__hip_fp8x4_e5m2_fnuz *)(&(data[i].x)))->__x = - (__hip_fp8x4_storage_t(val.__x) << 24) | - (__hip_fp8x4_storage_t(val.__x) << 16) | - (__hip_fp8x4_storage_t(val.__x) << 8) | - __hip_fp8x4_storage_t(val.__x); - ((__hip_fp8x4_e5m2_fnuz *)(&(data[i].y)))->__x = - (__hip_fp8x4_storage_t(val.__x) << 24) | - (__hip_fp8x4_storage_t(val.__x) << 16) | - (__hip_fp8x4_storage_t(val.__x) << 8) | - __hip_fp8x4_storage_t(val.__x); - ((__hip_fp8x4_e5m2_fnuz *)(&(data[i].z)))->__x = - (__hip_fp8x4_storage_t(val.__x) << 24) | - (__hip_fp8x4_storage_t(val.__x) << 16) | - (__hip_fp8x4_storage_t(val.__x) << 8) | - __hip_fp8x4_storage_t(val.__x); - ((__hip_fp8x4_e5m2_fnuz *)(&(data[i].w)))->__x = - (__hip_fp8x4_storage_t(val.__x) << 24) | - (__hip_fp8x4_storage_t(val.__x) << 16) | - (__hip_fp8x4_storage_t(val.__x) << 8) | - __hip_fp8x4_storage_t(val.__x); - } - } - FLASHINFER_INLINE void load(const __hip_fp8_e5m2_fnuz *ptr) - { + for (size_t i = 0; i < vec_size / 16; ++i) { + ((__hip_fp8x4_e5m2_fnuz*)(&(data[i].x)))->__x = + (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | + (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); + ((__hip_fp8x4_e5m2_fnuz*)(&(data[i].y)))->__x = + (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | + (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); + ((__hip_fp8x4_e5m2_fnuz*)(&(data[i].z)))->__x = + (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | + (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); + ((__hip_fp8x4_e5m2_fnuz*)(&(data[i].w)))->__x = + (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | + (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); + } + } + FLASHINFER_INLINE void load(const __hip_fp8_e5m2_fnuz* ptr) { #pragma unroll - for (size_t i = 0; i < vec_size / 16; ++i) { - data[i] = ((uint4 *)ptr)[i]; - } + for (size_t i = 0; i < vec_size / 16; ++i) { + data[i] = ((uint4*)ptr)[i]; } - FLASHINFER_INLINE void store(__hip_fp8_e5m2_fnuz *ptr) const - { + } + FLASHINFER_INLINE void store(__hip_fp8_e5m2_fnuz* ptr) const { #pragma unroll - for (size_t i = 0; i < vec_size / 16; ++i) { - ((uint4 *)ptr)[i] = data[i]; - } - } - template - FLASHINFER_INLINE void cast_from(const vec_t &src) - { - cast_from_impl(*this, src); - } - template FLASHINFER_INLINE void cast_load(const T *ptr) - { - cast_load_impl(*this, ptr); - } - template FLASHINFER_INLINE void cast_store(T *ptr) const - { - cast_store_impl(ptr, *this); - } - FLASHINFER_INLINE static void memcpy(__hip_fp8_e5m2_fnuz *dst, - const __hip_fp8_e5m2_fnuz *src) - { + for (size_t i = 0; i < vec_size / 16; ++i) { + ((uint4*)ptr)[i] = data[i]; + } + } + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + FLASHINFER_INLINE static void memcpy(__hip_fp8_e5m2_fnuz* dst, const __hip_fp8_e5m2_fnuz* src) { #pragma unroll - for (size_t i = 0; i < vec_size / 16; ++i) { - ((uint4 *)dst)[i] = ((uint4 *)src)[i]; - } + for (size_t i = 0; i < vec_size / 16; ++i) { + ((uint4*)dst)[i] = ((uint4*)src)[i]; } + } }; /******************* vec_t *******************/ // half x 1 -template <> struct vec_t -{ - half data; - - FLASHINFER_INLINE half &operator[](size_t i) - { - return ((half *)(&data))[i]; - } - FLASHINFER_INLINE const half &operator[](size_t i) const - { - return ((const half *)(&data))[i]; - } - FLASHINFER_INLINE half *ptr() { return reinterpret_cast(&data); } - FLASHINFER_INLINE void fill(half val); - FLASHINFER_INLINE void load(const half *ptr); - FLASHINFER_INLINE void store(half *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) - { - cast_from_impl(*this, src); - } - template FLASHINFER_INLINE void cast_load(const T *ptr) - { - cast_load_impl(*this, ptr); - } - template FLASHINFER_INLINE void cast_store(T *ptr) const - { - cast_store_impl(ptr, *this); - } - - FLASHINFER_INLINE static void memcpy(half *dst, const half *src); +template <> +struct vec_t { + half data; + + FLASHINFER_INLINE half& operator[](size_t i) { return ((half*)(&data))[i]; } + FLASHINFER_INLINE const half& operator[](size_t i) const { return ((const half*)(&data))[i]; } + FLASHINFER_INLINE half* ptr() { return reinterpret_cast(&data); } + FLASHINFER_INLINE void fill(half val); + FLASHINFER_INLINE void load(const half* ptr); + FLASHINFER_INLINE void store(half* ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + + FLASHINFER_INLINE static void memcpy(half* dst, const half* src); }; FLASHINFER_INLINE void vec_t::fill(half val) { data = val; } -FLASHINFER_INLINE void vec_t::load(const half *ptr) { data = *ptr; } +FLASHINFER_INLINE void vec_t::load(const half* ptr) { data = *ptr; } -FLASHINFER_INLINE void vec_t::store(half *ptr) const { *ptr = data; } +FLASHINFER_INLINE void vec_t::store(half* ptr) const { *ptr = data; } -FLASHINFER_INLINE void vec_t::memcpy(half *dst, const half *src) -{ - *dst = *src; -} +FLASHINFER_INLINE void vec_t::memcpy(half* dst, const half* src) { *dst = *src; } // half x 2 -template <> struct vec_t -{ - half2 data; - - FLASHINFER_INLINE half &operator[](size_t i) - { - return ((half *)(&data))[i]; - } - FLASHINFER_INLINE const half &operator[](size_t i) const - { - return ((const half *)(&data))[i]; - } - FLASHINFER_INLINE half *ptr() { return reinterpret_cast(&data); } - FLASHINFER_INLINE void fill(half val); - FLASHINFER_INLINE void load(const half *ptr); - FLASHINFER_INLINE void store(half *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) - { - cast_from_impl(*this, src); - } - template FLASHINFER_INLINE void cast_load(const T *ptr) - { - cast_load_impl(*this, ptr); - } - template FLASHINFER_INLINE void cast_store(T *ptr) const - { - cast_store_impl(ptr, *this); - } - - FLASHINFER_INLINE static void memcpy(half *dst, const half *src); +template <> +struct vec_t { + half2 data; + + FLASHINFER_INLINE half& operator[](size_t i) { return ((half*)(&data))[i]; } + FLASHINFER_INLINE const half& operator[](size_t i) const { return ((const half*)(&data))[i]; } + FLASHINFER_INLINE half* ptr() { return reinterpret_cast(&data); } + FLASHINFER_INLINE void fill(half val); + FLASHINFER_INLINE void load(const half* ptr); + FLASHINFER_INLINE void store(half* ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + + FLASHINFER_INLINE static void memcpy(half* dst, const half* src); }; -FLASHINFER_INLINE void vec_t::fill(half val) -{ - data = make_half2(val, val); -} +FLASHINFER_INLINE void vec_t::fill(half val) { data = make_half2(val, val); } -FLASHINFER_INLINE void vec_t::load(const half *ptr) -{ - data = *((half2 *)ptr); -} +FLASHINFER_INLINE void vec_t::load(const half* ptr) { data = *((half2*)ptr); } -FLASHINFER_INLINE void vec_t::store(half *ptr) const -{ - *((half2 *)ptr) = data; -} +FLASHINFER_INLINE void vec_t::store(half* ptr) const { *((half2*)ptr) = data; } -FLASHINFER_INLINE void vec_t::memcpy(half *dst, const half *src) -{ - *((half2 *)dst) = *((half2 *)src); +FLASHINFER_INLINE void vec_t::memcpy(half* dst, const half* src) { + *((half2*)dst) = *((half2*)src); } // half x 4 -template <> struct vec_t -{ - uint2 data; - - FLASHINFER_INLINE half &operator[](size_t i) - { - return ((half *)(&data))[i]; - } - FLASHINFER_INLINE const half &operator[](size_t i) const - { - return ((const half *)(&data))[i]; - } - FLASHINFER_INLINE half *ptr() { return reinterpret_cast(&data); } - FLASHINFER_INLINE void fill(half val); - FLASHINFER_INLINE void load(const half *ptr); - FLASHINFER_INLINE void store(half *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) - { - cast_from_impl(*this, src); - } - template FLASHINFER_INLINE void cast_load(const T *ptr) - { - cast_load_impl(*this, ptr); - } - template FLASHINFER_INLINE void cast_store(T *ptr) const - { - cast_store_impl(ptr, *this); - } - FLASHINFER_INLINE static void memcpy(half *dst, const half *src); +template <> +struct vec_t { + uint2 data; + + FLASHINFER_INLINE half& operator[](size_t i) { return ((half*)(&data))[i]; } + FLASHINFER_INLINE const half& operator[](size_t i) const { return ((const half*)(&data))[i]; } + FLASHINFER_INLINE half* ptr() { return reinterpret_cast(&data); } + FLASHINFER_INLINE void fill(half val); + FLASHINFER_INLINE void load(const half* ptr); + FLASHINFER_INLINE void store(half* ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + FLASHINFER_INLINE static void memcpy(half* dst, const half* src); }; -FLASHINFER_INLINE void vec_t::fill(half val) -{ - *(half2 *)(&data.x) = make_half2(val, val); - *(half2 *)(&data.y) = make_half2(val, val); +FLASHINFER_INLINE void vec_t::fill(half val) { + *(half2*)(&data.x) = make_half2(val, val); + *(half2*)(&data.y) = make_half2(val, val); } -FLASHINFER_INLINE void vec_t::load(const half *ptr) -{ - data = *((uint2 *)ptr); -} +FLASHINFER_INLINE void vec_t::load(const half* ptr) { data = *((uint2*)ptr); } -FLASHINFER_INLINE void vec_t::store(half *ptr) const -{ - *((uint2 *)ptr) = data; -} +FLASHINFER_INLINE void vec_t::store(half* ptr) const { *((uint2*)ptr) = data; } -FLASHINFER_INLINE void vec_t::memcpy(half *dst, const half *src) -{ - *((uint2 *)dst) = *((uint2 *)src); +FLASHINFER_INLINE void vec_t::memcpy(half* dst, const half* src) { + *((uint2*)dst) = *((uint2*)src); } // half x 8 or more -template struct vec_t -{ - uint4 data[vec_size / 8]; - FLASHINFER_INLINE half &operator[](size_t i) { return ((half *)data)[i]; } - FLASHINFER_INLINE const half &operator[](size_t i) const - { - return ((const half *)data)[i]; - } - FLASHINFER_INLINE half *ptr() { return reinterpret_cast(&data); } - FLASHINFER_INLINE void fill(half val) - { +template +struct vec_t { + uint4 data[vec_size / 8]; + FLASHINFER_INLINE half& operator[](size_t i) { return ((half*)data)[i]; } + FLASHINFER_INLINE const half& operator[](size_t i) const { return ((const half*)data)[i]; } + FLASHINFER_INLINE half* ptr() { return reinterpret_cast(&data); } + FLASHINFER_INLINE void fill(half val) { #pragma unroll - for (size_t i = 0; i < vec_size / 8; ++i) { - *(half2 *)(&(data[i].x)) = make_half2(val, val); - *(half2 *)(&(data[i].y)) = make_half2(val, val); - *(half2 *)(&(data[i].z)) = make_half2(val, val); - *(half2 *)(&(data[i].w)) = make_half2(val, val); - } - } - FLASHINFER_INLINE void load(const half *ptr) - { + for (size_t i = 0; i < vec_size / 8; ++i) { + *(half2*)(&(data[i].x)) = make_half2(val, val); + *(half2*)(&(data[i].y)) = make_half2(val, val); + *(half2*)(&(data[i].z)) = make_half2(val, val); + *(half2*)(&(data[i].w)) = make_half2(val, val); + } + } + FLASHINFER_INLINE void load(const half* ptr) { #pragma unroll - for (size_t i = 0; i < vec_size / 8; ++i) { - data[i] = ((uint4 *)ptr)[i]; - } + for (size_t i = 0; i < vec_size / 8; ++i) { + data[i] = ((uint4*)ptr)[i]; } - FLASHINFER_INLINE void store(half *ptr) const - { + } + FLASHINFER_INLINE void store(half* ptr) const { #pragma unroll - for (size_t i = 0; i < vec_size / 8; ++i) { - ((uint4 *)ptr)[i] = data[i]; - } - } - template - FLASHINFER_INLINE void cast_from(const vec_t &src) - { - cast_from_impl(*this, src); - } - template FLASHINFER_INLINE void cast_load(const T *ptr) - { - cast_load_impl(*this, ptr); - } - template FLASHINFER_INLINE void cast_store(T *ptr) const - { - cast_store_impl(ptr, *this); - } - FLASHINFER_INLINE static void memcpy(half *dst, const half *src) - { + for (size_t i = 0; i < vec_size / 8; ++i) { + ((uint4*)ptr)[i] = data[i]; + } + } + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + FLASHINFER_INLINE static void memcpy(half* dst, const half* src) { #pragma unroll - for (size_t i = 0; i < vec_size / 8; ++i) { - ((uint4 *)dst)[i] = ((uint4 *)src)[i]; - } + for (size_t i = 0; i < vec_size / 8; ++i) { + ((uint4*)dst)[i] = ((uint4*)src)[i]; } + } }; /******************* vec_t<__hip_bfloat16> *******************/ // __hip_bfloat16 x 1 -template <> struct vec_t<__hip_bfloat16, 1> -{ - __hip_bfloat16 data; - FLASHINFER_INLINE __hip_bfloat16 &operator[](size_t i) - { - return ((__hip_bfloat16 *)(&data))[i]; - } - FLASHINFER_INLINE const __hip_bfloat16 &operator[](size_t i) const - { - return ((const __hip_bfloat16 *)(&data))[i]; - } - FLASHINFER_INLINE __hip_bfloat16 *ptr() - { - return reinterpret_cast<__hip_bfloat16 *>(&data); - } - FLASHINFER_INLINE void fill(__hip_bfloat16 val); - FLASHINFER_INLINE void load(const __hip_bfloat16 *ptr); - FLASHINFER_INLINE void store(__hip_bfloat16 *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) - { - cast_from_impl(*this, src); - } - template FLASHINFER_INLINE void cast_load(const T *ptr) - { - cast_load_impl(*this, ptr); - } - template FLASHINFER_INLINE void cast_store(T *ptr) const - { - cast_store_impl(ptr, *this); - } - FLASHINFER_INLINE static void memcpy(__hip_bfloat16 *dst, - const __hip_bfloat16 *src); +template <> +struct vec_t<__hip_bfloat16, 1> { + __hip_bfloat16 data; + FLASHINFER_INLINE __hip_bfloat16& operator[](size_t i) { return ((__hip_bfloat16*)(&data))[i]; } + FLASHINFER_INLINE const __hip_bfloat16& operator[](size_t i) const { + return ((const __hip_bfloat16*)(&data))[i]; + } + FLASHINFER_INLINE __hip_bfloat16* ptr() { return reinterpret_cast<__hip_bfloat16*>(&data); } + FLASHINFER_INLINE void fill(__hip_bfloat16 val); + FLASHINFER_INLINE void load(const __hip_bfloat16* ptr); + FLASHINFER_INLINE void store(__hip_bfloat16* ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + FLASHINFER_INLINE static void memcpy(__hip_bfloat16* dst, const __hip_bfloat16* src); }; -FLASHINFER_INLINE void vec_t<__hip_bfloat16, 1>::fill(__hip_bfloat16 val) -{ - data = val; -} +FLASHINFER_INLINE void vec_t<__hip_bfloat16, 1>::fill(__hip_bfloat16 val) { data = val; } -FLASHINFER_INLINE void vec_t<__hip_bfloat16, 1>::load(const __hip_bfloat16 *ptr) -{ - data = *ptr; -} +FLASHINFER_INLINE void vec_t<__hip_bfloat16, 1>::load(const __hip_bfloat16* ptr) { data = *ptr; } -FLASHINFER_INLINE void -vec_t<__hip_bfloat16, 1>::store(__hip_bfloat16 *ptr) const -{ - *ptr = data; -} +FLASHINFER_INLINE void vec_t<__hip_bfloat16, 1>::store(__hip_bfloat16* ptr) const { *ptr = data; } -FLASHINFER_INLINE void -vec_t<__hip_bfloat16, 1>::memcpy(__hip_bfloat16 *dst, const __hip_bfloat16 *src) -{ - *dst = *src; +FLASHINFER_INLINE void vec_t<__hip_bfloat16, 1>::memcpy(__hip_bfloat16* dst, + const __hip_bfloat16* src) { + *dst = *src; } // __hip_bfloat16 x 2 -template <> struct vec_t<__hip_bfloat16, 2> -{ - __hip_bfloat162 data; - - FLASHINFER_INLINE __hip_bfloat16 &operator[](size_t i) - { - return ((__hip_bfloat16 *)(&data))[i]; - } - FLASHINFER_INLINE const __hip_bfloat16 &operator[](size_t i) const - { - return ((const __hip_bfloat16 *)(&data))[i]; - } - FLASHINFER_INLINE __hip_bfloat16 *ptr() - { - return reinterpret_cast<__hip_bfloat16 *>(&data); - } - FLASHINFER_INLINE void fill(__hip_bfloat16 val); - FLASHINFER_INLINE void load(const __hip_bfloat16 *ptr); - FLASHINFER_INLINE void store(__hip_bfloat16 *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) - { - cast_from_impl(*this, src); - } - template FLASHINFER_INLINE void cast_load(const T *ptr) - { - cast_load_impl(*this, ptr); - } - template FLASHINFER_INLINE void cast_store(T *ptr) const - { - cast_store_impl(ptr, *this); - } - FLASHINFER_INLINE static void memcpy(__hip_bfloat16 *dst, - const __hip_bfloat16 *src); +template <> +struct vec_t<__hip_bfloat16, 2> { + __hip_bfloat162 data; + + FLASHINFER_INLINE __hip_bfloat16& operator[](size_t i) { return ((__hip_bfloat16*)(&data))[i]; } + FLASHINFER_INLINE const __hip_bfloat16& operator[](size_t i) const { + return ((const __hip_bfloat16*)(&data))[i]; + } + FLASHINFER_INLINE __hip_bfloat16* ptr() { return reinterpret_cast<__hip_bfloat16*>(&data); } + FLASHINFER_INLINE void fill(__hip_bfloat16 val); + FLASHINFER_INLINE void load(const __hip_bfloat16* ptr); + FLASHINFER_INLINE void store(__hip_bfloat16* ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + FLASHINFER_INLINE static void memcpy(__hip_bfloat16* dst, const __hip_bfloat16* src); }; -FLASHINFER_INLINE void vec_t<__hip_bfloat16, 2>::fill(__hip_bfloat16 val) -{ - data = make_bfloat162(val, val); +FLASHINFER_INLINE void vec_t<__hip_bfloat16, 2>::fill(__hip_bfloat16 val) { + data = make_bfloat162(val, val); } -FLASHINFER_INLINE void vec_t<__hip_bfloat16, 2>::load(const __hip_bfloat16 *ptr) -{ - data = *((__hip_bfloat162 *)ptr); +FLASHINFER_INLINE void vec_t<__hip_bfloat16, 2>::load(const __hip_bfloat16* ptr) { + data = *((__hip_bfloat162*)ptr); } -FLASHINFER_INLINE void -vec_t<__hip_bfloat16, 2>::store(__hip_bfloat16 *ptr) const -{ - *((__hip_bfloat162 *)ptr) = data; +FLASHINFER_INLINE void vec_t<__hip_bfloat16, 2>::store(__hip_bfloat16* ptr) const { + *((__hip_bfloat162*)ptr) = data; } -FLASHINFER_INLINE void -vec_t<__hip_bfloat16, 2>::memcpy(__hip_bfloat16 *dst, const __hip_bfloat16 *src) -{ - *((__hip_bfloat162 *)dst) = *((__hip_bfloat162 *)src); +FLASHINFER_INLINE void vec_t<__hip_bfloat16, 2>::memcpy(__hip_bfloat16* dst, + const __hip_bfloat16* src) { + *((__hip_bfloat162*)dst) = *((__hip_bfloat162*)src); } // __hip_bfloat16 x 4 -template <> struct vec_t<__hip_bfloat16, 4> -{ - uint2 data; - - FLASHINFER_INLINE __hip_bfloat16 &operator[](size_t i) - { - return ((__hip_bfloat16 *)(&data))[i]; - } - FLASHINFER_INLINE const __hip_bfloat16 &operator[](size_t i) const - { - return ((const __hip_bfloat16 *)(&data))[i]; - } - FLASHINFER_INLINE __hip_bfloat16 *ptr() - { - return reinterpret_cast<__hip_bfloat16 *>(&data); - } - FLASHINFER_INLINE void fill(__hip_bfloat16 val); - FLASHINFER_INLINE void load(const __hip_bfloat16 *ptr); - FLASHINFER_INLINE void store(__hip_bfloat16 *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) - { - cast_from_impl(*this, src); - } - template FLASHINFER_INLINE void cast_load(const T *ptr) - { - cast_load_impl(*this, ptr); - } - template FLASHINFER_INLINE void cast_store(T *ptr) const - { - cast_store_impl(ptr, *this); - } - FLASHINFER_INLINE static void memcpy(__hip_bfloat16 *dst, - const __hip_bfloat16 *src); +template <> +struct vec_t<__hip_bfloat16, 4> { + uint2 data; + + FLASHINFER_INLINE __hip_bfloat16& operator[](size_t i) { return ((__hip_bfloat16*)(&data))[i]; } + FLASHINFER_INLINE const __hip_bfloat16& operator[](size_t i) const { + return ((const __hip_bfloat16*)(&data))[i]; + } + FLASHINFER_INLINE __hip_bfloat16* ptr() { return reinterpret_cast<__hip_bfloat16*>(&data); } + FLASHINFER_INLINE void fill(__hip_bfloat16 val); + FLASHINFER_INLINE void load(const __hip_bfloat16* ptr); + FLASHINFER_INLINE void store(__hip_bfloat16* ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + FLASHINFER_INLINE static void memcpy(__hip_bfloat16* dst, const __hip_bfloat16* src); }; -FLASHINFER_INLINE void vec_t<__hip_bfloat16, 4>::fill(__hip_bfloat16 val) -{ - *(__hip_bfloat162 *)(&data.x) = make_bfloat162(val, val); - *(__hip_bfloat162 *)(&data.y) = make_bfloat162(val, val); +FLASHINFER_INLINE void vec_t<__hip_bfloat16, 4>::fill(__hip_bfloat16 val) { + *(__hip_bfloat162*)(&data.x) = make_bfloat162(val, val); + *(__hip_bfloat162*)(&data.y) = make_bfloat162(val, val); } -FLASHINFER_INLINE void vec_t<__hip_bfloat16, 4>::load(const __hip_bfloat16 *ptr) -{ - data = *((uint2 *)ptr); +FLASHINFER_INLINE void vec_t<__hip_bfloat16, 4>::load(const __hip_bfloat16* ptr) { + data = *((uint2*)ptr); } -FLASHINFER_INLINE void -vec_t<__hip_bfloat16, 4>::store(__hip_bfloat16 *ptr) const -{ - *((uint2 *)ptr) = data; +FLASHINFER_INLINE void vec_t<__hip_bfloat16, 4>::store(__hip_bfloat16* ptr) const { + *((uint2*)ptr) = data; } -FLASHINFER_INLINE void -vec_t<__hip_bfloat16, 4>::memcpy(__hip_bfloat16 *dst, const __hip_bfloat16 *src) -{ - *((uint2 *)dst) = *((uint2 *)src); +FLASHINFER_INLINE void vec_t<__hip_bfloat16, 4>::memcpy(__hip_bfloat16* dst, + const __hip_bfloat16* src) { + *((uint2*)dst) = *((uint2*)src); } // __hip_bfloat16 x 8 or more -template struct vec_t<__hip_bfloat16, vec_size> -{ - uint4 data[vec_size / 8]; +template +struct vec_t<__hip_bfloat16, vec_size> { + uint4 data[vec_size / 8]; - FLASHINFER_INLINE __hip_bfloat16 &operator[](size_t i) - { - return ((__hip_bfloat16 *)data)[i]; - } - FLASHINFER_INLINE const __hip_bfloat16 &operator[](size_t i) const - { - return ((const __hip_bfloat16 *)data)[i]; - } - FLASHINFER_INLINE __hip_bfloat16 *ptr() - { - return reinterpret_cast<__hip_bfloat16 *>(&data); - } - FLASHINFER_INLINE void fill(__hip_bfloat16 val) - { + FLASHINFER_INLINE __hip_bfloat16& operator[](size_t i) { return ((__hip_bfloat16*)data)[i]; } + FLASHINFER_INLINE const __hip_bfloat16& operator[](size_t i) const { + return ((const __hip_bfloat16*)data)[i]; + } + FLASHINFER_INLINE __hip_bfloat16* ptr() { return reinterpret_cast<__hip_bfloat16*>(&data); } + FLASHINFER_INLINE void fill(__hip_bfloat16 val) { #pragma unoll - for (size_t i = 0; i < vec_size / 8; ++i) { - *(__hip_bfloat162 *)(&(data[i].x)) = make_bfloat162(val, val); - *(__hip_bfloat162 *)(&(data[i].y)) = make_bfloat162(val, val); - *(__hip_bfloat162 *)(&(data[i].z)) = make_bfloat162(val, val); - *(__hip_bfloat162 *)(&(data[i].w)) = make_bfloat162(val, val); - } - } - FLASHINFER_INLINE void load(const __hip_bfloat16 *ptr) - { + for (size_t i = 0; i < vec_size / 8; ++i) { + *(__hip_bfloat162*)(&(data[i].x)) = make_bfloat162(val, val); + *(__hip_bfloat162*)(&(data[i].y)) = make_bfloat162(val, val); + *(__hip_bfloat162*)(&(data[i].z)) = make_bfloat162(val, val); + *(__hip_bfloat162*)(&(data[i].w)) = make_bfloat162(val, val); + } + } + FLASHINFER_INLINE void load(const __hip_bfloat16* ptr) { #pragma unoll - for (size_t i = 0; i < vec_size / 8; ++i) { - data[i] = ((uint4 *)ptr)[i]; - } + for (size_t i = 0; i < vec_size / 8; ++i) { + data[i] = ((uint4*)ptr)[i]; } - FLASHINFER_INLINE void store(__hip_bfloat16 *ptr) const - { + } + FLASHINFER_INLINE void store(__hip_bfloat16* ptr) const { #pragma unoll - for (size_t i = 0; i < vec_size / 8; ++i) { - ((uint4 *)ptr)[i] = data[i]; - } - } - template - FLASHINFER_INLINE void cast_from(const vec_t &src) - { - cast_from_impl(*this, src); - } - template FLASHINFER_INLINE void cast_load(const T *ptr) - { - cast_load_impl(*this, ptr); - } - template FLASHINFER_INLINE void cast_store(T *ptr) const - { - cast_store_impl(ptr, *this); - } - FLASHINFER_INLINE static void memcpy(__hip_bfloat16 *dst, - const __hip_bfloat16 *src) - { + for (size_t i = 0; i < vec_size / 8; ++i) { + ((uint4*)ptr)[i] = data[i]; + } + } + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + FLASHINFER_INLINE static void memcpy(__hip_bfloat16* dst, const __hip_bfloat16* src) { #pragma unoll - for (size_t i = 0; i < vec_size / 8; ++i) { - ((uint4 *)dst)[i] = ((uint4 *)src)[i]; - } + for (size_t i = 0; i < vec_size / 8; ++i) { + ((uint4*)dst)[i] = ((uint4*)src)[i]; } + } }; /******************* vec_t *******************/ // float x 1 -template <> struct vec_t -{ - float data; - - FLASHINFER_INLINE float &operator[](size_t i) - { - return ((float *)(&data))[i]; - } - FLASHINFER_INLINE const float &operator[](size_t i) const - { - return ((const float *)(&data))[i]; - } - FLASHINFER_INLINE float *ptr() { return reinterpret_cast(&data); } - FLASHINFER_INLINE void fill(float val); - FLASHINFER_INLINE void load(const float *ptr); - FLASHINFER_INLINE void store(float *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) - { - cast_from_impl(*this, src); - } - template FLASHINFER_INLINE void cast_load(const T *ptr) - { - cast_load_impl(*this, ptr); - } - template FLASHINFER_INLINE void cast_store(T *ptr) const - { - cast_store_impl(ptr, *this); - } - FLASHINFER_INLINE static void memcpy(float *dst, const float *src); +template <> +struct vec_t { + float data; + + FLASHINFER_INLINE float& operator[](size_t i) { return ((float*)(&data))[i]; } + FLASHINFER_INLINE const float& operator[](size_t i) const { return ((const float*)(&data))[i]; } + FLASHINFER_INLINE float* ptr() { return reinterpret_cast(&data); } + FLASHINFER_INLINE void fill(float val); + FLASHINFER_INLINE void load(const float* ptr); + FLASHINFER_INLINE void store(float* ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + FLASHINFER_INLINE static void memcpy(float* dst, const float* src); }; FLASHINFER_INLINE void vec_t::fill(float val) { data = val; } -FLASHINFER_INLINE void vec_t::load(const float *ptr) { data = *ptr; } +FLASHINFER_INLINE void vec_t::load(const float* ptr) { data = *ptr; } -FLASHINFER_INLINE void vec_t::store(float *ptr) const { *ptr = data; } +FLASHINFER_INLINE void vec_t::store(float* ptr) const { *ptr = data; } -FLASHINFER_INLINE void vec_t::memcpy(float *dst, const float *src) -{ - *dst = *src; -} +FLASHINFER_INLINE void vec_t::memcpy(float* dst, const float* src) { *dst = *src; } // float x 2 -template <> struct vec_t -{ - float2 data; - - FLASHINFER_INLINE float &operator[](size_t i) - { - return ((float *)(&data))[i]; - } - FLASHINFER_INLINE const float &operator[](size_t i) const - { - return ((const float *)(&data))[i]; - } - FLASHINFER_INLINE float *ptr() { return reinterpret_cast(&data); } - FLASHINFER_INLINE void fill(float val); - FLASHINFER_INLINE void load(const float *ptr); - FLASHINFER_INLINE void store(float *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) - { - cast_from_impl(*this, src); - } - template FLASHINFER_INLINE void cast_load(const T *ptr) - { - cast_load_impl(*this, ptr); - } - template FLASHINFER_INLINE void cast_store(T *ptr) const - { - cast_store_impl(ptr, *this); - } - FLASHINFER_INLINE static void memcpy(float *dst, const float *src); +template <> +struct vec_t { + float2 data; + + FLASHINFER_INLINE float& operator[](size_t i) { return ((float*)(&data))[i]; } + FLASHINFER_INLINE const float& operator[](size_t i) const { return ((const float*)(&data))[i]; } + FLASHINFER_INLINE float* ptr() { return reinterpret_cast(&data); } + FLASHINFER_INLINE void fill(float val); + FLASHINFER_INLINE void load(const float* ptr); + FLASHINFER_INLINE void store(float* ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + FLASHINFER_INLINE static void memcpy(float* dst, const float* src); }; -FLASHINFER_INLINE void vec_t::fill(float val) -{ - data = make_float2(val, val); -} +FLASHINFER_INLINE void vec_t::fill(float val) { data = make_float2(val, val); } -FLASHINFER_INLINE void vec_t::load(const float *ptr) -{ - data = *((float2 *)ptr); -} +FLASHINFER_INLINE void vec_t::load(const float* ptr) { data = *((float2*)ptr); } -FLASHINFER_INLINE void vec_t::store(float *ptr) const -{ - *((float2 *)ptr) = data; -} +FLASHINFER_INLINE void vec_t::store(float* ptr) const { *((float2*)ptr) = data; } -FLASHINFER_INLINE void vec_t::memcpy(float *dst, const float *src) -{ - *((float2 *)dst) = *((float2 *)src); +FLASHINFER_INLINE void vec_t::memcpy(float* dst, const float* src) { + *((float2*)dst) = *((float2*)src); } // float x 4 or more -template struct vec_t -{ - float4 data[vec_size / 4]; - - FLASHINFER_INLINE float &operator[](size_t i) - { - return ((float *)(data))[i]; - } - FLASHINFER_INLINE const float &operator[](size_t i) const - { - return ((const float *)(data))[i]; - } - FLASHINFER_INLINE float *ptr() { return reinterpret_cast(&data); } - FLASHINFER_INLINE void fill(float val) - { +template +struct vec_t { + float4 data[vec_size / 4]; + + FLASHINFER_INLINE float& operator[](size_t i) { return ((float*)(data))[i]; } + FLASHINFER_INLINE const float& operator[](size_t i) const { return ((const float*)(data))[i]; } + FLASHINFER_INLINE float* ptr() { return reinterpret_cast(&data); } + FLASHINFER_INLINE void fill(float val) { #pragma unroll - for (size_t i = 0; i < vec_size / 4; ++i) { - data[i] = make_float4(val, val, val, val); - } + for (size_t i = 0; i < vec_size / 4; ++i) { + data[i] = make_float4(val, val, val, val); } - FLASHINFER_INLINE void load(const float *ptr) - { + } + FLASHINFER_INLINE void load(const float* ptr) { #pragma unroll - for (size_t i = 0; i < vec_size / 4; ++i) { - data[i] = ((float4 *)ptr)[i]; - } + for (size_t i = 0; i < vec_size / 4; ++i) { + data[i] = ((float4*)ptr)[i]; } - FLASHINFER_INLINE void store(float *ptr) const - { + } + FLASHINFER_INLINE void store(float* ptr) const { #pragma unroll - for (size_t i = 0; i < vec_size / 4; ++i) { - ((float4 *)ptr)[i] = data[i]; - } - } - template - FLASHINFER_INLINE void cast_from(const vec_t &src) - { - cast_from_impl(*this, src); - } - template FLASHINFER_INLINE void cast_load(const T *ptr) - { - cast_load_impl(*this, ptr); - } - template FLASHINFER_INLINE void cast_store(T *ptr) const - { - cast_store_impl(ptr, *this); - } - FLASHINFER_INLINE static void memcpy(float *dst, const float *src) - { + for (size_t i = 0; i < vec_size / 4; ++i) { + ((float4*)ptr)[i] = data[i]; + } + } + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + FLASHINFER_INLINE static void memcpy(float* dst, const float* src) { #pragma unroll - for (size_t i = 0; i < vec_size / 4; ++i) { - ((float4 *)dst)[i] = ((float4 *)src)[i]; - } + for (size_t i = 0; i < vec_size / 4; ++i) { + ((float4*)dst)[i] = ((float4*)src)[i]; } + } }; -} // namespace flashinfer +} // namespace flashinfer diff --git a/libflashinfer/include/gpu_iface/backend/hip/math_hip.h b/libflashinfer/include/gpu_iface/backend/hip/math_hip.h index 103430082e..a65fcdfaa0 100644 --- a/libflashinfer/include/gpu_iface/backend/hip/math_hip.h +++ b/libflashinfer/include/gpu_iface/backend/hip/math_hip.h @@ -9,15 +9,13 @@ #define HIP_ENABLE_WARP_SYNC_BUILTINS 1 +#include #include #include -#include - #include -namespace flashinfer::math -{ +namespace flashinfer::math { // log2(e) constexpr float log2e = 1.44269504088896340736f; @@ -26,33 +24,34 @@ constexpr float loge2 = 0.693147180559945309417f; constexpr float inf = 5e4; -template __forceinline__ __device__ T ptx_exp2(T x); +template +__forceinline__ __device__ T ptx_exp2(T x); /// @brief Wrapper for computing 2 ^ x. We currently do not support a direct /// equivalent of __exp2f() /// @param x Input power to exponentiate /// @return Computes 2 ^ x -template <> __forceinline__ __device__ float ptx_exp2(float x) -{ - return __exp10f(x * __log10f(2.0f)); // Writing 2^x = 10 ^ (x * log_10(2)) +template <> +__forceinline__ __device__ float ptx_exp2(float x) { + return __exp10f(x * __log10f(2.0f)); // Writing 2^x = 10 ^ (x * log_10(2)) } /// @brief Wrapper for computing 2 ^ x. We currently do not support a direct /// equivalent of __exp2f() /// @param x Input power to exponentiate /// @return Computes 2 ^ x -template <> __forceinline__ __device__ __half ptx_exp2<__half>(__half x) -{ - return hexp2(x); +template <> +__forceinline__ __device__ __half ptx_exp2<__half>(__half x) { + return hexp2(x); } /// @brief Wrapper for computing 2 ^ x. We currently do not support a direct /// equivalent of __exp2f() /// @param x Vector of two half dtypes to exponentiate /// @return Computes 2 ^ x -template <> __forceinline__ __device__ __half2 ptx_exp2<__half2>(__half2 x) -{ - return half2(ptx_exp2(x.x), ptx_exp2(x.y)); +template <> +__forceinline__ __device__ __half2 ptx_exp2<__half2>(__half2 x) { + return half2(ptx_exp2(x.x), ptx_exp2(x.y)); } /// @brief Compute log2 @@ -66,16 +65,15 @@ __forceinline__ __device__ float ptx_log2(float x) { return __log2f(x); } __forceinline__ __device__ float ptx_rcp(float x) { return __frcp_rn(x); } template -__forceinline__ __device__ T shfl_xor_sync(T x, int lane_mask) -{ - // FIXME (diptorupd): The shfl_xor_sync is used to implement a butterfly - // reduction pattern. The caller in decode.cuh most likely assumes that the - // warp size is 32 and the lane_mask is going from 16, 8, 4, 2, 1. - // Given that AMDGPU for CDNA3 has a warp size of 64, the lane_mask based on - // the warp size of 32 might lead to incorrect exchanges between the - // threads. The issue requires further investigation, for now I have hard - // coded the warp size to 32 when calling shfl_xor. - return __shfl_xor(x, lane_mask, 32); +__forceinline__ __device__ T shfl_xor_sync(T x, int lane_mask) { + // FIXME (diptorupd): The shfl_xor_sync is used to implement a butterfly + // reduction pattern. The caller in decode.cuh most likely assumes that the + // warp size is 32 and the lane_mask is going from 16, 8, 4, 2, 1. + // Given that AMDGPU for CDNA3 has a warp size of 64, the lane_mask based on + // the warp size of 32 might lead to incorrect exchanges between the + // threads. The issue requires further investigation, for now I have hard + // coded the warp size to 32 when calling shfl_xor. + return __shfl_xor(x, lane_mask, 32); } /// @brief Wrapper for math intrinsic 1/sqrt(x) @@ -83,32 +81,33 @@ __forceinline__ __device__ T shfl_xor_sync(T x, int lane_mask) /// @return Returns 1 / sqrt(x) in round to nearest even mode __forceinline__ __device__ float rsqrt(float x) { return __frsqrt_rn(x); } -template __forceinline__ __device__ T tanh(T x); +template +__forceinline__ __device__ T tanh(T x); /// @brief Compute tanhf(x) /// @param x Input param - float dtype /// @return Returns tanhf(x) /// @note ROCm6.3 does not have a fast tanh or instrincs to support this -template <> __forceinline__ __device__ float tanh(float x) -{ - return tanhf(x); +template <> +__forceinline__ __device__ float tanh(float x) { + return tanhf(x); } /// @brief A utility function to compute tanh for half dtype /// @param x Input param - half /// @return Hyperbolic tangent of x -template <> __forceinline__ __device__ __half tanh<__half>(__half x) -{ - return __float2half(tanh(__half2float(x))); +template <> +__forceinline__ __device__ __half tanh<__half>(__half x) { + return __float2half(tanh(__half2float(x))); } /// @brief Compute hyperbolic tangent for a vector of two half dtype /// @param x Vector of two half dtypes /// @return Hyperbolic tangent of x -template <> __forceinline__ __device__ __half2 tanh<__half2>(__half2 x) -{ - return __half2(tanh(x.x), tanh(x.y)); +template <> +__forceinline__ __device__ __half2 tanh<__half2>(__half2 x) { + return __half2(tanh(x.x), tanh(x.y)); } -} // namespace flashinfer::math -#endif // FLASHINFER_MATH_CUH_ +} // namespace flashinfer::math +#endif // FLASHINFER_MATH_CUH_ diff --git a/libflashinfer/include/gpu_iface/backend/hip/memory_ops_hip.h b/libflashinfer/include/gpu_iface/backend/hip/memory_ops_hip.h index f787309954..36bce4a0bc 100644 --- a/libflashinfer/include/gpu_iface/backend/hip/memory_ops_hip.h +++ b/libflashinfer/include/gpu_iface/backend/hip/memory_ops_hip.h @@ -3,111 +3,82 @@ #include #include -namespace flashinfer -{ -namespace gpu_iface -{ -namespace memory -{ -namespace detail -{ -namespace hip -{ +namespace flashinfer { +namespace gpu_iface { +namespace memory { +namespace detail { +namespace hip { -__device__ __forceinline__ void commit_group() -{ - // Currently a no-op for HIP +__device__ __forceinline__ void commit_group() { + // Currently a no-op for HIP } -template __device__ __forceinline__ void wait_group() -{ - // Currently a no-op for HIP +template +__device__ __forceinline__ void wait_group() { + // Currently a no-op for HIP } /// @brief loads 128 bits from global to shared memory template -__device__ __forceinline__ void load_128b(T *smem_ptr, const T *gmem_ptr) -{ - *reinterpret_cast(smem_ptr) = - *reinterpret_cast(gmem_ptr); +__device__ __forceinline__ void load_128b(T* smem_ptr, const T* gmem_ptr) { + *reinterpret_cast(smem_ptr) = *reinterpret_cast(gmem_ptr); } template -__device__ __forceinline__ void load_64b(T *smem_ptr, const T *gmem_ptr) -{ - *reinterpret_cast(smem_ptr) = - *reinterpret_cast(gmem_ptr); +__device__ __forceinline__ void load_64b(T* smem_ptr, const T* gmem_ptr) { + *reinterpret_cast(smem_ptr) = *reinterpret_cast(gmem_ptr); } // Predicated 128-bit load template -__device__ __forceinline__ void -pred_load_128b(T *smem_ptr, const T *gmem_ptr, bool predicate) -{ - if (predicate) { - *reinterpret_cast(smem_ptr) = - *reinterpret_cast(gmem_ptr); - } - else { - if constexpr (FillOpt == SharedMemFillMode::kFillZero) { - *reinterpret_cast(smem_ptr) = make_uint4(0, 0, 0, 0); - } +__device__ __forceinline__ void pred_load_128b(T* smem_ptr, const T* gmem_ptr, bool predicate) { + if (predicate) { + *reinterpret_cast(smem_ptr) = *reinterpret_cast(gmem_ptr); + } else { + if constexpr (FillOpt == SharedMemFillMode::kFillZero) { + *reinterpret_cast(smem_ptr) = make_uint4(0, 0, 0, 0); } + } } template -__device__ __forceinline__ void -pred_load_64b(T *smem_ptr, const T *gmem_ptr, bool predicate) -{ - if (predicate) { - *reinterpret_cast(smem_ptr) = - *reinterpret_cast(gmem_ptr); - } - else { - if constexpr (FillOpt == SharedMemFillMode::kFillZero) { - *reinterpret_cast(smem_ptr) = make_uint2(0, 0); - } +__device__ __forceinline__ void pred_load_64b(T* smem_ptr, const T* gmem_ptr, bool predicate) { + if (predicate) { + *reinterpret_cast(smem_ptr) = *reinterpret_cast(gmem_ptr); + } else { + if constexpr (FillOpt == SharedMemFillMode::kFillZero) { + *reinterpret_cast(smem_ptr) = make_uint2(0, 0); } + } } // Generic load with NumBits template parameter template -__device__ __forceinline__ void load(T *smem_ptr, const T *gmem_ptr) -{ - static_assert(NumBits == 128 || NumBits == 256, - "NumBits must be 128 or 256"); - if constexpr (NumBits == 128) { - load_128b(smem_ptr, gmem_ptr); - } - else { - load_128b(smem_ptr, gmem_ptr); - load_128b(smem_ptr + 16 / sizeof(T), - gmem_ptr + 16 / sizeof(T)); - } +__device__ __forceinline__ void load(T* smem_ptr, const T* gmem_ptr) { + static_assert(NumBits == 128 || NumBits == 256, "NumBits must be 128 or 256"); + if constexpr (NumBits == 128) { + load_128b(smem_ptr, gmem_ptr); + } else { + load_128b(smem_ptr, gmem_ptr); + load_128b(smem_ptr + 16 / sizeof(T), gmem_ptr + 16 / sizeof(T)); + } } // Generic predicated load with NumBits template parameter -template -__device__ __forceinline__ void -pred_load(T *smem_ptr, const T *gmem_ptr, bool predicate) -{ - static_assert(NumBits == 128 || NumBits == 256, - "NumBits must be 128 or 256"); - if constexpr (NumBits == 128) { - pred_load_128b(smem_ptr, gmem_ptr, predicate); - } - else { - pred_load_128b(smem_ptr, gmem_ptr, predicate); - pred_load_128b( - smem_ptr + 16 / sizeof(T), gmem_ptr + 16 / sizeof(T), predicate); - } +template +__device__ __forceinline__ void pred_load(T* smem_ptr, const T* gmem_ptr, bool predicate) { + static_assert(NumBits == 128 || NumBits == 256, "NumBits must be 128 or 256"); + if constexpr (NumBits == 128) { + pred_load_128b(smem_ptr, gmem_ptr, predicate); + } else { + pred_load_128b(smem_ptr, gmem_ptr, predicate); + pred_load_128b(smem_ptr + 16 / sizeof(T), gmem_ptr + 16 / sizeof(T), + predicate); + } } -} // namespace hip -} // namespace detail -} // namespace memory -} // namespace gpu_iface -} // namespace flashinfer +} // namespace hip +} // namespace detail +} // namespace memory +} // namespace gpu_iface +} // namespace flashinfer diff --git a/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h b/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h index 320a2bf818..52229f70fe 100644 --- a/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h +++ b/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h @@ -7,165 +7,142 @@ #include "gpu_iface/mma_types.hpp" #include "gpu_iface/platform.hpp" -namespace -{ +namespace { using f16 = _Float16; using f16x4 = f16 __attribute__((ext_vector_type(4))); using f32x4 = float __attribute__((ext_vector_type(4))); template -__device__ __forceinline__ f32x4 mfma_fp32_16x16x16fp16(f32x4 C, - const f16x4 A, - const f16x4 B) -{ - if constexpr (std::is_same_v) { - return __builtin_amdgcn_mfma_f32_16x16x16f16(A, B, C, 0, 0, 0); - } - else if constexpr (std::is_same_v) { - return __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(A, B, C, 0, 0, 0); - } - return C; +__device__ __forceinline__ f32x4 mfma_fp32_16x16x16fp16(f32x4 C, const f16x4 A, const f16x4 B) { + if constexpr (std::is_same_v) { + return __builtin_amdgcn_mfma_f32_16x16x16f16(A, B, C, 0, 0, 0); + } else if constexpr (std::is_same_v) { + return __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(A, B, C, 0, 0, 0); + } + return C; } -} // namespace +} // namespace -namespace flashinfer -{ -namespace gpu_iface -{ -namespace mma_impl -{ -namespace hip -{ +namespace flashinfer { +namespace gpu_iface { +namespace mma_impl { +namespace hip { #define FLASHINFER_RUNTIME_ASSERT(x) assert(0 && x) -__device__ __forceinline__ void transpose_4x4_half_registers(uint32_t *R) -{ - // Calculate lane within 4-thread group - uint32_t lane_id = threadIdx.x % 64; - uint32_t lane_in_group = lane_id % 4; - - // === ROUND 1: Exchange with neighbor (XOR with 1) === - // T0↔T1, T2↔T3 partial exchange - uint32_t reg_idx = (lane_in_group >> 1) & 0x1; - uint32_t exchanged_val = __shfl_xor(R[reg_idx], 0x1); - uint32_t shift = (lane_in_group & 1) * 16; - uint32_t keep_mask = 0xFFFF0000 >> shift; - int right_shift_amount = 16 * (1 - (lane_in_group & 1)); - int left_shift_amount = 16 * (lane_in_group & 1); - R[reg_idx] = (R[reg_idx] & keep_mask) | - ((exchanged_val >> right_shift_amount) << left_shift_amount); - - // === ROUND 2: Exchange with one hop (XOR with 2) === - // T0↔T2, T1↔T3 exchange R[0] and R[1] - // Swap entire registers based on thread position - uint32_t is_top = 1 - reg_idx; - uint32_t temp0 = __shfl_xor(R[0], 0x2); - uint32_t temp1 = __shfl_xor(R[1], 0x2); - - // Compute both possibilities and select - R[0] = R[0] * is_top + temp1 * reg_idx; - R[1] = temp0 * is_top + R[1] * reg_idx; - - // === ROUND 3: Exchange with neighbor again (XOR with 1) === - // T0↔T1, T2↔T3 exchange remaining parts - - reg_idx = 1 - reg_idx; - exchanged_val = __shfl_xor(R[reg_idx], 0x1); - R[reg_idx] = (R[reg_idx] & keep_mask) | - ((exchanged_val >> right_shift_amount) << left_shift_amount); +__device__ __forceinline__ void transpose_4x4_half_registers(uint32_t* R) { + // Calculate lane within 4-thread group + uint32_t lane_id = threadIdx.x % 64; + uint32_t lane_in_group = lane_id % 4; + + // === ROUND 1: Exchange with neighbor (XOR with 1) === + // T0↔T1, T2↔T3 partial exchange + uint32_t reg_idx = (lane_in_group >> 1) & 0x1; + uint32_t exchanged_val = __shfl_xor(R[reg_idx], 0x1); + uint32_t shift = (lane_in_group & 1) * 16; + uint32_t keep_mask = 0xFFFF0000 >> shift; + int right_shift_amount = 16 * (1 - (lane_in_group & 1)); + int left_shift_amount = 16 * (lane_in_group & 1); + R[reg_idx] = + (R[reg_idx] & keep_mask) | ((exchanged_val >> right_shift_amount) << left_shift_amount); + + // === ROUND 2: Exchange with one hop (XOR with 2) === + // T0↔T2, T1↔T3 exchange R[0] and R[1] + // Swap entire registers based on thread position + uint32_t is_top = 1 - reg_idx; + uint32_t temp0 = __shfl_xor(R[0], 0x2); + uint32_t temp1 = __shfl_xor(R[1], 0x2); + + // Compute both possibilities and select + R[0] = R[0] * is_top + temp1 * reg_idx; + R[1] = temp0 * is_top + R[1] * reg_idx; + + // === ROUND 3: Exchange with neighbor again (XOR with 1) === + // T0↔T1, T2↔T3 exchange remaining parts + + reg_idx = 1 - reg_idx; + exchanged_val = __shfl_xor(R[reg_idx], 0x1); + R[reg_idx] = + (R[reg_idx] & keep_mask) | ((exchanged_val >> right_shift_amount) << left_shift_amount); } // Single unified load function for all fragment types /// @param R [in] pointer to the register file to load the fragment into /// @param smem_ptr [in] pointer to the shared memory to load the fragment from template -__device__ __forceinline__ void load_fragment(uint32_t *R, const T *smem_ptr) -{ - const uint16_t *v0 = reinterpret_cast(smem_ptr) + 0; - const uint16_t *v1 = reinterpret_cast(++smem_ptr); - const uint16_t *v2 = reinterpret_cast(++smem_ptr); - const uint16_t *v3 = reinterpret_cast(++smem_ptr); - - R[0] = (static_cast(*v0) << 16) | - static_cast(*v1); - R[1] = (static_cast(*v2) << 16) | - static_cast(*v3); +__device__ __forceinline__ void load_fragment(uint32_t* R, const T* smem_ptr) { + const uint16_t* v0 = reinterpret_cast(smem_ptr) + 0; + const uint16_t* v1 = reinterpret_cast(++smem_ptr); + const uint16_t* v2 = reinterpret_cast(++smem_ptr); + const uint16_t* v3 = reinterpret_cast(++smem_ptr); + + R[0] = (static_cast(*v0) << 16) | static_cast(*v1); + R[1] = (static_cast(*v2) << 16) | static_cast(*v3); } template -__device__ __forceinline__ void -load_fragment_transpose(uint32_t *R, const T *smem_ptr, uint32_t stride) -{ - const uint16_t *v0 = reinterpret_cast(smem_ptr) + 0; - const uint16_t *v1 = - reinterpret_cast(smem_ptr + 1 * stride); - const uint16_t *v2 = - reinterpret_cast(smem_ptr + 2 * stride); - const uint16_t *v3 = - reinterpret_cast(smem_ptr + 3 * stride); - - R[0] = (static_cast(*v0) << 16) | - static_cast(*v1); - R[1] = (static_cast(*v2) << 16) | - static_cast(*v3); +__device__ __forceinline__ void load_fragment_transpose(uint32_t* R, const T* smem_ptr, + uint32_t stride) { + const uint16_t* v0 = reinterpret_cast(smem_ptr) + 0; + const uint16_t* v1 = reinterpret_cast(smem_ptr + 1 * stride); + const uint16_t* v2 = reinterpret_cast(smem_ptr + 2 * stride); + const uint16_t* v3 = reinterpret_cast(smem_ptr + 3 * stride); + + R[0] = (static_cast(*v0) << 16) | static_cast(*v1); + R[1] = (static_cast(*v2) << 16) | static_cast(*v3); } // MMA operation for FP16 inputs with FP32 accumulator template -__device__ __forceinline__ void -mma_sync_m16n16k16_row_col_f16f16f32(float *C, uint32_t *A, uint32_t *B) -{ - // Ensure T is either __half or __hip_bfloat16 - static_assert(std::is_same_v || - std::is_same_v, - "T must be __half or __hip_bfloat16"); - - // Initialize C if requested - if constexpr (mma_mode == mma::MMAMode::kInit) { - C[0] = 0.0f; - C[1] = 0.0f; - C[2] = 0.0f; - C[3] = 0.0f; - } - - f16x4 B_fp16 = reinterpret_cast(B)[0]; - f16x4 A_fp16 = reinterpret_cast(A)[0]; - f32x4 C_fp32 = reinterpret_cast(C)[0]; - - // Perform MMA operation directly with fragments - C_fp32 = mfma_fp32_16x16x16fp16(C_fp32, A_fp16, B_fp16); - C[0] = C_fp32[0]; - C[1] = C_fp32[1]; - C[2] = C_fp32[2]; - C[3] = C_fp32[3]; +__device__ __forceinline__ void mma_sync_m16n16k16_row_col_f16f16f32(float* C, uint32_t* A, + uint32_t* B) { + // Ensure T is either __half or __hip_bfloat16 + static_assert(std::is_same_v || std::is_same_v, + "T must be __half or __hip_bfloat16"); + + // Initialize C if requested + if constexpr (mma_mode == mma::MMAMode::kInit) { + C[0] = 0.0f; + C[1] = 0.0f; + C[2] = 0.0f; + C[3] = 0.0f; + } + + f16x4 B_fp16 = reinterpret_cast(B)[0]; + f16x4 A_fp16 = reinterpret_cast(A)[0]; + f32x4 C_fp32 = reinterpret_cast(C)[0]; + + // Perform MMA operation directly with fragments + C_fp32 = mfma_fp32_16x16x16fp16(C_fp32, A_fp16, B_fp16); + C[0] = C_fp32[0]; + C[1] = C_fp32[1]; + C[2] = C_fp32[2]; + C[3] = C_fp32[3]; } /// Loads a fragment from LDS to two 32bit registers and then transposes /// the registers for a group of four consecuitive threads. template -__device__ __forceinline__ void -load_fragment_4x4_half_registers(uint32_t *R, const T *smem_ptr) -{ - static_assert(std::is_same_v, "Only half type is supported"); - // Each thread loads 4 __half values in two 32b registers. - load_fragment(R, smem_ptr); - // transposes the values in four adjacent threads. The function does the - // following layout transformation: - // Original data in registers for Threads 0-3 after fragment load - // T0 : a b c d - // T1 : e f g h - // T2 : i j k l - // T3 : m n o p - // - // After transposition: - // T0 : a e i m - // T1 : b f j n - // T2 : c g k o - // T3 : d h l p - - transpose_4x4_half_registers(R); +__device__ __forceinline__ void load_fragment_4x4_half_registers(uint32_t* R, const T* smem_ptr) { + static_assert(std::is_same_v, "Only half type is supported"); + // Each thread loads 4 __half values in two 32b registers. + load_fragment(R, smem_ptr); + // transposes the values in four adjacent threads. The function does the + // following layout transformation: + // Original data in registers for Threads 0-3 after fragment load + // T0 : a b c d + // T1 : e f g h + // T2 : i j k l + // T3 : m n o p + // + // After transposition: + // T0 : a e i m + // T1 : b f j n + // T2 : c g k o + // T3 : d h l p + + transpose_4x4_half_registers(R); } // TODO: Verify correct matrix multiplication order for rowsum on CDNA3 @@ -179,37 +156,33 @@ load_fragment_4x4_half_registers(uint32_t *R, const T *smem_ptr) // - s_frag layout matches expected Q×K^T result // - rowsum produces correct per-row sums template -__device__ __forceinline__ void m16k16_rowsum_f16f16f32(float *d, DType *s_frag) -{ - static_assert(sizeof(DType) == 2, "DType must be 16-bit type"); - transpose_4x4_half_registers(reinterpret_cast(s_frag)); - f16x4 a = reinterpret_cast(s_frag)[0]; - f16x4 b = {f16(1.0f), f16(1.0f), f16(1.0f), f16(1.0f)}; - f32x4 c = {d[0], d[1], d[2], d[3]}; - f32x4 out = __builtin_amdgcn_mfma_f32_16x16x16f16(a, b, c, 0, 0, 0); - d[0] = out.x; - d[1] = out.y; - d[2] = out.z; - d[3] = out.w; +__device__ __forceinline__ void m16k16_rowsum_f16f16f32(float* d, DType* s_frag) { + static_assert(sizeof(DType) == 2, "DType must be 16-bit type"); + transpose_4x4_half_registers(reinterpret_cast(s_frag)); + f16x4 a = reinterpret_cast(s_frag)[0]; + f16x4 b = {f16(1.0f), f16(1.0f), f16(1.0f), f16(1.0f)}; + f32x4 c = {d[0], d[1], d[2], d[3]}; + f32x4 out = __builtin_amdgcn_mfma_f32_16x16x16f16(a, b, c, 0, 0, 0); + d[0] = out.x; + d[1] = out.y; + d[2] = out.z; + d[3] = out.w; } // TODO (rimaddur) : After release 2025.08 // FP8 operations - not implemented for MI300 yet template -__device__ __forceinline__ void -mma_sync_m16n16k32_row_col_f8f8f32(float *c_frag, T *a_frag, T *b_frag) -{ - FLASHINFER_RUNTIME_ASSERT("FP8 MMA not implemented for AMD"); +__device__ __forceinline__ void mma_sync_m16n16k32_row_col_f8f8f32(float* c_frag, T* a_frag, + T* b_frag) { + FLASHINFER_RUNTIME_ASSERT("FP8 MMA not implemented for AMD"); } template -__device__ __forceinline__ void m16k32_rowsum_f8f8f32(float *d_frag, - DType *s_frag) -{ - FLASHINFER_RUNTIME_ASSERT("FP8 rowsum not implemented for AMD"); +__device__ __forceinline__ void m16k32_rowsum_f8f8f32(float* d_frag, DType* s_frag) { + FLASHINFER_RUNTIME_ASSERT("FP8 rowsum not implemented for AMD"); } -} // namespace hip -} // namespace mma_impl -} // namespace gpu_iface -} // namespace flashinfer +} // namespace hip +} // namespace mma_impl +} // namespace gpu_iface +} // namespace flashinfer diff --git a/libflashinfer/include/gpu_iface/backend/hip/vec_dtypes_hip.h b/libflashinfer/include/gpu_iface/backend/hip/vec_dtypes_hip.h index cc903a753f..9fec484196 100644 --- a/libflashinfer/include/gpu_iface/backend/hip/vec_dtypes_hip.h +++ b/libflashinfer/include/gpu_iface/backend/hip/vec_dtypes_hip.h @@ -20,157 +20,133 @@ #define FLASHINFER_INLINE inline __attribute__((always_inline)) __device__ -__host__ __device__ inline __hip_bfloat162 __float2bfloat162_rn(const float a) -{ - return __hip_bfloat162{__float2bfloat16(a), __float2bfloat16(a)}; +__host__ __device__ inline __hip_bfloat162 __float2bfloat162_rn(const float a) { + return __hip_bfloat162{__float2bfloat16(a), __float2bfloat16(a)}; } -FLASHINFER_INLINE __hip_bfloat162 make_bfloat162(const __hip_bfloat16 x, - const __hip_bfloat16 y) -{ - __hip_bfloat162 t; - t.x = x; - t.y = y; - return t; +FLASHINFER_INLINE __hip_bfloat162 make_bfloat162(const __hip_bfloat16 x, const __hip_bfloat16 y) { + __hip_bfloat162 t; + t.x = x; + t.y = y; + return t; } -namespace detail -{ -namespace hip -{ +namespace detail { +namespace hip { #define FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED #define FLASHINFER_INLINE inline __attribute__((always_inline)) __device__ -#if (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 < 120400) && \ +#if (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 < 120400) && \ (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)) // CUDA version < 12.4 and GPU architecture < 80 -FLASHINFER_INLINE __hip_bfloat16 __hmul(const __hip_bfloat16 a, - const __hip_bfloat16 b) -{ - __hip_bfloat16 val; - const float fa = __bfloat162float(a); - const float fb = __bfloat162float(b); - // avoid ftz in device code - val = __float2bfloat16(__fmaf_ieee_rn(fa, fb, -0.0f)); - return val; +FLASHINFER_INLINE __hip_bfloat16 __hmul(const __hip_bfloat16 a, const __hip_bfloat16 b) { + __hip_bfloat16 val; + const float fa = __bfloat162float(a); + const float fb = __bfloat162float(b); + // avoid ftz in device code + val = __float2bfloat16(__fmaf_ieee_rn(fa, fb, -0.0f)); + return val; } -FLASHINFER_INLINE __hip_bfloat162 __hmul2(const __hip_bfloat162 a, - const __hip_bfloat162 b) -{ - __hip_bfloat162 val; - val.x = __hmul(a.x, b.x); - val.y = __hmul(a.y, b.y); - return val; +FLASHINFER_INLINE __hip_bfloat162 __hmul2(const __hip_bfloat162 a, const __hip_bfloat162 b) { + __hip_bfloat162 val; + val.x = __hmul(a.x, b.x); + val.y = __hmul(a.y, b.y); + return val; } -FLASHINFER_INLINE __hip_bfloat162 __floats2bfloat162_rn(const float a, - const float b) -{ - __hip_bfloat162 val; - val = __hip_bfloat162(__float2bfloat16(a), __float2bfloat16(b)); - return val; +FLASHINFER_INLINE __hip_bfloat162 __floats2bfloat162_rn(const float a, const float b) { + __hip_bfloat162 val; + val = __hip_bfloat162(__float2bfloat16(a), __float2bfloat16(b)); + return val; } -FLASHINFER_INLINE __hip_bfloat162 __float22bfloat162_rn(const float2 a) -{ - __hip_bfloat162 val = __float22bfloat162_rn(a.x, a.y); - return val; +FLASHINFER_INLINE __hip_bfloat162 __float22bfloat162_rn(const float2 a) { + __hip_bfloat162 val = __float22bfloat162_rn(a.x, a.y); + return val; } -FLASHINFER_INLINE float2 __bfloat1622float2(const __hip_bfloat162 a) -{ - float hi_float; - float lo_float; - // lo_float = __internal_bfloat162float(((__gpu_bfloat162_raw)a).x); - // hi_float = __internal_bfloat162float(((__gpu_bfloat162_raw)a).y); - lo_float = __bfloat1622float2(a.x); - hi_float = __bfloat1622float2(a.y); - return make_float2(lo_float, hi_float); +FLASHINFER_INLINE float2 __bfloat1622float2(const __hip_bfloat162 a) { + float hi_float; + float lo_float; + // lo_float = __internal_bfloat162float(((__gpu_bfloat162_raw)a).x); + // hi_float = __internal_bfloat162float(((__gpu_bfloat162_raw)a).y); + lo_float = __bfloat1622float2(a.x); + hi_float = __bfloat1622float2(a.y); + return make_float2(lo_float, hi_float); } #endif /******************* vec_t type cast *******************/ -template struct vec_cast -{ - template - FLASHINFER_INLINE static void cast(dst_t *dst, const src_t *src) - { +template +struct vec_cast { + template + FLASHINFER_INLINE static void cast(dst_t* dst, const src_t* src) { #pragma unroll - for (size_t i = 0; i < vec_size; ++i) { - dst[i] = (dst_t)src[i]; - } + for (size_t i = 0; i < vec_size; ++i) { + dst[i] = (dst_t)src[i]; } + } }; -template <> struct vec_cast -{ - template - FLASHINFER_INLINE static void cast(float *dst, const half *src) - { - if constexpr (vec_size == 1) { - // dst[0] = (float)src[0]; - dst[0] = __half2float(src[0]); - } - else { +template <> +struct vec_cast { + template + FLASHINFER_INLINE static void cast(float* dst, const half* src) { + if constexpr (vec_size == 1) { + // dst[0] = (float)src[0]; + dst[0] = __half2float(src[0]); + } else { #pragma unroll - for (size_t i = 0; i < vec_size / 2; ++i) { - ((float2 *)dst)[i] = __half22float2(((half2 *)src)[i]); - } - } + for (size_t i = 0; i < vec_size / 2; ++i) { + ((float2*)dst)[i] = __half22float2(((half2*)src)[i]); + } } + } }; -template <> struct vec_cast -{ - template - FLASHINFER_INLINE static void cast(half *dst, const float *src) - { - if constexpr (vec_size == 1) { - dst[0] = __float2half(src[0]); - } - else { +template <> +struct vec_cast { + template + FLASHINFER_INLINE static void cast(half* dst, const float* src) { + if constexpr (vec_size == 1) { + dst[0] = __float2half(src[0]); + } else { #pragma unroll - for (size_t i = 0; i < vec_size / 2; ++i) { - ((half2 *)dst)[i] = __float22half2_rn(((float2 *)src)[i]); - } - } + for (size_t i = 0; i < vec_size / 2; ++i) { + ((half2*)dst)[i] = __float22half2_rn(((float2*)src)[i]); + } } + } }; -template constexpr FLASHINFER_INLINE int get_exponent_bits() -{ - if constexpr (std::is_same_v) { - return 4; - } - else if constexpr (std::is_same_v) { - return 5; - } - else if constexpr (std::is_same_v) { - return 5; - } - else if constexpr (std::is_same_v) { - return 8; - } -} - -template constexpr FLASHINFER_INLINE int get_mantissa_bits() -{ - if constexpr (std::is_same_v) { - return 3; - } - else if constexpr (std::is_same_v) { - return 2; - } - else if constexpr (std::is_same_v) { - return 11; - } - else if constexpr (std::is_same_v) { - return 7; - } +template +constexpr FLASHINFER_INLINE int get_exponent_bits() { + if constexpr (std::is_same_v) { + return 4; + } else if constexpr (std::is_same_v) { + return 5; + } else if constexpr (std::is_same_v) { + return 5; + } else if constexpr (std::is_same_v) { + return 8; + } +} + +template +constexpr FLASHINFER_INLINE int get_mantissa_bits() { + if constexpr (std::is_same_v) { + return 3; + } else if constexpr (std::is_same_v) { + return 2; + } else if constexpr (std::is_same_v) { + return 11; + } else if constexpr (std::is_same_v) { + return 7; + } } /*! @@ -182,207 +158,180 @@ template constexpr FLASHINFER_INLINE int get_mantissa_bits() * https://github.com/vllm-project/vllm/blob/6dffa4b0a6120159ef2fe44d695a46817aff65bc/csrc/quantization/fp8/fp8_marlin.cu#L120 */ template -__device__ void fast_dequant_f8f16x4(uint32_t *input, uint2 *output) -{ - uint32_t q = *input; - if constexpr (std::is_same_v && - std::is_same_v) - { - output->x = __byte_perm(0U, q, 0x5140); - output->y = __byte_perm(0U, q, 0x7362); - } - else { - constexpr int FP8_EXPONENT = get_exponent_bits(); - constexpr int FP8_MANTISSA = get_mantissa_bits(); - constexpr int FP16_EXPONENT = get_exponent_bits(); - - constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP8_EXPONENT; - // Calculate MASK for extracting mantissa and exponent - // XXX: duplicate defs of `MASK1` and `MASK2`, - // in the HIP file "include/hip/amd_detail/amd_device_functions.h". - constexpr int MASK1_orig = 0x80000000; - constexpr int MASK2_orig = MASK1_orig >> (FP8_EXPONENT + FP8_MANTISSA); - constexpr int MASK3 = MASK2_orig & 0x7fffffff; - constexpr int MASK = MASK3 | (MASK3 >> 16); - q = __byte_perm(q, q, 0x1302); - - // Extract and shift FP8 values to FP16 format - uint32_t Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); - uint32_t Out2 = - ((q << 8) & 0x80008000) | (((q << 8) & MASK) >> RIGHT_SHIFT); - - constexpr int BIAS_OFFSET = - (1 << (FP16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1)); - // Construct and apply exponent bias - if constexpr (std::is_same_v) { - const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET)); - - // Convert to half2 and apply bias - *(half2 *)&(output->x) = - __hmul2(*reinterpret_cast(&Out1), bias_reg); - *(half2 *)&(output->y) = - __hmul2(*reinterpret_cast(&Out2), bias_reg); - } - else { - constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23; - const __hip_bfloat162 bias_reg = - __float2bfloat162_rn(*reinterpret_cast(&BIAS)); - // Convert to bfloat162 and apply bias - *(__hip_bfloat162 *)&(output->x) = __hmul2( - *reinterpret_cast(&Out1), bias_reg); - *(__hip_bfloat162 *)&(output->y) = __hmul2( - *reinterpret_cast(&Out2), bias_reg); - } - } -} - -template <> struct vec_cast<__hip_bfloat16, __hip_fp8_e4m3_fnuz> -{ - template - FLASHINFER_INLINE static void cast(__hip_bfloat16 *dst, - const __hip_fp8_e4m3_fnuz *src) - { - if constexpr (vec_size == 1) { - dst[0] = __hip_bfloat16(src[0]); - } - else if constexpr (vec_size == 2) { - dst[0] = __hip_bfloat16(src[0]); - dst[1] = __hip_bfloat16(src[1]); - } - else { - static_assert(vec_size % 4 == 0, - "vec_size must be a multiple of 4"); +__device__ void fast_dequant_f8f16x4(uint32_t* input, uint2* output) { + uint32_t q = *input; + if constexpr (std::is_same_v && + std::is_same_v) { + output->x = __byte_perm(0U, q, 0x5140); + output->y = __byte_perm(0U, q, 0x7362); + } else { + constexpr int FP8_EXPONENT = get_exponent_bits(); + constexpr int FP8_MANTISSA = get_mantissa_bits(); + constexpr int FP16_EXPONENT = get_exponent_bits(); + + constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP8_EXPONENT; + // Calculate MASK for extracting mantissa and exponent + // XXX: duplicate defs of `MASK1` and `MASK2`, + // in the HIP file "include/hip/amd_detail/amd_device_functions.h". + constexpr int MASK1_orig = 0x80000000; + constexpr int MASK2_orig = MASK1_orig >> (FP8_EXPONENT + FP8_MANTISSA); + constexpr int MASK3 = MASK2_orig & 0x7fffffff; + constexpr int MASK = MASK3 | (MASK3 >> 16); + q = __byte_perm(q, q, 0x1302); + + // Extract and shift FP8 values to FP16 format + uint32_t Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + uint32_t Out2 = ((q << 8) & 0x80008000) | (((q << 8) & MASK) >> RIGHT_SHIFT); + + constexpr int BIAS_OFFSET = (1 << (FP16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1)); + // Construct and apply exponent bias + if constexpr (std::is_same_v) { + const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET)); + + // Convert to half2 and apply bias + *(half2*)&(output->x) = __hmul2(*reinterpret_cast(&Out1), bias_reg); + *(half2*)&(output->y) = __hmul2(*reinterpret_cast(&Out2), bias_reg); + } else { + constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23; + const __hip_bfloat162 bias_reg = __float2bfloat162_rn(*reinterpret_cast(&BIAS)); + // Convert to bfloat162 and apply bias + *(__hip_bfloat162*)&(output->x) = + __hmul2(*reinterpret_cast(&Out1), bias_reg); + *(__hip_bfloat162*)&(output->y) = + __hmul2(*reinterpret_cast(&Out2), bias_reg); + } + } +} + +template <> +struct vec_cast<__hip_bfloat16, __hip_fp8_e4m3_fnuz> { + template + FLASHINFER_INLINE static void cast(__hip_bfloat16* dst, const __hip_fp8_e4m3_fnuz* src) { + if constexpr (vec_size == 1) { + dst[0] = __hip_bfloat16(src[0]); + } else if constexpr (vec_size == 2) { + dst[0] = __hip_bfloat16(src[0]); + dst[1] = __hip_bfloat16(src[1]); + } else { + static_assert(vec_size % 4 == 0, "vec_size must be a multiple of 4"); #pragma unroll - for (uint32_t i = 0; i < vec_size / 4; ++i) { - fast_dequant_f8f16x4<__hip_fp8_e4m3_fnuz, __hip_bfloat16>( - (uint32_t *)&src[i * 4], (uint2 *)&dst[i * 4]); - } - } + for (uint32_t i = 0; i < vec_size / 4; ++i) { + fast_dequant_f8f16x4<__hip_fp8_e4m3_fnuz, __hip_bfloat16>((uint32_t*)&src[i * 4], + (uint2*)&dst[i * 4]); + } } + } }; -template <> struct vec_cast<__hip_bfloat16, __hip_fp8_e5m2_fnuz> -{ - template - FLASHINFER_INLINE static void cast(__hip_bfloat16 *dst, - const __hip_fp8_e5m2_fnuz *src) - { - if constexpr (vec_size == 1) { - dst[0] = __hip_bfloat16(src[0]); - } - else if constexpr (vec_size == 2) { - dst[0] = __hip_bfloat16(src[0]); - dst[1] = __hip_bfloat16(src[1]); - } - else { - static_assert(vec_size % 4 == 0, - "vec_size must be a multiple of 4"); +template <> +struct vec_cast<__hip_bfloat16, __hip_fp8_e5m2_fnuz> { + template + FLASHINFER_INLINE static void cast(__hip_bfloat16* dst, const __hip_fp8_e5m2_fnuz* src) { + if constexpr (vec_size == 1) { + dst[0] = __hip_bfloat16(src[0]); + } else if constexpr (vec_size == 2) { + dst[0] = __hip_bfloat16(src[0]); + dst[1] = __hip_bfloat16(src[1]); + } else { + static_assert(vec_size % 4 == 0, "vec_size must be a multiple of 4"); #pragma unroll - for (uint32_t i = 0; i < vec_size / 4; ++i) { - fast_dequant_f8f16x4<__hip_fp8_e5m2_fnuz, __hip_bfloat16>( - (uint32_t *)&src[i * 4], (uint2 *)&dst[i * 4]); - } - } + for (uint32_t i = 0; i < vec_size / 4; ++i) { + fast_dequant_f8f16x4<__hip_fp8_e5m2_fnuz, __hip_bfloat16>((uint32_t*)&src[i * 4], + (uint2*)&dst[i * 4]); + } } + } }; // Function to convert float to e4m3 -__device__ uint8_t convert_f32_to_e4m3(float val) -{ - // Define the range of e4m3 - // 1. Minimum representable value for e4m3 - // 2. Binary 1000.000 in e4m3 - // 3. FLT_MIN is not suitable for e4m3 because e4m3 has a much smaller - // dynamic range. - float min_e4m3 = -8.0f; - // 1. Maximum representable value for e4m3 - // 2. Binary 0111.111 in e4m3 - // FLT_MAX far exceeds the maximum value representable in e4m3. - float max_e4m3 = 7.875f; - - // Saturate the value to the e4m3 range - val = fminf(fmaxf(val, min_e4m3), max_e4m3); - - // Perform conversion - // Decompose into mantissa and exponent - int exp; - float mantissa = frexpf(val, &exp); - - // Encode sign bit - uint8_t sign = (mantissa < 0) ? 0x80 : 0x00; - - // Normalize mantissa and encode exponent - mantissa = - fabsf(mantissa) * 16.0f; // Scale mantissa for e4m3's 3-bit precision - uint8_t exponent = static_cast(exp + 7); // Bias of 7 for e4m3 - - // Quantize mantissa - // Apply round-to-nearest-even to the mantissa - uint8_t quant_mantissa = static_cast(roundf(mantissa)) & 0x07; - - // Combine into 8 bits: [sign][exponent][mantissa] - return sign | (exponent << 3) | quant_mantissa; -} - -__device__ __half2 convert_uint32_to_half2(uint32_t input) -{ - // Extract the low and high 16 bits - uint16_t low_val = input & 0xFFFF; - uint16_t high_val = (input >> 16) & 0xFFFF; - // Convert to __half - __half low_half = __float2half(static_cast(low_val)); - __half high_half = __float2half(static_cast(high_val)); - // Pack into __half2 - return __halves2half2(low_half, high_half); +__device__ uint8_t convert_f32_to_e4m3(float val) { + // Define the range of e4m3 + // 1. Minimum representable value for e4m3 + // 2. Binary 1000.000 in e4m3 + // 3. FLT_MIN is not suitable for e4m3 because e4m3 has a much smaller + // dynamic range. + float min_e4m3 = -8.0f; + // 1. Maximum representable value for e4m3 + // 2. Binary 0111.111 in e4m3 + // FLT_MAX far exceeds the maximum value representable in e4m3. + float max_e4m3 = 7.875f; + + // Saturate the value to the e4m3 range + val = fminf(fmaxf(val, min_e4m3), max_e4m3); + + // Perform conversion + // Decompose into mantissa and exponent + int exp; + float mantissa = frexpf(val, &exp); + + // Encode sign bit + uint8_t sign = (mantissa < 0) ? 0x80 : 0x00; + + // Normalize mantissa and encode exponent + mantissa = fabsf(mantissa) * 16.0f; // Scale mantissa for e4m3's 3-bit precision + uint8_t exponent = static_cast(exp + 7); // Bias of 7 for e4m3 + + // Quantize mantissa + // Apply round-to-nearest-even to the mantissa + uint8_t quant_mantissa = static_cast(roundf(mantissa)) & 0x07; + + // Combine into 8 bits: [sign][exponent][mantissa] + return sign | (exponent << 3) | quant_mantissa; +} + +__device__ __half2 convert_uint32_to_half2(uint32_t input) { + // Extract the low and high 16 bits + uint16_t low_val = input & 0xFFFF; + uint16_t high_val = (input >> 16) & 0xFFFF; + // Convert to __half + __half low_half = __float2half(static_cast(low_val)); + __half high_half = __float2half(static_cast(high_val)); + // Pack into __half2 + return __halves2half2(low_half, high_half); } // Convert f16x2 (__half2) to e4m3x2 (packed 16-bit) -__device__ uint16_t convert_f16x2_to_e4m3x2(__half2 x) -{ - float f32_0 = __half2float(__low2half(x)); - float f32_1 = __half2float(__high2half(x)); - uint8_t e4m3_0 = convert_f32_to_e4m3(f32_0); - uint8_t e4m3_1 = convert_f32_to_e4m3(f32_1); - return (static_cast(e4m3_1) << 8) | e4m3_0; -} - -template <> struct vec_cast<__hip_fp8_e4m3_fnuz, half> -{ - template - FLASHINFER_INLINE static void cast(__hip_fp8_e4m3_fnuz *dst, - const half *src) - { +__device__ uint16_t convert_f16x2_to_e4m3x2(__half2 x) { + float f32_0 = __half2float(__low2half(x)); + float f32_1 = __half2float(__high2half(x)); + uint8_t e4m3_0 = convert_f32_to_e4m3(f32_0); + uint8_t e4m3_1 = convert_f32_to_e4m3(f32_1); + return (static_cast(e4m3_1) << 8) | e4m3_0; +} + +template <> +struct vec_cast<__hip_fp8_e4m3_fnuz, half> { + template + FLASHINFER_INLINE static void cast(__hip_fp8_e4m3_fnuz* dst, const half* src) { #ifdef FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED - if constexpr (vec_size == 1) { - dst[0] = __hip_fp8_e4m3_fnuz(src[0]); - } - else { + if constexpr (vec_size == 1) { + dst[0] = __hip_fp8_e4m3_fnuz(src[0]); + } else { #pragma unroll - for (size_t i = 0; i < vec_size / 2; ++i) { - uint16_t y; - uint32_t x = *(uint32_t *)&src[i * 2]; - __half2 x_h2 = convert_uint32_to_half2(x); - y = convert_f16x2_to_e4m3x2(x_h2); - - *(uint16_t *)&dst[i * 2] = y; - } - } + for (size_t i = 0; i < vec_size / 2; ++i) { + uint16_t y; + uint32_t x = *(uint32_t*)&src[i * 2]; + __half2 x_h2 = convert_uint32_to_half2(x); + y = convert_f16x2_to_e4m3x2(x_h2); + + *(uint16_t*)&dst[i * 2] = y; + } + } #else #pragma unroll - for (size_t i = 0; i < vec_size; ++i) { - dst[i] = __hip_fp8_e4m3_fnuz(src[i]); - } -#endif // FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED + for (size_t i = 0; i < vec_size; ++i) { + dst[i] = __hip_fp8_e4m3_fnuz(src[i]); } +#endif // FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED + } }; -__device__ uint16_t convert_f16x2_to_e5m2x2(uint32_t x) -{ - // Unpack the two 16-bit half-precision floats from the input - // Extract lower 16 bits - __half h1 = __ushort_as_half(x & 0xFFFF); - // Extract upper 16 bits - __half h2 = __ushort_as_half((x >> 16) & 0xFFFF); +__device__ uint16_t convert_f16x2_to_e5m2x2(uint32_t x) { + // Unpack the two 16-bit half-precision floats from the input + // Extract lower 16 bits + __half h1 = __ushort_as_half(x & 0xFFFF); + // Extract upper 16 bits + __half h2 = __ushort_as_half((x >> 16) & 0xFFFF); #if 0 // Alternative with `__uint2half_rn` @@ -392,1621 +341,1293 @@ __device__ uint16_t convert_f16x2_to_e5m2x2(uint32_t x) __half h2 = __uint2half_rn(val2); #endif - // Define the range of e5m2 - // Minimum representable value for e5m2 - const float min_e5m2 = -8.0f; - // Maximum representable value for e5m2 - const float max_e5m2 = 7.75f; + // Define the range of e5m2 + // Minimum representable value for e5m2 + const float min_e5m2 = -8.0f; + // Maximum representable value for e5m2 + const float max_e5m2 = 7.75f; - // Helper lambda for conversion - auto f32_to_e5m2 = [min_e5m2, max_e5m2](float val) -> uint8_t { - // Saturate the val - val = fminf(fmaxf(val, min_e5m2), max_e5m2); + // Helper lambda for conversion + auto f32_to_e5m2 = [min_e5m2, max_e5m2](float val) -> uint8_t { + // Saturate the val + val = fminf(fmaxf(val, min_e5m2), max_e5m2); - // Decompose into mantissa and exponent - int exp; - float mantissa = frexpf(val, &exp); + // Decompose into mantissa and exponent + int exp; + float mantissa = frexpf(val, &exp); - // Encode sign bit - uint8_t sign = (mantissa < 0) ? 0x10 : 0x00; // Sign in bit 4 - mantissa = fabsf(mantissa); + // Encode sign bit + uint8_t sign = (mantissa < 0) ? 0x10 : 0x00; // Sign in bit 4 + mantissa = fabsf(mantissa); - // Normalize mantissa and encode exponent - mantissa *= 4.0f; // Scale for 2-bit mantissa - uint8_t exponent = static_cast(exp + 7); // Apply bias for e5m2 + // Normalize mantissa and encode exponent + mantissa *= 4.0f; // Scale for 2-bit mantissa + uint8_t exponent = static_cast(exp + 7); // Apply bias for e5m2 - // Apply round-to-nearest-even - uint8_t quant_mantissa = static_cast(roundf(mantissa)) & 0x03; + // Apply round-to-nearest-even + uint8_t quant_mantissa = static_cast(roundf(mantissa)) & 0x03; - // Combine into 5 bits: [sign][exponent][mantissa] - return sign | (exponent << 2) | quant_mantissa; - }; + // Combine into 5 bits: [sign][exponent][mantissa] + return sign | (exponent << 2) | quant_mantissa; + }; - // Convert the two __half values to e5m2 - uint8_t e5m2_1 = f32_to_e5m2(__half2float(h1)); - uint8_t e5m2_2 = f32_to_e5m2(__half2float(h2)); + // Convert the two __half values to e5m2 + uint8_t e5m2_1 = f32_to_e5m2(__half2float(h1)); + uint8_t e5m2_2 = f32_to_e5m2(__half2float(h2)); - // Pack the two e5m2 values into a single 16-bit output - return (e5m2_2 << 8) | e5m2_1; + // Pack the two e5m2 values into a single 16-bit output + return (e5m2_2 << 8) | e5m2_1; } #endif -template <> struct vec_cast<__hip_fp8_e5m2_fnuz, half> -{ - template - FLASHINFER_INLINE static void cast(__hip_fp8_e5m2_fnuz *dst, - const half *src) - { +template <> +struct vec_cast<__hip_fp8_e5m2_fnuz, half> { + template + FLASHINFER_INLINE static void cast(__hip_fp8_e5m2_fnuz* dst, const half* src) { #ifdef FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED - if constexpr (vec_size == 1) { - dst[0] = __hip_fp8_e5m2_fnuz(src[0]); - } - else { + if constexpr (vec_size == 1) { + dst[0] = __hip_fp8_e5m2_fnuz(src[0]); + } else { #pragma unroll - for (size_t i = 0; i < vec_size / 2; ++i) { - uint16_t y; - uint32_t x = *(uint32_t *)&src[i * 2]; - y = convert_f16x2_to_e5m2x2(x); - *(uint16_t *)&dst[i * 2] = y; - } - } + for (size_t i = 0; i < vec_size / 2; ++i) { + uint16_t y; + uint32_t x = *(uint32_t*)&src[i * 2]; + y = convert_f16x2_to_e5m2x2(x); + *(uint16_t*)&dst[i * 2] = y; + } + } #else #pragma unroll - for (size_t i = 0; i < vec_size; ++i) { - dst[i] = __hip_fp8_e5m2_fnuz(src[i]); - } -#endif // FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED + for (size_t i = 0; i < vec_size; ++i) { + dst[i] = __hip_fp8_e5m2_fnuz(src[i]); } +#endif // FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED + } }; -__device__ uint32_t convert_e4m3x2_to_f16x2(uint16_t x) -{ - // Extract two e4m3 values from the 16-bit input - uint8_t e4m3_1 = x & 0xFF; // Lower 8 bits - uint8_t e4m3_2 = (x >> 8) & 0xFF; // Upper 8 bits - - // Decode e4m3 to float - auto e4m3_to_f32 = [](uint8_t e4m3) -> float { - // Extract sign, exponent, and mantissa - int sign = (e4m3 & 0x80) ? -1 : 1; - int exponent = ((e4m3 >> 3) & 0x0F) - 7; // 4-bit exponent with bias 7 - int mantissa = e4m3 & 0x07; // 3-bit mantissa - - // Handle special case: zero - if (exponent == -7 && mantissa == 0) { - return 0.0f; - } - - // Convert to float - float f32_val = sign * ldexpf(1.0f + mantissa / 8.0f, exponent); - return f32_val; - }; - - float f1 = e4m3_to_f32(e4m3_1); - float f2 = e4m3_to_f32(e4m3_2); - - // Convert float to IEEE f16 - __half h1 = __float2half_rn(f1); - __half h2 = __float2half_rn(f2); - - // Pack the two f16 values into a single uint32_t - uint32_t f16x2 = (__half_as_ushort(h2) << 16) | __half_as_ushort(h1); - return f16x2; +__device__ uint32_t convert_e4m3x2_to_f16x2(uint16_t x) { + // Extract two e4m3 values from the 16-bit input + uint8_t e4m3_1 = x & 0xFF; // Lower 8 bits + uint8_t e4m3_2 = (x >> 8) & 0xFF; // Upper 8 bits + + // Decode e4m3 to float + auto e4m3_to_f32 = [](uint8_t e4m3) -> float { + // Extract sign, exponent, and mantissa + int sign = (e4m3 & 0x80) ? -1 : 1; + int exponent = ((e4m3 >> 3) & 0x0F) - 7; // 4-bit exponent with bias 7 + int mantissa = e4m3 & 0x07; // 3-bit mantissa + + // Handle special case: zero + if (exponent == -7 && mantissa == 0) { + return 0.0f; + } + + // Convert to float + float f32_val = sign * ldexpf(1.0f + mantissa / 8.0f, exponent); + return f32_val; + }; + + float f1 = e4m3_to_f32(e4m3_1); + float f2 = e4m3_to_f32(e4m3_2); + + // Convert float to IEEE f16 + __half h1 = __float2half_rn(f1); + __half h2 = __float2half_rn(f2); + + // Pack the two f16 values into a single uint32_t + uint32_t f16x2 = (__half_as_ushort(h2) << 16) | __half_as_ushort(h1); + return f16x2; } -template <> struct vec_cast -{ - template - FLASHINFER_INLINE static void cast(half *dst, - const __hip_fp8_e4m3_fnuz *src) - { +template <> +struct vec_cast { + template + FLASHINFER_INLINE static void cast(half* dst, const __hip_fp8_e4m3_fnuz* src) { #ifdef FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED - if constexpr (vec_size == 1) { - dst[0] = half(src[0]); - } - else { + if constexpr (vec_size == 1) { + dst[0] = half(src[0]); + } else { #pragma unroll - for (size_t i = 0; i < vec_size / 2; ++i) { - uint32_t y; - uint16_t x = *(uint16_t *)&src[i * 2]; - y = convert_e4m3x2_to_f16x2(x); - - *(uint32_t *)&dst[i * 2] = y; - } - } + for (size_t i = 0; i < vec_size / 2; ++i) { + uint32_t y; + uint16_t x = *(uint16_t*)&src[i * 2]; + y = convert_e4m3x2_to_f16x2(x); + + *(uint32_t*)&dst[i * 2] = y; + } + } #else - if constexpr (vec_size == 1) { - dst[0] = half(src[0]); - } - else if constexpr (vec_size == 2) { - dst[0] = half(src[0]); - dst[1] = half(src[1]); - } - else { - static_assert(vec_size % 4 == 0, - "vec_size must be a multiple of 4"); + if constexpr (vec_size == 1) { + dst[0] = half(src[0]); + } else if constexpr (vec_size == 2) { + dst[0] = half(src[0]); + dst[1] = half(src[1]); + } else { + static_assert(vec_size % 4 == 0, "vec_size must be a multiple of 4"); #pragma unroll - for (uint32_t i = 0; i < vec_size / 4; ++i) { - fast_dequant_f8f16x4<__hip_fp8_e4m3_fnuz, half>( - (uint32_t *)&src[i * 4], (uint2 *)&dst[i * 4]); - } - } -#endif // FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED + for (uint32_t i = 0; i < vec_size / 4; ++i) { + fast_dequant_f8f16x4<__hip_fp8_e4m3_fnuz, half>((uint32_t*)&src[i * 4], + (uint2*)&dst[i * 4]); + } } +#endif // FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED + } }; -__device__ uint32_t convert_e5m2x2_to_f16x2(uint16_t x) -{ - // Extract two e5m2 values from the 16-bit input - uint8_t e5m2_1 = x & 0xFF; // Lower 8 bits - uint8_t e5m2_2 = (x >> 8) & 0xFF; // Upper 8 bits - - // Decode e5m2 to float - auto e5m2_to_f32 = [](uint8_t e5m2) -> float { - // Extract sign, exponent, and mantissa - int sign = (e5m2 & 0x80) ? -1 : 1; // Sign bit - int exponent = ((e5m2 >> 2) & 0x1F) - 15; // 5-bit exponent with bias 15 - int mantissa = e5m2 & 0x03; // 2-bit mantissa - - // Handle special case: zero - if (exponent == -15 && mantissa == 0) { - return 0.0f; - } - - // Convert to float - float value = sign * ldexpf(1.0f + mantissa / 4.0f, exponent); - return value; - }; - - float f1 = e5m2_to_f32(e5m2_1); - float f2 = e5m2_to_f32(e5m2_2); - - // Convert float to IEEE f16 - __half h1 = __float2half_rn(f1); - __half h2 = __float2half_rn(f2); - - // Pack the two f16 values into a single uint32_t - uint32_t f16x2 = (__half_as_ushort(h2) << 16) | __half_as_ushort(h1); - return f16x2; +__device__ uint32_t convert_e5m2x2_to_f16x2(uint16_t x) { + // Extract two e5m2 values from the 16-bit input + uint8_t e5m2_1 = x & 0xFF; // Lower 8 bits + uint8_t e5m2_2 = (x >> 8) & 0xFF; // Upper 8 bits + + // Decode e5m2 to float + auto e5m2_to_f32 = [](uint8_t e5m2) -> float { + // Extract sign, exponent, and mantissa + int sign = (e5m2 & 0x80) ? -1 : 1; // Sign bit + int exponent = ((e5m2 >> 2) & 0x1F) - 15; // 5-bit exponent with bias 15 + int mantissa = e5m2 & 0x03; // 2-bit mantissa + + // Handle special case: zero + if (exponent == -15 && mantissa == 0) { + return 0.0f; + } + + // Convert to float + float value = sign * ldexpf(1.0f + mantissa / 4.0f, exponent); + return value; + }; + + float f1 = e5m2_to_f32(e5m2_1); + float f2 = e5m2_to_f32(e5m2_2); + + // Convert float to IEEE f16 + __half h1 = __float2half_rn(f1); + __half h2 = __float2half_rn(f2); + + // Pack the two f16 values into a single uint32_t + uint32_t f16x2 = (__half_as_ushort(h2) << 16) | __half_as_ushort(h1); + return f16x2; } -template <> struct vec_cast -{ - template - FLASHINFER_INLINE static void cast(half *dst, - const __hip_fp8_e5m2_fnuz *src) - { +template <> +struct vec_cast { + template + FLASHINFER_INLINE static void cast(half* dst, const __hip_fp8_e5m2_fnuz* src) { #ifdef FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED - if constexpr (vec_size == 1) { - dst[0] = half(src[0]); - } - else { + if constexpr (vec_size == 1) { + dst[0] = half(src[0]); + } else { #pragma unroll - for (size_t i = 0; i < vec_size / 2; ++i) { - uint32_t y; - uint16_t x = *(uint16_t *)&src[i * 2]; - y = convert_e5m2x2_to_f16x2(x); - *(uint32_t *)&dst[i * 2] = y; - } - } + for (size_t i = 0; i < vec_size / 2; ++i) { + uint32_t y; + uint16_t x = *(uint16_t*)&src[i * 2]; + y = convert_e5m2x2_to_f16x2(x); + *(uint32_t*)&dst[i * 2] = y; + } + } #else - if constexpr (vec_size == 1) { - dst[0] = half(src[0]); - } - else if constexpr (vec_size == 2) { - dst[0] = half(src[0]); - dst[1] = half(src[1]); - } - else { - static_assert(vec_size % 4 == 0, - "vec_size must be a multiple of 4"); + if constexpr (vec_size == 1) { + dst[0] = half(src[0]); + } else if constexpr (vec_size == 2) { + dst[0] = half(src[0]); + dst[1] = half(src[1]); + } else { + static_assert(vec_size % 4 == 0, "vec_size must be a multiple of 4"); #pragma unroll - for (uint32_t i = 0; i < vec_size / 4; ++i) { - fast_dequant_f8f16x4<__hip_fp8_e5m2_fnuz, half>( - (uint32_t *)&src[i * 4], (uint2 *)&dst[i * 4]); - } - } -#endif // FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED + for (uint32_t i = 0; i < vec_size / 4; ++i) { + fast_dequant_f8f16x4<__hip_fp8_e5m2_fnuz, half>((uint32_t*)&src[i * 4], + (uint2*)&dst[i * 4]); + } } +#endif // FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED + } }; -template <> struct vec_cast -{ - template - FLASHINFER_INLINE static void cast(float *dst, const __hip_bfloat16 *src) - { - if constexpr (vec_size == 1) { - dst[0] = (float)src[0]; - } - else { +template <> +struct vec_cast { + template + FLASHINFER_INLINE static void cast(float* dst, const __hip_bfloat16* src) { + if constexpr (vec_size == 1) { + dst[0] = (float)src[0]; + } else { #pragma unroll - for (size_t i = 0; i < vec_size / 2; ++i) { - ((float2 *)dst)[i] = - __bfloat1622float2(((__hip_bfloat162 *)src)[i]); - } - } + for (size_t i = 0; i < vec_size / 2; ++i) { + ((float2*)dst)[i] = __bfloat1622float2(((__hip_bfloat162*)src)[i]); + } } + } }; -template <> struct vec_cast<__hip_bfloat16, float> -{ - template - FLASHINFER_INLINE static void cast(__hip_bfloat16 *dst, const float *src) - { - if constexpr (vec_size == 1) { - dst[0] = __hip_bfloat16(src[0]); - } - else { +template <> +struct vec_cast<__hip_bfloat16, float> { + template + FLASHINFER_INLINE static void cast(__hip_bfloat16* dst, const float* src) { + if constexpr (vec_size == 1) { + dst[0] = __hip_bfloat16(src[0]); + } else { #pragma unroll - for (size_t i = 0; i < vec_size / 2; ++i) { - ((__hip_bfloat162 *)dst)[i] = - __float22bfloat162_rn(((float2 *)src)[i]); - } - } + for (size_t i = 0; i < vec_size / 2; ++i) { + ((__hip_bfloat162*)dst)[i] = __float22bfloat162_rn(((float2*)src)[i]); + } } + } }; -template struct vec_t -{ - FLASHINFER_INLINE float_t &operator[](size_t i); - FLASHINFER_INLINE const float_t &operator[](size_t i) const; - FLASHINFER_INLINE void fill(float_t val); - FLASHINFER_INLINE void load(const float_t *ptr); - FLASHINFER_INLINE void store(float_t *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src); - template FLASHINFER_INLINE void cast_load(const T *ptr); - template FLASHINFER_INLINE void cast_store(T *ptr) const; - FLASHINFER_INLINE static void memcpy(float_t *dst, const float_t *src); - FLASHINFER_INLINE float_t *ptr(); +template +struct vec_t { + FLASHINFER_INLINE float_t& operator[](size_t i); + FLASHINFER_INLINE const float_t& operator[](size_t i) const; + FLASHINFER_INLINE void fill(float_t val); + FLASHINFER_INLINE void load(const float_t* ptr); + FLASHINFER_INLINE void store(float_t* ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t& src); + template + FLASHINFER_INLINE void cast_load(const T* ptr); + template + FLASHINFER_INLINE void cast_store(T* ptr) const; + FLASHINFER_INLINE static void memcpy(float_t* dst, const float_t* src); + FLASHINFER_INLINE float_t* ptr(); }; template -FLASHINFER_INLINE void cast_from_impl(vec_t &dst, - const vec_t &src) -{ - - vec_cast::template cast( - dst.ptr(), const_cast *>(&src)->ptr()); +FLASHINFER_INLINE void cast_from_impl(vec_t& dst, + const vec_t& src) { + vec_cast::template cast( + dst.ptr(), const_cast*>(&src)->ptr()); } template -FLASHINFER_INLINE void cast_load_impl(vec_t &dst, - const src_float_t *src_ptr) -{ - if constexpr (std::is_same_v) { - dst.load(src_ptr); - } - else { - vec_t tmp; - tmp.load(src_ptr); - dst.cast_from(tmp); - } +FLASHINFER_INLINE void cast_load_impl(vec_t& dst, + const src_float_t* src_ptr) { + if constexpr (std::is_same_v) { + dst.load(src_ptr); + } else { + vec_t tmp; + tmp.load(src_ptr); + dst.cast_from(tmp); + } } template -FLASHINFER_INLINE void cast_store_impl(tgt_float_t *dst_ptr, - const vec_t &src) -{ - if constexpr (std::is_same_v) { - src.store(dst_ptr); - } - else { - vec_t tmp; - tmp.cast_from(src); - tmp.store(dst_ptr); - } +FLASHINFER_INLINE void cast_store_impl(tgt_float_t* dst_ptr, + const vec_t& src) { + if constexpr (std::is_same_v) { + src.store(dst_ptr); + } else { + vec_t tmp; + tmp.cast_from(src); + tmp.store(dst_ptr); + } } /******************* vec_t<__hip_fp8_e4m3_fnuz> *******************/ // __hip_fp8_e4m3_fnuz x 1 -template <> struct vec_t<__hip_fp8_e4m3_fnuz, 1> -{ - __hip_fp8_e4m3_fnuz data; - - FLASHINFER_INLINE __hip_fp8_e4m3_fnuz &operator[](size_t i) - { - return ((__hip_fp8_e4m3_fnuz *)(&data))[i]; - } - FLASHINFER_INLINE const __hip_fp8_e4m3_fnuz &operator[](size_t i) const - { - return ((const __hip_fp8_e4m3_fnuz *)(&data))[i]; - } - FLASHINFER_INLINE __hip_fp8_e4m3_fnuz *ptr() - { - return reinterpret_cast<__hip_fp8_e4m3_fnuz *>(&data); - } - FLASHINFER_INLINE void fill(__hip_fp8_e4m3_fnuz val); - FLASHINFER_INLINE void load(const __hip_fp8_e4m3_fnuz *ptr); - FLASHINFER_INLINE void store(__hip_fp8_e4m3_fnuz *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) - { - cast_from_impl(*this, src); - } - template FLASHINFER_INLINE void cast_load(const T *ptr) - { - cast_load_impl(*this, ptr); - } - template FLASHINFER_INLINE void cast_store(T *ptr) const - { - cast_store_impl(ptr, *this); - } - - FLASHINFER_INLINE static void memcpy(__hip_fp8_e4m3_fnuz *dst, - const __hip_fp8_e4m3_fnuz *src); +template <> +struct vec_t<__hip_fp8_e4m3_fnuz, 1> { + __hip_fp8_e4m3_fnuz data; + + FLASHINFER_INLINE __hip_fp8_e4m3_fnuz& operator[](size_t i) { + return ((__hip_fp8_e4m3_fnuz*)(&data))[i]; + } + FLASHINFER_INLINE const __hip_fp8_e4m3_fnuz& operator[](size_t i) const { + return ((const __hip_fp8_e4m3_fnuz*)(&data))[i]; + } + FLASHINFER_INLINE __hip_fp8_e4m3_fnuz* ptr() { + return reinterpret_cast<__hip_fp8_e4m3_fnuz*>(&data); + } + FLASHINFER_INLINE void fill(__hip_fp8_e4m3_fnuz val); + FLASHINFER_INLINE void load(const __hip_fp8_e4m3_fnuz* ptr); + FLASHINFER_INLINE void store(__hip_fp8_e4m3_fnuz* ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + + FLASHINFER_INLINE static void memcpy(__hip_fp8_e4m3_fnuz* dst, const __hip_fp8_e4m3_fnuz* src); }; -FLASHINFER_INLINE void -vec_t<__hip_fp8_e4m3_fnuz, 1>::fill(__hip_fp8_e4m3_fnuz val) -{ - data = val; -} +FLASHINFER_INLINE void vec_t<__hip_fp8_e4m3_fnuz, 1>::fill(__hip_fp8_e4m3_fnuz val) { data = val; } -FLASHINFER_INLINE void -vec_t<__hip_fp8_e4m3_fnuz, 1>::load(const __hip_fp8_e4m3_fnuz *ptr) -{ - data = *ptr; +FLASHINFER_INLINE void vec_t<__hip_fp8_e4m3_fnuz, 1>::load(const __hip_fp8_e4m3_fnuz* ptr) { + data = *ptr; } -FLASHINFER_INLINE void -vec_t<__hip_fp8_e4m3_fnuz, 1>::store(__hip_fp8_e4m3_fnuz *ptr) const -{ - *ptr = data; +FLASHINFER_INLINE void vec_t<__hip_fp8_e4m3_fnuz, 1>::store(__hip_fp8_e4m3_fnuz* ptr) const { + *ptr = data; } -FLASHINFER_INLINE void -vec_t<__hip_fp8_e4m3_fnuz, 1>::memcpy(__hip_fp8_e4m3_fnuz *dst, - const __hip_fp8_e4m3_fnuz *src) -{ - *dst = *src; +FLASHINFER_INLINE void vec_t<__hip_fp8_e4m3_fnuz, 1>::memcpy(__hip_fp8_e4m3_fnuz* dst, + const __hip_fp8_e4m3_fnuz* src) { + *dst = *src; } // __hip_fp8_e4m3_fnuz x 2 -template <> struct vec_t<__hip_fp8_e4m3_fnuz, 2> -{ - __hip_fp8x2_e4m3_fnuz data; - - FLASHINFER_INLINE __hip_fp8_e4m3_fnuz &operator[](size_t i) - { - return ((__hip_fp8_e4m3_fnuz *)(&data))[i]; - } - FLASHINFER_INLINE const __hip_fp8_e4m3_fnuz &operator[](size_t i) const - { - return ((const __hip_fp8_e4m3_fnuz *)(&data))[i]; - } - FLASHINFER_INLINE __hip_fp8_e4m3_fnuz *ptr() - { - return reinterpret_cast<__hip_fp8_e4m3_fnuz *>(&data); - } - FLASHINFER_INLINE void fill(__hip_fp8_e4m3_fnuz val); - FLASHINFER_INLINE void load(const __hip_fp8_e4m3_fnuz *ptr); - FLASHINFER_INLINE void store(__hip_fp8_e4m3_fnuz *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) - { - cast_from_impl(*this, src); - } - template FLASHINFER_INLINE void cast_load(const T *ptr) - { - cast_load_impl(*this, ptr); - } - template FLASHINFER_INLINE void cast_store(T *ptr) const - { - cast_store_impl(ptr, *this); - } - FLASHINFER_INLINE static void memcpy(__hip_fp8_e4m3_fnuz *dst, - const __hip_fp8_e4m3_fnuz *src); +template <> +struct vec_t<__hip_fp8_e4m3_fnuz, 2> { + __hip_fp8x2_e4m3_fnuz data; + + FLASHINFER_INLINE __hip_fp8_e4m3_fnuz& operator[](size_t i) { + return ((__hip_fp8_e4m3_fnuz*)(&data))[i]; + } + FLASHINFER_INLINE const __hip_fp8_e4m3_fnuz& operator[](size_t i) const { + return ((const __hip_fp8_e4m3_fnuz*)(&data))[i]; + } + FLASHINFER_INLINE __hip_fp8_e4m3_fnuz* ptr() { + return reinterpret_cast<__hip_fp8_e4m3_fnuz*>(&data); + } + FLASHINFER_INLINE void fill(__hip_fp8_e4m3_fnuz val); + FLASHINFER_INLINE void load(const __hip_fp8_e4m3_fnuz* ptr); + FLASHINFER_INLINE void store(__hip_fp8_e4m3_fnuz* ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + FLASHINFER_INLINE static void memcpy(__hip_fp8_e4m3_fnuz* dst, const __hip_fp8_e4m3_fnuz* src); }; -FLASHINFER_INLINE void -vec_t<__hip_fp8_e4m3_fnuz, 2>::fill(__hip_fp8_e4m3_fnuz val) -{ - data.__x = - (__hip_fp8x2_storage_t(val.__x) << 8) | __hip_fp8x2_storage_t(val.__x); +FLASHINFER_INLINE void vec_t<__hip_fp8_e4m3_fnuz, 2>::fill(__hip_fp8_e4m3_fnuz val) { + data.__x = (__hip_fp8x2_storage_t(val.__x) << 8) | __hip_fp8x2_storage_t(val.__x); } -FLASHINFER_INLINE void -vec_t<__hip_fp8_e4m3_fnuz, 2>::load(const __hip_fp8_e4m3_fnuz *ptr) -{ - data = *((__hip_fp8x2_e4m3_fnuz *)ptr); +FLASHINFER_INLINE void vec_t<__hip_fp8_e4m3_fnuz, 2>::load(const __hip_fp8_e4m3_fnuz* ptr) { + data = *((__hip_fp8x2_e4m3_fnuz*)ptr); } -FLASHINFER_INLINE void -vec_t<__hip_fp8_e4m3_fnuz, 2>::store(__hip_fp8_e4m3_fnuz *ptr) const -{ - *((__hip_fp8x2_e4m3_fnuz *)ptr) = data; +FLASHINFER_INLINE void vec_t<__hip_fp8_e4m3_fnuz, 2>::store(__hip_fp8_e4m3_fnuz* ptr) const { + *((__hip_fp8x2_e4m3_fnuz*)ptr) = data; } -FLASHINFER_INLINE void -vec_t<__hip_fp8_e4m3_fnuz, 2>::memcpy(__hip_fp8_e4m3_fnuz *dst, - const __hip_fp8_e4m3_fnuz *src) -{ - *((__hip_fp8x2_e4m3_fnuz *)dst) = *((__hip_fp8x2_e4m3_fnuz *)src); +FLASHINFER_INLINE void vec_t<__hip_fp8_e4m3_fnuz, 2>::memcpy(__hip_fp8_e4m3_fnuz* dst, + const __hip_fp8_e4m3_fnuz* src) { + *((__hip_fp8x2_e4m3_fnuz*)dst) = *((__hip_fp8x2_e4m3_fnuz*)src); } // __hip_fp8_e4m3_fnuz x 4 -template <> struct vec_t<__hip_fp8_e4m3_fnuz, 4> -{ - __hip_fp8x4_e4m3_fnuz data; - - FLASHINFER_INLINE __hip_fp8_e4m3_fnuz &operator[](size_t i) - { - return ((__hip_fp8_e4m3_fnuz *)(&data))[i]; - } - FLASHINFER_INLINE const __hip_fp8_e4m3_fnuz &operator[](size_t i) const - { - return ((const __hip_fp8_e4m3_fnuz *)(&data))[i]; - } - FLASHINFER_INLINE __hip_fp8_e4m3_fnuz *ptr() - { - return reinterpret_cast<__hip_fp8_e4m3_fnuz *>(&data); - } - FLASHINFER_INLINE void fill(__hip_fp8_e4m3_fnuz val); - FLASHINFER_INLINE void load(const __hip_fp8_e4m3_fnuz *ptr); - FLASHINFER_INLINE void store(__hip_fp8_e4m3_fnuz *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) - { - cast_from_impl(*this, src); - } - template FLASHINFER_INLINE void cast_load(const T *ptr) - { - cast_load_impl(*this, ptr); - } - template FLASHINFER_INLINE void cast_store(T *ptr) const - { - cast_store_impl(ptr, *this); - } - - FLASHINFER_INLINE static void memcpy(__hip_fp8_e4m3_fnuz *dst, - const __hip_fp8_e4m3_fnuz *src); +template <> +struct vec_t<__hip_fp8_e4m3_fnuz, 4> { + __hip_fp8x4_e4m3_fnuz data; + + FLASHINFER_INLINE __hip_fp8_e4m3_fnuz& operator[](size_t i) { + return ((__hip_fp8_e4m3_fnuz*)(&data))[i]; + } + FLASHINFER_INLINE const __hip_fp8_e4m3_fnuz& operator[](size_t i) const { + return ((const __hip_fp8_e4m3_fnuz*)(&data))[i]; + } + FLASHINFER_INLINE __hip_fp8_e4m3_fnuz* ptr() { + return reinterpret_cast<__hip_fp8_e4m3_fnuz*>(&data); + } + FLASHINFER_INLINE void fill(__hip_fp8_e4m3_fnuz val); + FLASHINFER_INLINE void load(const __hip_fp8_e4m3_fnuz* ptr); + FLASHINFER_INLINE void store(__hip_fp8_e4m3_fnuz* ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + + FLASHINFER_INLINE static void memcpy(__hip_fp8_e4m3_fnuz* dst, const __hip_fp8_e4m3_fnuz* src); }; -FLASHINFER_INLINE void -vec_t<__hip_fp8_e4m3_fnuz, 4>::fill(__hip_fp8_e4m3_fnuz val) -{ - data.__x = (__hip_fp8x4_storage_t(val.__x) << 24) | - (__hip_fp8x4_storage_t(val.__x) << 16) | - (__hip_fp8x4_storage_t(val.__x) << 8) | - __hip_fp8x4_storage_t(val.__x); +FLASHINFER_INLINE void vec_t<__hip_fp8_e4m3_fnuz, 4>::fill(__hip_fp8_e4m3_fnuz val) { + data.__x = (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | + (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); } -FLASHINFER_INLINE void -vec_t<__hip_fp8_e4m3_fnuz, 4>::load(const __hip_fp8_e4m3_fnuz *ptr) -{ - data = *((__hip_fp8x4_e4m3_fnuz *)ptr); +FLASHINFER_INLINE void vec_t<__hip_fp8_e4m3_fnuz, 4>::load(const __hip_fp8_e4m3_fnuz* ptr) { + data = *((__hip_fp8x4_e4m3_fnuz*)ptr); } -FLASHINFER_INLINE void -vec_t<__hip_fp8_e4m3_fnuz, 4>::store(__hip_fp8_e4m3_fnuz *ptr) const -{ - *((__hip_fp8x4_e4m3_fnuz *)ptr) = data; +FLASHINFER_INLINE void vec_t<__hip_fp8_e4m3_fnuz, 4>::store(__hip_fp8_e4m3_fnuz* ptr) const { + *((__hip_fp8x4_e4m3_fnuz*)ptr) = data; } -FLASHINFER_INLINE void -vec_t<__hip_fp8_e4m3_fnuz, 4>::memcpy(__hip_fp8_e4m3_fnuz *dst, - const __hip_fp8_e4m3_fnuz *src) -{ - *((__hip_fp8x4_e4m3_fnuz *)dst) = *((__hip_fp8x4_e4m3_fnuz *)src); +FLASHINFER_INLINE void vec_t<__hip_fp8_e4m3_fnuz, 4>::memcpy(__hip_fp8_e4m3_fnuz* dst, + const __hip_fp8_e4m3_fnuz* src) { + *((__hip_fp8x4_e4m3_fnuz*)dst) = *((__hip_fp8x4_e4m3_fnuz*)src); } // __hip_fp8_e4m3_fnuz x 8 -template <> struct vec_t<__hip_fp8_e4m3_fnuz, 8> -{ - uint2 data; - - FLASHINFER_INLINE __hip_fp8_e4m3_fnuz &operator[](size_t i) - { - return ((__hip_fp8_e4m3_fnuz *)(&data))[i]; - } - FLASHINFER_INLINE const __hip_fp8_e4m3_fnuz &operator[](size_t i) const - { - return ((const __hip_fp8_e4m3_fnuz *)(&data))[i]; - } - FLASHINFER_INLINE __hip_fp8_e4m3_fnuz *ptr() - { - return reinterpret_cast<__hip_fp8_e4m3_fnuz *>(&data); - } - FLASHINFER_INLINE void fill(__hip_fp8_e4m3_fnuz val); - FLASHINFER_INLINE void load(const __hip_fp8_e4m3_fnuz *ptr); - FLASHINFER_INLINE void store(__hip_fp8_e4m3_fnuz *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) - { - cast_from_impl(*this, src); - } - template FLASHINFER_INLINE void cast_load(const T *ptr) - { - cast_load_impl(*this, ptr); - } - template FLASHINFER_INLINE void cast_store(T *ptr) const - { - cast_store_impl(ptr, *this); - } - - FLASHINFER_INLINE static void memcpy(__hip_fp8_e4m3_fnuz *dst, - const __hip_fp8_e4m3_fnuz *src); +template <> +struct vec_t<__hip_fp8_e4m3_fnuz, 8> { + uint2 data; + + FLASHINFER_INLINE __hip_fp8_e4m3_fnuz& operator[](size_t i) { + return ((__hip_fp8_e4m3_fnuz*)(&data))[i]; + } + FLASHINFER_INLINE const __hip_fp8_e4m3_fnuz& operator[](size_t i) const { + return ((const __hip_fp8_e4m3_fnuz*)(&data))[i]; + } + FLASHINFER_INLINE __hip_fp8_e4m3_fnuz* ptr() { + return reinterpret_cast<__hip_fp8_e4m3_fnuz*>(&data); + } + FLASHINFER_INLINE void fill(__hip_fp8_e4m3_fnuz val); + FLASHINFER_INLINE void load(const __hip_fp8_e4m3_fnuz* ptr); + FLASHINFER_INLINE void store(__hip_fp8_e4m3_fnuz* ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + + FLASHINFER_INLINE static void memcpy(__hip_fp8_e4m3_fnuz* dst, const __hip_fp8_e4m3_fnuz* src); }; -FLASHINFER_INLINE void -vec_t<__hip_fp8_e4m3_fnuz, 8>::fill(__hip_fp8_e4m3_fnuz val) -{ - ((__hip_fp8x4_e4m3_fnuz *)(&data.x))->__x = - (__hip_fp8x4_storage_t(val.__x) << 24) | - (__hip_fp8x4_storage_t(val.__x) << 16) | - (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); - ((__hip_fp8x4_e4m3_fnuz *)(&data.y))->__x = - (__hip_fp8x4_storage_t(val.__x) << 24) | - (__hip_fp8x4_storage_t(val.__x) << 16) | - (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); +FLASHINFER_INLINE void vec_t<__hip_fp8_e4m3_fnuz, 8>::fill(__hip_fp8_e4m3_fnuz val) { + ((__hip_fp8x4_e4m3_fnuz*)(&data.x))->__x = + (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | + (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); + ((__hip_fp8x4_e4m3_fnuz*)(&data.y))->__x = + (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | + (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); } -FLASHINFER_INLINE void -vec_t<__hip_fp8_e4m3_fnuz, 8>::load(const __hip_fp8_e4m3_fnuz *ptr) -{ - data = *((uint2 *)ptr); +FLASHINFER_INLINE void vec_t<__hip_fp8_e4m3_fnuz, 8>::load(const __hip_fp8_e4m3_fnuz* ptr) { + data = *((uint2*)ptr); } -FLASHINFER_INLINE void -vec_t<__hip_fp8_e4m3_fnuz, 8>::store(__hip_fp8_e4m3_fnuz *ptr) const -{ - *((uint2 *)ptr) = data; +FLASHINFER_INLINE void vec_t<__hip_fp8_e4m3_fnuz, 8>::store(__hip_fp8_e4m3_fnuz* ptr) const { + *((uint2*)ptr) = data; } -FLASHINFER_INLINE void -vec_t<__hip_fp8_e4m3_fnuz, 8>::memcpy(__hip_fp8_e4m3_fnuz *dst, - const __hip_fp8_e4m3_fnuz *src) -{ - *((uint2 *)dst) = *((uint2 *)src); +FLASHINFER_INLINE void vec_t<__hip_fp8_e4m3_fnuz, 8>::memcpy(__hip_fp8_e4m3_fnuz* dst, + const __hip_fp8_e4m3_fnuz* src) { + *((uint2*)dst) = *((uint2*)src); } // __hip_fp8_e4m3_fnuz x 16 or more -template struct vec_t<__hip_fp8_e4m3_fnuz, vec_size> -{ - uint4 data[vec_size / 16]; - - FLASHINFER_INLINE __hip_fp8_e4m3_fnuz &operator[](size_t i) - { - return ((__hip_fp8_e4m3_fnuz *)data)[i]; - } - FLASHINFER_INLINE const __hip_fp8_e4m3_fnuz &operator[](size_t i) const - { - return ((const __hip_fp8_e4m3_fnuz *)data)[i]; - } - FLASHINFER_INLINE __hip_fp8_e4m3_fnuz *ptr() - { - return reinterpret_cast<__hip_fp8_e4m3_fnuz *>(&data); - } - FLASHINFER_INLINE void fill(__hip_fp8_e4m3_fnuz val) - { +template +struct vec_t<__hip_fp8_e4m3_fnuz, vec_size> { + uint4 data[vec_size / 16]; + + FLASHINFER_INLINE __hip_fp8_e4m3_fnuz& operator[](size_t i) { + return ((__hip_fp8_e4m3_fnuz*)data)[i]; + } + FLASHINFER_INLINE const __hip_fp8_e4m3_fnuz& operator[](size_t i) const { + return ((const __hip_fp8_e4m3_fnuz*)data)[i]; + } + FLASHINFER_INLINE __hip_fp8_e4m3_fnuz* ptr() { + return reinterpret_cast<__hip_fp8_e4m3_fnuz*>(&data); + } + FLASHINFER_INLINE void fill(__hip_fp8_e4m3_fnuz val) { #pragma unroll - for (size_t i = 0; i < vec_size / 16; ++i) { - ((__hip_fp8x4_e4m3_fnuz *)(&(data[i].x)))->__x = - (__hip_fp8x4_storage_t(val.__x) << 24) | - (__hip_fp8x4_storage_t(val.__x) << 16) | - (__hip_fp8x4_storage_t(val.__x) << 8) | - __hip_fp8x4_storage_t(val.__x); - ((__hip_fp8x4_e4m3_fnuz *)(&(data[i].y)))->__x = - (__hip_fp8x4_storage_t(val.__x) << 24) | - (__hip_fp8x4_storage_t(val.__x) << 16) | - (__hip_fp8x4_storage_t(val.__x) << 8) | - __hip_fp8x4_storage_t(val.__x); - ((__hip_fp8x4_e4m3_fnuz *)(&(data[i].z)))->__x = - (__hip_fp8x4_storage_t(val.__x) << 24) | - (__hip_fp8x4_storage_t(val.__x) << 16) | - (__hip_fp8x4_storage_t(val.__x) << 8) | - __hip_fp8x4_storage_t(val.__x); - ((__hip_fp8x4_e4m3_fnuz *)(&(data[i].w)))->__x = - (__hip_fp8x4_storage_t(val.__x) << 24) | - (__hip_fp8x4_storage_t(val.__x) << 16) | - (__hip_fp8x4_storage_t(val.__x) << 8) | - __hip_fp8x4_storage_t(val.__x); - } - } - FLASHINFER_INLINE void load(const __hip_fp8_e4m3_fnuz *ptr) - { + for (size_t i = 0; i < vec_size / 16; ++i) { + ((__hip_fp8x4_e4m3_fnuz*)(&(data[i].x)))->__x = + (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | + (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); + ((__hip_fp8x4_e4m3_fnuz*)(&(data[i].y)))->__x = + (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | + (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); + ((__hip_fp8x4_e4m3_fnuz*)(&(data[i].z)))->__x = + (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | + (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); + ((__hip_fp8x4_e4m3_fnuz*)(&(data[i].w)))->__x = + (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | + (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); + } + } + FLASHINFER_INLINE void load(const __hip_fp8_e4m3_fnuz* ptr) { #pragma unroll - for (size_t i = 0; i < vec_size / 16; ++i) { - data[i] = ((uint4 *)ptr)[i]; - } + for (size_t i = 0; i < vec_size / 16; ++i) { + data[i] = ((uint4*)ptr)[i]; } - FLASHINFER_INLINE void store(__hip_fp8_e4m3_fnuz *ptr) const - { + } + FLASHINFER_INLINE void store(__hip_fp8_e4m3_fnuz* ptr) const { #pragma unroll - for (size_t i = 0; i < vec_size / 16; ++i) { - ((uint4 *)ptr)[i] = data[i]; - } - } - template - FLASHINFER_INLINE void cast_from(const vec_t &src) - { - cast_from_impl(*this, src); - } - template FLASHINFER_INLINE void cast_load(const T *ptr) - { - cast_load_impl(*this, ptr); - } - template FLASHINFER_INLINE void cast_store(T *ptr) const - { - cast_store_impl(ptr, *this); - } - - FLASHINFER_INLINE static void memcpy(__hip_fp8_e4m3_fnuz *dst, - const __hip_fp8_e4m3_fnuz *src) - { + for (size_t i = 0; i < vec_size / 16; ++i) { + ((uint4*)ptr)[i] = data[i]; + } + } + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + + FLASHINFER_INLINE static void memcpy(__hip_fp8_e4m3_fnuz* dst, const __hip_fp8_e4m3_fnuz* src) { #pragma unroll - for (size_t i = 0; i < vec_size / 16; ++i) { - ((uint4 *)dst)[i] = ((uint4 *)src)[i]; - } + for (size_t i = 0; i < vec_size / 16; ++i) { + ((uint4*)dst)[i] = ((uint4*)src)[i]; } + } }; /******************* vec_t<__hip_fp8_e5m2_fnuz> *******************/ // __hip_fp8_e5m2_fnuz x 1 -template <> struct vec_t<__hip_fp8_e5m2_fnuz, 1> -{ - __hip_fp8_e5m2_fnuz data; - - FLASHINFER_INLINE __hip_fp8_e5m2_fnuz &operator[](size_t i) - { - return ((__hip_fp8_e5m2_fnuz *)(&data))[i]; - } - FLASHINFER_INLINE const __hip_fp8_e5m2_fnuz &operator[](size_t i) const - { - return ((const __hip_fp8_e5m2_fnuz *)(&data))[i]; - } - FLASHINFER_INLINE __hip_fp8_e5m2_fnuz *ptr() - { - return reinterpret_cast<__hip_fp8_e5m2_fnuz *>(&data); - } - FLASHINFER_INLINE void fill(__hip_fp8_e5m2_fnuz val); - FLASHINFER_INLINE void load(const __hip_fp8_e5m2_fnuz *ptr); - FLASHINFER_INLINE void store(__hip_fp8_e5m2_fnuz *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) - { - cast_from_impl(*this, src); - } - template FLASHINFER_INLINE void cast_load(const T *ptr) - { - cast_load_impl(*this, ptr); - } - template FLASHINFER_INLINE void cast_store(T *ptr) const - { - cast_store_impl(ptr, *this); - } - - FLASHINFER_INLINE static void memcpy(__hip_fp8_e5m2_fnuz *dst, - const __hip_fp8_e5m2_fnuz *src); +template <> +struct vec_t<__hip_fp8_e5m2_fnuz, 1> { + __hip_fp8_e5m2_fnuz data; + + FLASHINFER_INLINE __hip_fp8_e5m2_fnuz& operator[](size_t i) { + return ((__hip_fp8_e5m2_fnuz*)(&data))[i]; + } + FLASHINFER_INLINE const __hip_fp8_e5m2_fnuz& operator[](size_t i) const { + return ((const __hip_fp8_e5m2_fnuz*)(&data))[i]; + } + FLASHINFER_INLINE __hip_fp8_e5m2_fnuz* ptr() { + return reinterpret_cast<__hip_fp8_e5m2_fnuz*>(&data); + } + FLASHINFER_INLINE void fill(__hip_fp8_e5m2_fnuz val); + FLASHINFER_INLINE void load(const __hip_fp8_e5m2_fnuz* ptr); + FLASHINFER_INLINE void store(__hip_fp8_e5m2_fnuz* ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + + FLASHINFER_INLINE static void memcpy(__hip_fp8_e5m2_fnuz* dst, const __hip_fp8_e5m2_fnuz* src); }; -FLASHINFER_INLINE void -vec_t<__hip_fp8_e5m2_fnuz, 1>::fill(__hip_fp8_e5m2_fnuz val) -{ - data = val; -} +FLASHINFER_INLINE void vec_t<__hip_fp8_e5m2_fnuz, 1>::fill(__hip_fp8_e5m2_fnuz val) { data = val; } -FLASHINFER_INLINE void -vec_t<__hip_fp8_e5m2_fnuz, 1>::load(const __hip_fp8_e5m2_fnuz *ptr) -{ - data = *ptr; +FLASHINFER_INLINE void vec_t<__hip_fp8_e5m2_fnuz, 1>::load(const __hip_fp8_e5m2_fnuz* ptr) { + data = *ptr; } -FLASHINFER_INLINE void -vec_t<__hip_fp8_e5m2_fnuz, 1>::store(__hip_fp8_e5m2_fnuz *ptr) const -{ - *ptr = data; +FLASHINFER_INLINE void vec_t<__hip_fp8_e5m2_fnuz, 1>::store(__hip_fp8_e5m2_fnuz* ptr) const { + *ptr = data; } -FLASHINFER_INLINE void -vec_t<__hip_fp8_e5m2_fnuz, 1>::memcpy(__hip_fp8_e5m2_fnuz *dst, - const __hip_fp8_e5m2_fnuz *src) -{ - *dst = *src; +FLASHINFER_INLINE void vec_t<__hip_fp8_e5m2_fnuz, 1>::memcpy(__hip_fp8_e5m2_fnuz* dst, + const __hip_fp8_e5m2_fnuz* src) { + *dst = *src; } // __hip_fp8_e5m2_fnuz x 2 -template <> struct vec_t<__hip_fp8_e5m2_fnuz, 2> -{ - __hip_fp8x2_e5m2_fnuz data; - - FLASHINFER_INLINE __hip_fp8_e5m2_fnuz &operator[](size_t i) - { - return ((__hip_fp8_e5m2_fnuz *)(&data))[i]; - } - FLASHINFER_INLINE const __hip_fp8_e5m2_fnuz &operator[](size_t i) const - { - return ((const __hip_fp8_e5m2_fnuz *)(&data))[i]; - } - FLASHINFER_INLINE __hip_fp8_e5m2_fnuz *ptr() - { - return reinterpret_cast<__hip_fp8_e5m2_fnuz *>(&data); - } - FLASHINFER_INLINE void fill(__hip_fp8_e5m2_fnuz val); - FLASHINFER_INLINE void load(const __hip_fp8_e5m2_fnuz *ptr); - FLASHINFER_INLINE void store(__hip_fp8_e5m2_fnuz *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) - { - cast_from_impl(*this, src); - } - template FLASHINFER_INLINE void cast_load(const T *ptr) - { - cast_load_impl(*this, ptr); - } - template FLASHINFER_INLINE void cast_store(T *ptr) const - { - cast_store_impl(ptr, *this); - } - - FLASHINFER_INLINE static void memcpy(__hip_fp8_e5m2_fnuz *dst, - const __hip_fp8_e5m2_fnuz *src); +template <> +struct vec_t<__hip_fp8_e5m2_fnuz, 2> { + __hip_fp8x2_e5m2_fnuz data; + + FLASHINFER_INLINE __hip_fp8_e5m2_fnuz& operator[](size_t i) { + return ((__hip_fp8_e5m2_fnuz*)(&data))[i]; + } + FLASHINFER_INLINE const __hip_fp8_e5m2_fnuz& operator[](size_t i) const { + return ((const __hip_fp8_e5m2_fnuz*)(&data))[i]; + } + FLASHINFER_INLINE __hip_fp8_e5m2_fnuz* ptr() { + return reinterpret_cast<__hip_fp8_e5m2_fnuz*>(&data); + } + FLASHINFER_INLINE void fill(__hip_fp8_e5m2_fnuz val); + FLASHINFER_INLINE void load(const __hip_fp8_e5m2_fnuz* ptr); + FLASHINFER_INLINE void store(__hip_fp8_e5m2_fnuz* ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + + FLASHINFER_INLINE static void memcpy(__hip_fp8_e5m2_fnuz* dst, const __hip_fp8_e5m2_fnuz* src); }; -FLASHINFER_INLINE void -vec_t<__hip_fp8_e5m2_fnuz, 2>::fill(__hip_fp8_e5m2_fnuz val) -{ - data.__x = - (__hip_fp8x2_storage_t(val.__x) << 8) | __hip_fp8x2_storage_t(val.__x); +FLASHINFER_INLINE void vec_t<__hip_fp8_e5m2_fnuz, 2>::fill(__hip_fp8_e5m2_fnuz val) { + data.__x = (__hip_fp8x2_storage_t(val.__x) << 8) | __hip_fp8x2_storage_t(val.__x); } -FLASHINFER_INLINE void -vec_t<__hip_fp8_e5m2_fnuz, 2>::load(const __hip_fp8_e5m2_fnuz *ptr) -{ - data = *((__hip_fp8x2_e5m2_fnuz *)ptr); +FLASHINFER_INLINE void vec_t<__hip_fp8_e5m2_fnuz, 2>::load(const __hip_fp8_e5m2_fnuz* ptr) { + data = *((__hip_fp8x2_e5m2_fnuz*)ptr); } -FLASHINFER_INLINE void -vec_t<__hip_fp8_e5m2_fnuz, 2>::store(__hip_fp8_e5m2_fnuz *ptr) const -{ - *((__hip_fp8x2_e5m2_fnuz *)ptr) = data; +FLASHINFER_INLINE void vec_t<__hip_fp8_e5m2_fnuz, 2>::store(__hip_fp8_e5m2_fnuz* ptr) const { + *((__hip_fp8x2_e5m2_fnuz*)ptr) = data; } -FLASHINFER_INLINE void -vec_t<__hip_fp8_e5m2_fnuz, 2>::memcpy(__hip_fp8_e5m2_fnuz *dst, - const __hip_fp8_e5m2_fnuz *src) -{ - *((__hip_fp8x2_e5m2_fnuz *)dst) = *((__hip_fp8x2_e5m2_fnuz *)src); +FLASHINFER_INLINE void vec_t<__hip_fp8_e5m2_fnuz, 2>::memcpy(__hip_fp8_e5m2_fnuz* dst, + const __hip_fp8_e5m2_fnuz* src) { + *((__hip_fp8x2_e5m2_fnuz*)dst) = *((__hip_fp8x2_e5m2_fnuz*)src); } // __hip_fp8_e5m2_fnuz x 4 -template <> struct vec_t<__hip_fp8_e5m2_fnuz, 4> -{ - __hip_fp8x4_e5m2_fnuz data; - - FLASHINFER_INLINE __hip_fp8_e5m2_fnuz &operator[](size_t i) - { - return ((__hip_fp8_e5m2_fnuz *)(&data))[i]; - } - FLASHINFER_INLINE const __hip_fp8_e5m2_fnuz &operator[](size_t i) const - { - return ((const __hip_fp8_e5m2_fnuz *)(&data))[i]; - } - FLASHINFER_INLINE __hip_fp8_e5m2_fnuz *ptr() - { - return reinterpret_cast<__hip_fp8_e5m2_fnuz *>(&data); - } - FLASHINFER_INLINE void fill(__hip_fp8_e5m2_fnuz val); - FLASHINFER_INLINE void load(const __hip_fp8_e5m2_fnuz *ptr); - FLASHINFER_INLINE void store(__hip_fp8_e5m2_fnuz *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) - { - cast_from_impl(*this, src); - } - template FLASHINFER_INLINE void cast_load(const T *ptr) - { - cast_load_impl(*this, ptr); - } - template FLASHINFER_INLINE void cast_store(T *ptr) const - { - cast_store_impl(ptr, *this); - } - - FLASHINFER_INLINE static void memcpy(__hip_fp8_e5m2_fnuz *dst, - const __hip_fp8_e5m2_fnuz *src); +template <> +struct vec_t<__hip_fp8_e5m2_fnuz, 4> { + __hip_fp8x4_e5m2_fnuz data; + + FLASHINFER_INLINE __hip_fp8_e5m2_fnuz& operator[](size_t i) { + return ((__hip_fp8_e5m2_fnuz*)(&data))[i]; + } + FLASHINFER_INLINE const __hip_fp8_e5m2_fnuz& operator[](size_t i) const { + return ((const __hip_fp8_e5m2_fnuz*)(&data))[i]; + } + FLASHINFER_INLINE __hip_fp8_e5m2_fnuz* ptr() { + return reinterpret_cast<__hip_fp8_e5m2_fnuz*>(&data); + } + FLASHINFER_INLINE void fill(__hip_fp8_e5m2_fnuz val); + FLASHINFER_INLINE void load(const __hip_fp8_e5m2_fnuz* ptr); + FLASHINFER_INLINE void store(__hip_fp8_e5m2_fnuz* ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + + FLASHINFER_INLINE static void memcpy(__hip_fp8_e5m2_fnuz* dst, const __hip_fp8_e5m2_fnuz* src); }; -FLASHINFER_INLINE void -vec_t<__hip_fp8_e5m2_fnuz, 4>::fill(__hip_fp8_e5m2_fnuz val) -{ - data.__x = (__hip_fp8x4_storage_t(val.__x) << 24) | - (__hip_fp8x4_storage_t(val.__x) << 16) | - (__hip_fp8x4_storage_t(val.__x) << 8) | - __hip_fp8x4_storage_t(val.__x); +FLASHINFER_INLINE void vec_t<__hip_fp8_e5m2_fnuz, 4>::fill(__hip_fp8_e5m2_fnuz val) { + data.__x = (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | + (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); } -FLASHINFER_INLINE void -vec_t<__hip_fp8_e5m2_fnuz, 4>::load(const __hip_fp8_e5m2_fnuz *ptr) -{ - data = *((__hip_fp8x4_e5m2_fnuz *)ptr); +FLASHINFER_INLINE void vec_t<__hip_fp8_e5m2_fnuz, 4>::load(const __hip_fp8_e5m2_fnuz* ptr) { + data = *((__hip_fp8x4_e5m2_fnuz*)ptr); } -FLASHINFER_INLINE void -vec_t<__hip_fp8_e5m2_fnuz, 4>::store(__hip_fp8_e5m2_fnuz *ptr) const -{ - *((__hip_fp8x4_e5m2_fnuz *)ptr) = data; +FLASHINFER_INLINE void vec_t<__hip_fp8_e5m2_fnuz, 4>::store(__hip_fp8_e5m2_fnuz* ptr) const { + *((__hip_fp8x4_e5m2_fnuz*)ptr) = data; } -FLASHINFER_INLINE void -vec_t<__hip_fp8_e5m2_fnuz, 4>::memcpy(__hip_fp8_e5m2_fnuz *dst, - const __hip_fp8_e5m2_fnuz *src) -{ - *((__hip_fp8x4_e5m2_fnuz *)dst) = *((__hip_fp8x4_e5m2_fnuz *)src); +FLASHINFER_INLINE void vec_t<__hip_fp8_e5m2_fnuz, 4>::memcpy(__hip_fp8_e5m2_fnuz* dst, + const __hip_fp8_e5m2_fnuz* src) { + *((__hip_fp8x4_e5m2_fnuz*)dst) = *((__hip_fp8x4_e5m2_fnuz*)src); } // __hip_fp8_e5m2_fnuz x 8 -template <> struct vec_t<__hip_fp8_e5m2_fnuz, 8> -{ - uint2 data; - - FLASHINFER_INLINE __hip_fp8_e5m2_fnuz &operator[](size_t i) - { - return ((__hip_fp8_e5m2_fnuz *)(&data))[i]; - } - FLASHINFER_INLINE const __hip_fp8_e5m2_fnuz &operator[](size_t i) const - { - return ((const __hip_fp8_e5m2_fnuz *)(&data))[i]; - } - FLASHINFER_INLINE __hip_fp8_e5m2_fnuz *ptr() - { - return reinterpret_cast<__hip_fp8_e5m2_fnuz *>(&data); - } - FLASHINFER_INLINE void fill(__hip_fp8_e5m2_fnuz val); - FLASHINFER_INLINE void load(const __hip_fp8_e5m2_fnuz *ptr); - FLASHINFER_INLINE void store(__hip_fp8_e5m2_fnuz *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) - { - cast_from_impl(*this, src); - } - template FLASHINFER_INLINE void cast_load(const T *ptr) - { - cast_load_impl(*this, ptr); - } - template FLASHINFER_INLINE void cast_store(T *ptr) const - { - cast_store_impl(ptr, *this); - } - FLASHINFER_INLINE static void memcpy(__hip_fp8_e5m2_fnuz *dst, - const __hip_fp8_e5m2_fnuz *src); +template <> +struct vec_t<__hip_fp8_e5m2_fnuz, 8> { + uint2 data; + + FLASHINFER_INLINE __hip_fp8_e5m2_fnuz& operator[](size_t i) { + return ((__hip_fp8_e5m2_fnuz*)(&data))[i]; + } + FLASHINFER_INLINE const __hip_fp8_e5m2_fnuz& operator[](size_t i) const { + return ((const __hip_fp8_e5m2_fnuz*)(&data))[i]; + } + FLASHINFER_INLINE __hip_fp8_e5m2_fnuz* ptr() { + return reinterpret_cast<__hip_fp8_e5m2_fnuz*>(&data); + } + FLASHINFER_INLINE void fill(__hip_fp8_e5m2_fnuz val); + FLASHINFER_INLINE void load(const __hip_fp8_e5m2_fnuz* ptr); + FLASHINFER_INLINE void store(__hip_fp8_e5m2_fnuz* ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + FLASHINFER_INLINE static void memcpy(__hip_fp8_e5m2_fnuz* dst, const __hip_fp8_e5m2_fnuz* src); }; -FLASHINFER_INLINE void -vec_t<__hip_fp8_e5m2_fnuz, 8>::fill(__hip_fp8_e5m2_fnuz val) -{ - ((__hip_fp8x4_e5m2_fnuz *)(&data.x))->__x = - (__hip_fp8x4_storage_t(val.__x) << 24) | - (__hip_fp8x4_storage_t(val.__x) << 16) | - (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); - ((__hip_fp8x4_e5m2_fnuz *)(&data.y))->__x = - (__hip_fp8x4_storage_t(val.__x) << 24) | - (__hip_fp8x4_storage_t(val.__x) << 16) | - (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); +FLASHINFER_INLINE void vec_t<__hip_fp8_e5m2_fnuz, 8>::fill(__hip_fp8_e5m2_fnuz val) { + ((__hip_fp8x4_e5m2_fnuz*)(&data.x))->__x = + (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | + (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); + ((__hip_fp8x4_e5m2_fnuz*)(&data.y))->__x = + (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | + (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); } -FLASHINFER_INLINE void -vec_t<__hip_fp8_e5m2_fnuz, 8>::load(const __hip_fp8_e5m2_fnuz *ptr) -{ - data = *((uint2 *)ptr); +FLASHINFER_INLINE void vec_t<__hip_fp8_e5m2_fnuz, 8>::load(const __hip_fp8_e5m2_fnuz* ptr) { + data = *((uint2*)ptr); } -FLASHINFER_INLINE void -vec_t<__hip_fp8_e5m2_fnuz, 8>::store(__hip_fp8_e5m2_fnuz *ptr) const -{ - *((uint2 *)ptr) = data; +FLASHINFER_INLINE void vec_t<__hip_fp8_e5m2_fnuz, 8>::store(__hip_fp8_e5m2_fnuz* ptr) const { + *((uint2*)ptr) = data; } -FLASHINFER_INLINE void -vec_t<__hip_fp8_e5m2_fnuz, 8>::memcpy(__hip_fp8_e5m2_fnuz *dst, - const __hip_fp8_e5m2_fnuz *src) -{ - *((uint2 *)dst) = *((uint2 *)src); +FLASHINFER_INLINE void vec_t<__hip_fp8_e5m2_fnuz, 8>::memcpy(__hip_fp8_e5m2_fnuz* dst, + const __hip_fp8_e5m2_fnuz* src) { + *((uint2*)dst) = *((uint2*)src); } // __hip_fp8_e5m2_fnuz x 16 or more -template struct vec_t<__hip_fp8_e5m2_fnuz, vec_size> -{ - uint4 data[vec_size / 16]; - - FLASHINFER_INLINE __hip_fp8_e5m2_fnuz &operator[](size_t i) - { - return ((__hip_fp8_e5m2_fnuz *)data)[i]; - } - FLASHINFER_INLINE const __hip_fp8_e5m2_fnuz &operator[](size_t i) const - { - return ((const __hip_fp8_e5m2_fnuz *)data)[i]; - } - FLASHINFER_INLINE __hip_fp8_e5m2_fnuz *ptr() - { - return reinterpret_cast<__hip_fp8_e5m2_fnuz *>(&data); - } - FLASHINFER_INLINE void fill(__hip_fp8_e5m2_fnuz val) - { +template +struct vec_t<__hip_fp8_e5m2_fnuz, vec_size> { + uint4 data[vec_size / 16]; + + FLASHINFER_INLINE __hip_fp8_e5m2_fnuz& operator[](size_t i) { + return ((__hip_fp8_e5m2_fnuz*)data)[i]; + } + FLASHINFER_INLINE const __hip_fp8_e5m2_fnuz& operator[](size_t i) const { + return ((const __hip_fp8_e5m2_fnuz*)data)[i]; + } + FLASHINFER_INLINE __hip_fp8_e5m2_fnuz* ptr() { + return reinterpret_cast<__hip_fp8_e5m2_fnuz*>(&data); + } + FLASHINFER_INLINE void fill(__hip_fp8_e5m2_fnuz val) { #pragma unroll - for (size_t i = 0; i < vec_size / 16; ++i) { - ((__hip_fp8x4_e5m2_fnuz *)(&(data[i].x)))->__x = - (__hip_fp8x4_storage_t(val.__x) << 24) | - (__hip_fp8x4_storage_t(val.__x) << 16) | - (__hip_fp8x4_storage_t(val.__x) << 8) | - __hip_fp8x4_storage_t(val.__x); - ((__hip_fp8x4_e5m2_fnuz *)(&(data[i].y)))->__x = - (__hip_fp8x4_storage_t(val.__x) << 24) | - (__hip_fp8x4_storage_t(val.__x) << 16) | - (__hip_fp8x4_storage_t(val.__x) << 8) | - __hip_fp8x4_storage_t(val.__x); - ((__hip_fp8x4_e5m2_fnuz *)(&(data[i].z)))->__x = - (__hip_fp8x4_storage_t(val.__x) << 24) | - (__hip_fp8x4_storage_t(val.__x) << 16) | - (__hip_fp8x4_storage_t(val.__x) << 8) | - __hip_fp8x4_storage_t(val.__x); - ((__hip_fp8x4_e5m2_fnuz *)(&(data[i].w)))->__x = - (__hip_fp8x4_storage_t(val.__x) << 24) | - (__hip_fp8x4_storage_t(val.__x) << 16) | - (__hip_fp8x4_storage_t(val.__x) << 8) | - __hip_fp8x4_storage_t(val.__x); - } - } - FLASHINFER_INLINE void load(const __hip_fp8_e5m2_fnuz *ptr) - { + for (size_t i = 0; i < vec_size / 16; ++i) { + ((__hip_fp8x4_e5m2_fnuz*)(&(data[i].x)))->__x = + (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | + (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); + ((__hip_fp8x4_e5m2_fnuz*)(&(data[i].y)))->__x = + (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | + (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); + ((__hip_fp8x4_e5m2_fnuz*)(&(data[i].z)))->__x = + (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | + (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); + ((__hip_fp8x4_e5m2_fnuz*)(&(data[i].w)))->__x = + (__hip_fp8x4_storage_t(val.__x) << 24) | (__hip_fp8x4_storage_t(val.__x) << 16) | + (__hip_fp8x4_storage_t(val.__x) << 8) | __hip_fp8x4_storage_t(val.__x); + } + } + FLASHINFER_INLINE void load(const __hip_fp8_e5m2_fnuz* ptr) { #pragma unroll - for (size_t i = 0; i < vec_size / 16; ++i) { - data[i] = ((uint4 *)ptr)[i]; - } + for (size_t i = 0; i < vec_size / 16; ++i) { + data[i] = ((uint4*)ptr)[i]; } - FLASHINFER_INLINE void store(__hip_fp8_e5m2_fnuz *ptr) const - { + } + FLASHINFER_INLINE void store(__hip_fp8_e5m2_fnuz* ptr) const { #pragma unroll - for (size_t i = 0; i < vec_size / 16; ++i) { - ((uint4 *)ptr)[i] = data[i]; - } - } - template - FLASHINFER_INLINE void cast_from(const vec_t &src) - { - cast_from_impl(*this, src); - } - template FLASHINFER_INLINE void cast_load(const T *ptr) - { - cast_load_impl(*this, ptr); - } - template FLASHINFER_INLINE void cast_store(T *ptr) const - { - cast_store_impl(ptr, *this); - } - FLASHINFER_INLINE static void memcpy(__hip_fp8_e5m2_fnuz *dst, - const __hip_fp8_e5m2_fnuz *src) - { + for (size_t i = 0; i < vec_size / 16; ++i) { + ((uint4*)ptr)[i] = data[i]; + } + } + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + FLASHINFER_INLINE static void memcpy(__hip_fp8_e5m2_fnuz* dst, const __hip_fp8_e5m2_fnuz* src) { #pragma unroll - for (size_t i = 0; i < vec_size / 16; ++i) { - ((uint4 *)dst)[i] = ((uint4 *)src)[i]; - } + for (size_t i = 0; i < vec_size / 16; ++i) { + ((uint4*)dst)[i] = ((uint4*)src)[i]; } + } }; /******************* vec_t *******************/ // half x 1 -template <> struct vec_t -{ - half data; - - FLASHINFER_INLINE half &operator[](size_t i) - { - return ((half *)(&data))[i]; - } - FLASHINFER_INLINE const half &operator[](size_t i) const - { - return ((const half *)(&data))[i]; - } - FLASHINFER_INLINE half *ptr() { return reinterpret_cast(&data); } - FLASHINFER_INLINE void fill(half val); - FLASHINFER_INLINE void load(const half *ptr); - FLASHINFER_INLINE void store(half *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) - { - cast_from_impl(*this, src); - } - template FLASHINFER_INLINE void cast_load(const T *ptr) - { - cast_load_impl(*this, ptr); - } - template FLASHINFER_INLINE void cast_store(T *ptr) const - { - cast_store_impl(ptr, *this); - } - - FLASHINFER_INLINE static void memcpy(half *dst, const half *src); +template <> +struct vec_t { + half data; + + FLASHINFER_INLINE half& operator[](size_t i) { return ((half*)(&data))[i]; } + FLASHINFER_INLINE const half& operator[](size_t i) const { return ((const half*)(&data))[i]; } + FLASHINFER_INLINE half* ptr() { return reinterpret_cast(&data); } + FLASHINFER_INLINE void fill(half val); + FLASHINFER_INLINE void load(const half* ptr); + FLASHINFER_INLINE void store(half* ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + + FLASHINFER_INLINE static void memcpy(half* dst, const half* src); }; FLASHINFER_INLINE void vec_t::fill(half val) { data = val; } -FLASHINFER_INLINE void vec_t::load(const half *ptr) { data = *ptr; } +FLASHINFER_INLINE void vec_t::load(const half* ptr) { data = *ptr; } -FLASHINFER_INLINE void vec_t::store(half *ptr) const { *ptr = data; } +FLASHINFER_INLINE void vec_t::store(half* ptr) const { *ptr = data; } -FLASHINFER_INLINE void vec_t::memcpy(half *dst, const half *src) -{ - *dst = *src; -} +FLASHINFER_INLINE void vec_t::memcpy(half* dst, const half* src) { *dst = *src; } // half x 2 -template <> struct vec_t -{ - half2 data; - - FLASHINFER_INLINE half &operator[](size_t i) - { - return ((half *)(&data))[i]; - } - FLASHINFER_INLINE const half &operator[](size_t i) const - { - return ((const half *)(&data))[i]; - } - FLASHINFER_INLINE half *ptr() { return reinterpret_cast(&data); } - FLASHINFER_INLINE void fill(half val); - FLASHINFER_INLINE void load(const half *ptr); - FLASHINFER_INLINE void store(half *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) - { - cast_from_impl(*this, src); - } - template FLASHINFER_INLINE void cast_load(const T *ptr) - { - cast_load_impl(*this, ptr); - } - template FLASHINFER_INLINE void cast_store(T *ptr) const - { - cast_store_impl(ptr, *this); - } - - FLASHINFER_INLINE static void memcpy(half *dst, const half *src); +template <> +struct vec_t { + half2 data; + + FLASHINFER_INLINE half& operator[](size_t i) { return ((half*)(&data))[i]; } + FLASHINFER_INLINE const half& operator[](size_t i) const { return ((const half*)(&data))[i]; } + FLASHINFER_INLINE half* ptr() { return reinterpret_cast(&data); } + FLASHINFER_INLINE void fill(half val); + FLASHINFER_INLINE void load(const half* ptr); + FLASHINFER_INLINE void store(half* ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + + FLASHINFER_INLINE static void memcpy(half* dst, const half* src); }; -FLASHINFER_INLINE void vec_t::fill(half val) -{ - data = make_half2(val, val); -} +FLASHINFER_INLINE void vec_t::fill(half val) { data = make_half2(val, val); } -FLASHINFER_INLINE void vec_t::load(const half *ptr) -{ - data = *((half2 *)ptr); -} +FLASHINFER_INLINE void vec_t::load(const half* ptr) { data = *((half2*)ptr); } -FLASHINFER_INLINE void vec_t::store(half *ptr) const -{ - *((half2 *)ptr) = data; -} +FLASHINFER_INLINE void vec_t::store(half* ptr) const { *((half2*)ptr) = data; } -FLASHINFER_INLINE void vec_t::memcpy(half *dst, const half *src) -{ - *((half2 *)dst) = *((half2 *)src); +FLASHINFER_INLINE void vec_t::memcpy(half* dst, const half* src) { + *((half2*)dst) = *((half2*)src); } // half x 4 -template <> struct vec_t -{ - uint2 data; - - FLASHINFER_INLINE half &operator[](size_t i) - { - return ((half *)(&data))[i]; - } - FLASHINFER_INLINE const half &operator[](size_t i) const - { - return ((const half *)(&data))[i]; - } - FLASHINFER_INLINE half *ptr() { return reinterpret_cast(&data); } - FLASHINFER_INLINE void fill(half val); - FLASHINFER_INLINE void load(const half *ptr); - FLASHINFER_INLINE void store(half *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) - { - cast_from_impl(*this, src); - } - template FLASHINFER_INLINE void cast_load(const T *ptr) - { - cast_load_impl(*this, ptr); - } - template FLASHINFER_INLINE void cast_store(T *ptr) const - { - cast_store_impl(ptr, *this); - } - FLASHINFER_INLINE static void memcpy(half *dst, const half *src); +template <> +struct vec_t { + uint2 data; + + FLASHINFER_INLINE half& operator[](size_t i) { return ((half*)(&data))[i]; } + FLASHINFER_INLINE const half& operator[](size_t i) const { return ((const half*)(&data))[i]; } + FLASHINFER_INLINE half* ptr() { return reinterpret_cast(&data); } + FLASHINFER_INLINE void fill(half val); + FLASHINFER_INLINE void load(const half* ptr); + FLASHINFER_INLINE void store(half* ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + FLASHINFER_INLINE static void memcpy(half* dst, const half* src); }; -FLASHINFER_INLINE void vec_t::fill(half val) -{ - *(half2 *)(&data.x) = make_half2(val, val); - *(half2 *)(&data.y) = make_half2(val, val); +FLASHINFER_INLINE void vec_t::fill(half val) { + *(half2*)(&data.x) = make_half2(val, val); + *(half2*)(&data.y) = make_half2(val, val); } -FLASHINFER_INLINE void vec_t::load(const half *ptr) -{ - data = *((uint2 *)ptr); -} +FLASHINFER_INLINE void vec_t::load(const half* ptr) { data = *((uint2*)ptr); } -FLASHINFER_INLINE void vec_t::store(half *ptr) const -{ - *((uint2 *)ptr) = data; -} +FLASHINFER_INLINE void vec_t::store(half* ptr) const { *((uint2*)ptr) = data; } -FLASHINFER_INLINE void vec_t::memcpy(half *dst, const half *src) -{ - *((uint2 *)dst) = *((uint2 *)src); +FLASHINFER_INLINE void vec_t::memcpy(half* dst, const half* src) { + *((uint2*)dst) = *((uint2*)src); } // half x 8 or more -template struct vec_t -{ - uint4 data[vec_size / 8]; - FLASHINFER_INLINE half &operator[](size_t i) { return ((half *)data)[i]; } - FLASHINFER_INLINE const half &operator[](size_t i) const - { - return ((const half *)data)[i]; - } - FLASHINFER_INLINE half *ptr() { return reinterpret_cast(&data); } - FLASHINFER_INLINE void fill(half val) - { +template +struct vec_t { + uint4 data[vec_size / 8]; + FLASHINFER_INLINE half& operator[](size_t i) { return ((half*)data)[i]; } + FLASHINFER_INLINE const half& operator[](size_t i) const { return ((const half*)data)[i]; } + FLASHINFER_INLINE half* ptr() { return reinterpret_cast(&data); } + FLASHINFER_INLINE void fill(half val) { #pragma unroll - for (size_t i = 0; i < vec_size / 8; ++i) { - *(half2 *)(&(data[i].x)) = make_half2(val, val); - *(half2 *)(&(data[i].y)) = make_half2(val, val); - *(half2 *)(&(data[i].z)) = make_half2(val, val); - *(half2 *)(&(data[i].w)) = make_half2(val, val); - } - } - FLASHINFER_INLINE void load(const half *ptr) - { + for (size_t i = 0; i < vec_size / 8; ++i) { + *(half2*)(&(data[i].x)) = make_half2(val, val); + *(half2*)(&(data[i].y)) = make_half2(val, val); + *(half2*)(&(data[i].z)) = make_half2(val, val); + *(half2*)(&(data[i].w)) = make_half2(val, val); + } + } + FLASHINFER_INLINE void load(const half* ptr) { #pragma unroll - for (size_t i = 0; i < vec_size / 8; ++i) { - data[i] = ((uint4 *)ptr)[i]; - } + for (size_t i = 0; i < vec_size / 8; ++i) { + data[i] = ((uint4*)ptr)[i]; } - FLASHINFER_INLINE void store(half *ptr) const - { + } + FLASHINFER_INLINE void store(half* ptr) const { #pragma unroll - for (size_t i = 0; i < vec_size / 8; ++i) { - ((uint4 *)ptr)[i] = data[i]; - } - } - template - FLASHINFER_INLINE void cast_from(const vec_t &src) - { - cast_from_impl(*this, src); - } - template FLASHINFER_INLINE void cast_load(const T *ptr) - { - cast_load_impl(*this, ptr); - } - template FLASHINFER_INLINE void cast_store(T *ptr) const - { - cast_store_impl(ptr, *this); - } - FLASHINFER_INLINE static void memcpy(half *dst, const half *src) - { + for (size_t i = 0; i < vec_size / 8; ++i) { + ((uint4*)ptr)[i] = data[i]; + } + } + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + FLASHINFER_INLINE static void memcpy(half* dst, const half* src) { #pragma unroll - for (size_t i = 0; i < vec_size / 8; ++i) { - ((uint4 *)dst)[i] = ((uint4 *)src)[i]; - } + for (size_t i = 0; i < vec_size / 8; ++i) { + ((uint4*)dst)[i] = ((uint4*)src)[i]; } + } }; /******************* vec_t<__hip_bfloat16> *******************/ // __hip_bfloat16 x 1 -template <> struct vec_t<__hip_bfloat16, 1> -{ - __hip_bfloat16 data; - FLASHINFER_INLINE __hip_bfloat16 &operator[](size_t i) - { - return ((__hip_bfloat16 *)(&data))[i]; - } - FLASHINFER_INLINE const __hip_bfloat16 &operator[](size_t i) const - { - return ((const __hip_bfloat16 *)(&data))[i]; - } - FLASHINFER_INLINE __hip_bfloat16 *ptr() - { - return reinterpret_cast<__hip_bfloat16 *>(&data); - } - FLASHINFER_INLINE void fill(__hip_bfloat16 val); - FLASHINFER_INLINE void load(const __hip_bfloat16 *ptr); - FLASHINFER_INLINE void store(__hip_bfloat16 *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) - { - cast_from_impl(*this, src); - } - template FLASHINFER_INLINE void cast_load(const T *ptr) - { - cast_load_impl(*this, ptr); - } - template FLASHINFER_INLINE void cast_store(T *ptr) const - { - cast_store_impl(ptr, *this); - } - FLASHINFER_INLINE static void memcpy(__hip_bfloat16 *dst, - const __hip_bfloat16 *src); +template <> +struct vec_t<__hip_bfloat16, 1> { + __hip_bfloat16 data; + FLASHINFER_INLINE __hip_bfloat16& operator[](size_t i) { return ((__hip_bfloat16*)(&data))[i]; } + FLASHINFER_INLINE const __hip_bfloat16& operator[](size_t i) const { + return ((const __hip_bfloat16*)(&data))[i]; + } + FLASHINFER_INLINE __hip_bfloat16* ptr() { return reinterpret_cast<__hip_bfloat16*>(&data); } + FLASHINFER_INLINE void fill(__hip_bfloat16 val); + FLASHINFER_INLINE void load(const __hip_bfloat16* ptr); + FLASHINFER_INLINE void store(__hip_bfloat16* ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + FLASHINFER_INLINE static void memcpy(__hip_bfloat16* dst, const __hip_bfloat16* src); }; -FLASHINFER_INLINE void vec_t<__hip_bfloat16, 1>::fill(__hip_bfloat16 val) -{ - data = val; -} +FLASHINFER_INLINE void vec_t<__hip_bfloat16, 1>::fill(__hip_bfloat16 val) { data = val; } -FLASHINFER_INLINE void vec_t<__hip_bfloat16, 1>::load(const __hip_bfloat16 *ptr) -{ - data = *ptr; -} +FLASHINFER_INLINE void vec_t<__hip_bfloat16, 1>::load(const __hip_bfloat16* ptr) { data = *ptr; } -FLASHINFER_INLINE void -vec_t<__hip_bfloat16, 1>::store(__hip_bfloat16 *ptr) const -{ - *ptr = data; -} +FLASHINFER_INLINE void vec_t<__hip_bfloat16, 1>::store(__hip_bfloat16* ptr) const { *ptr = data; } -FLASHINFER_INLINE void -vec_t<__hip_bfloat16, 1>::memcpy(__hip_bfloat16 *dst, const __hip_bfloat16 *src) -{ - *dst = *src; +FLASHINFER_INLINE void vec_t<__hip_bfloat16, 1>::memcpy(__hip_bfloat16* dst, + const __hip_bfloat16* src) { + *dst = *src; } // __hip_bfloat16 x 2 -template <> struct vec_t<__hip_bfloat16, 2> -{ - __hip_bfloat162 data; - - FLASHINFER_INLINE __hip_bfloat16 &operator[](size_t i) - { - return ((__hip_bfloat16 *)(&data))[i]; - } - FLASHINFER_INLINE const __hip_bfloat16 &operator[](size_t i) const - { - return ((const __hip_bfloat16 *)(&data))[i]; - } - FLASHINFER_INLINE __hip_bfloat16 *ptr() - { - return reinterpret_cast<__hip_bfloat16 *>(&data); - } - FLASHINFER_INLINE void fill(__hip_bfloat16 val); - FLASHINFER_INLINE void load(const __hip_bfloat16 *ptr); - FLASHINFER_INLINE void store(__hip_bfloat16 *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) - { - cast_from_impl(*this, src); - } - template FLASHINFER_INLINE void cast_load(const T *ptr) - { - cast_load_impl(*this, ptr); - } - template FLASHINFER_INLINE void cast_store(T *ptr) const - { - cast_store_impl(ptr, *this); - } - FLASHINFER_INLINE static void memcpy(__hip_bfloat16 *dst, - const __hip_bfloat16 *src); +template <> +struct vec_t<__hip_bfloat16, 2> { + __hip_bfloat162 data; + + FLASHINFER_INLINE __hip_bfloat16& operator[](size_t i) { return ((__hip_bfloat16*)(&data))[i]; } + FLASHINFER_INLINE const __hip_bfloat16& operator[](size_t i) const { + return ((const __hip_bfloat16*)(&data))[i]; + } + FLASHINFER_INLINE __hip_bfloat16* ptr() { return reinterpret_cast<__hip_bfloat16*>(&data); } + FLASHINFER_INLINE void fill(__hip_bfloat16 val); + FLASHINFER_INLINE void load(const __hip_bfloat16* ptr); + FLASHINFER_INLINE void store(__hip_bfloat16* ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + FLASHINFER_INLINE static void memcpy(__hip_bfloat16* dst, const __hip_bfloat16* src); }; -FLASHINFER_INLINE void vec_t<__hip_bfloat16, 2>::fill(__hip_bfloat16 val) -{ - data = make_bfloat162(val, val); +FLASHINFER_INLINE void vec_t<__hip_bfloat16, 2>::fill(__hip_bfloat16 val) { + data = make_bfloat162(val, val); } -FLASHINFER_INLINE void vec_t<__hip_bfloat16, 2>::load(const __hip_bfloat16 *ptr) -{ - data = *((__hip_bfloat162 *)ptr); +FLASHINFER_INLINE void vec_t<__hip_bfloat16, 2>::load(const __hip_bfloat16* ptr) { + data = *((__hip_bfloat162*)ptr); } -FLASHINFER_INLINE void -vec_t<__hip_bfloat16, 2>::store(__hip_bfloat16 *ptr) const -{ - *((__hip_bfloat162 *)ptr) = data; +FLASHINFER_INLINE void vec_t<__hip_bfloat16, 2>::store(__hip_bfloat16* ptr) const { + *((__hip_bfloat162*)ptr) = data; } -FLASHINFER_INLINE void -vec_t<__hip_bfloat16, 2>::memcpy(__hip_bfloat16 *dst, const __hip_bfloat16 *src) -{ - *((__hip_bfloat162 *)dst) = *((__hip_bfloat162 *)src); +FLASHINFER_INLINE void vec_t<__hip_bfloat16, 2>::memcpy(__hip_bfloat16* dst, + const __hip_bfloat16* src) { + *((__hip_bfloat162*)dst) = *((__hip_bfloat162*)src); } // __hip_bfloat16 x 4 -template <> struct vec_t<__hip_bfloat16, 4> -{ - uint2 data; - - FLASHINFER_INLINE __hip_bfloat16 &operator[](size_t i) - { - return ((__hip_bfloat16 *)(&data))[i]; - } - FLASHINFER_INLINE const __hip_bfloat16 &operator[](size_t i) const - { - return ((const __hip_bfloat16 *)(&data))[i]; - } - FLASHINFER_INLINE __hip_bfloat16 *ptr() - { - return reinterpret_cast<__hip_bfloat16 *>(&data); - } - FLASHINFER_INLINE void fill(__hip_bfloat16 val); - FLASHINFER_INLINE void load(const __hip_bfloat16 *ptr); - FLASHINFER_INLINE void store(__hip_bfloat16 *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) - { - cast_from_impl(*this, src); - } - template FLASHINFER_INLINE void cast_load(const T *ptr) - { - cast_load_impl(*this, ptr); - } - template FLASHINFER_INLINE void cast_store(T *ptr) const - { - cast_store_impl(ptr, *this); - } - FLASHINFER_INLINE static void memcpy(__hip_bfloat16 *dst, - const __hip_bfloat16 *src); +template <> +struct vec_t<__hip_bfloat16, 4> { + uint2 data; + + FLASHINFER_INLINE __hip_bfloat16& operator[](size_t i) { return ((__hip_bfloat16*)(&data))[i]; } + FLASHINFER_INLINE const __hip_bfloat16& operator[](size_t i) const { + return ((const __hip_bfloat16*)(&data))[i]; + } + FLASHINFER_INLINE __hip_bfloat16* ptr() { return reinterpret_cast<__hip_bfloat16*>(&data); } + FLASHINFER_INLINE void fill(__hip_bfloat16 val); + FLASHINFER_INLINE void load(const __hip_bfloat16* ptr); + FLASHINFER_INLINE void store(__hip_bfloat16* ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + FLASHINFER_INLINE static void memcpy(__hip_bfloat16* dst, const __hip_bfloat16* src); }; -FLASHINFER_INLINE void vec_t<__hip_bfloat16, 4>::fill(__hip_bfloat16 val) -{ - *(__hip_bfloat162 *)(&data.x) = make_bfloat162(val, val); - *(__hip_bfloat162 *)(&data.y) = make_bfloat162(val, val); +FLASHINFER_INLINE void vec_t<__hip_bfloat16, 4>::fill(__hip_bfloat16 val) { + *(__hip_bfloat162*)(&data.x) = make_bfloat162(val, val); + *(__hip_bfloat162*)(&data.y) = make_bfloat162(val, val); } -FLASHINFER_INLINE void vec_t<__hip_bfloat16, 4>::load(const __hip_bfloat16 *ptr) -{ - data = *((uint2 *)ptr); +FLASHINFER_INLINE void vec_t<__hip_bfloat16, 4>::load(const __hip_bfloat16* ptr) { + data = *((uint2*)ptr); } -FLASHINFER_INLINE void -vec_t<__hip_bfloat16, 4>::store(__hip_bfloat16 *ptr) const -{ - *((uint2 *)ptr) = data; +FLASHINFER_INLINE void vec_t<__hip_bfloat16, 4>::store(__hip_bfloat16* ptr) const { + *((uint2*)ptr) = data; } -FLASHINFER_INLINE void -vec_t<__hip_bfloat16, 4>::memcpy(__hip_bfloat16 *dst, const __hip_bfloat16 *src) -{ - *((uint2 *)dst) = *((uint2 *)src); +FLASHINFER_INLINE void vec_t<__hip_bfloat16, 4>::memcpy(__hip_bfloat16* dst, + const __hip_bfloat16* src) { + *((uint2*)dst) = *((uint2*)src); } // __hip_bfloat16 x 8 or more -template struct vec_t<__hip_bfloat16, vec_size> -{ - uint4 data[vec_size / 8]; +template +struct vec_t<__hip_bfloat16, vec_size> { + uint4 data[vec_size / 8]; - FLASHINFER_INLINE __hip_bfloat16 &operator[](size_t i) - { - return ((__hip_bfloat16 *)data)[i]; - } - FLASHINFER_INLINE const __hip_bfloat16 &operator[](size_t i) const - { - return ((const __hip_bfloat16 *)data)[i]; - } - FLASHINFER_INLINE __hip_bfloat16 *ptr() - { - return reinterpret_cast<__hip_bfloat16 *>(&data); - } - FLASHINFER_INLINE void fill(__hip_bfloat16 val) - { + FLASHINFER_INLINE __hip_bfloat16& operator[](size_t i) { return ((__hip_bfloat16*)data)[i]; } + FLASHINFER_INLINE const __hip_bfloat16& operator[](size_t i) const { + return ((const __hip_bfloat16*)data)[i]; + } + FLASHINFER_INLINE __hip_bfloat16* ptr() { return reinterpret_cast<__hip_bfloat16*>(&data); } + FLASHINFER_INLINE void fill(__hip_bfloat16 val) { #pragma unroll - for (size_t i = 0; i < vec_size / 8; ++i) { - *(__hip_bfloat162 *)(&(data[i].x)) = make_bfloat162(val, val); - *(__hip_bfloat162 *)(&(data[i].y)) = make_bfloat162(val, val); - *(__hip_bfloat162 *)(&(data[i].z)) = make_bfloat162(val, val); - *(__hip_bfloat162 *)(&(data[i].w)) = make_bfloat162(val, val); - } - } - FLASHINFER_INLINE void load(const __hip_bfloat16 *ptr) - { + for (size_t i = 0; i < vec_size / 8; ++i) { + *(__hip_bfloat162*)(&(data[i].x)) = make_bfloat162(val, val); + *(__hip_bfloat162*)(&(data[i].y)) = make_bfloat162(val, val); + *(__hip_bfloat162*)(&(data[i].z)) = make_bfloat162(val, val); + *(__hip_bfloat162*)(&(data[i].w)) = make_bfloat162(val, val); + } + } + FLASHINFER_INLINE void load(const __hip_bfloat16* ptr) { #pragma unroll - for (size_t i = 0; i < vec_size / 8; ++i) { - data[i] = ((uint4 *)ptr)[i]; - } + for (size_t i = 0; i < vec_size / 8; ++i) { + data[i] = ((uint4*)ptr)[i]; } - FLASHINFER_INLINE void store(__hip_bfloat16 *ptr) const - { + } + FLASHINFER_INLINE void store(__hip_bfloat16* ptr) const { #pragma unroll - for (size_t i = 0; i < vec_size / 8; ++i) { - ((uint4 *)ptr)[i] = data[i]; - } - } - template - FLASHINFER_INLINE void cast_from(const vec_t &src) - { - cast_from_impl(*this, src); - } - template FLASHINFER_INLINE void cast_load(const T *ptr) - { - cast_load_impl(*this, ptr); - } - template FLASHINFER_INLINE void cast_store(T *ptr) const - { - cast_store_impl(ptr, *this); - } - FLASHINFER_INLINE static void memcpy(__hip_bfloat16 *dst, - const __hip_bfloat16 *src) - { + for (size_t i = 0; i < vec_size / 8; ++i) { + ((uint4*)ptr)[i] = data[i]; + } + } + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + FLASHINFER_INLINE static void memcpy(__hip_bfloat16* dst, const __hip_bfloat16* src) { #pragma unroll - for (size_t i = 0; i < vec_size / 8; ++i) { - ((uint4 *)dst)[i] = ((uint4 *)src)[i]; - } + for (size_t i = 0; i < vec_size / 8; ++i) { + ((uint4*)dst)[i] = ((uint4*)src)[i]; } + } }; /******************* vec_t *******************/ // float x 1 -template <> struct vec_t -{ - float data; - - FLASHINFER_INLINE float &operator[](size_t i) - { - return ((float *)(&data))[i]; - } - FLASHINFER_INLINE const float &operator[](size_t i) const - { - return ((const float *)(&data))[i]; - } - FLASHINFER_INLINE float *ptr() { return reinterpret_cast(&data); } - FLASHINFER_INLINE void fill(float val); - FLASHINFER_INLINE void load(const float *ptr); - FLASHINFER_INLINE void store(float *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) - { - cast_from_impl(*this, src); - } - template FLASHINFER_INLINE void cast_load(const T *ptr) - { - cast_load_impl(*this, ptr); - } - template FLASHINFER_INLINE void cast_store(T *ptr) const - { - cast_store_impl(ptr, *this); - } - FLASHINFER_INLINE static void memcpy(float *dst, const float *src); +template <> +struct vec_t { + float data; + + FLASHINFER_INLINE float& operator[](size_t i) { return ((float*)(&data))[i]; } + FLASHINFER_INLINE const float& operator[](size_t i) const { return ((const float*)(&data))[i]; } + FLASHINFER_INLINE float* ptr() { return reinterpret_cast(&data); } + FLASHINFER_INLINE void fill(float val); + FLASHINFER_INLINE void load(const float* ptr); + FLASHINFER_INLINE void store(float* ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + FLASHINFER_INLINE static void memcpy(float* dst, const float* src); }; FLASHINFER_INLINE void vec_t::fill(float val) { data = val; } -FLASHINFER_INLINE void vec_t::load(const float *ptr) { data = *ptr; } +FLASHINFER_INLINE void vec_t::load(const float* ptr) { data = *ptr; } -FLASHINFER_INLINE void vec_t::store(float *ptr) const { *ptr = data; } +FLASHINFER_INLINE void vec_t::store(float* ptr) const { *ptr = data; } -FLASHINFER_INLINE void vec_t::memcpy(float *dst, const float *src) -{ - *dst = *src; -} +FLASHINFER_INLINE void vec_t::memcpy(float* dst, const float* src) { *dst = *src; } // float x 2 -template <> struct vec_t -{ - float2 data; - - FLASHINFER_INLINE float &operator[](size_t i) - { - return ((float *)(&data))[i]; - } - FLASHINFER_INLINE const float &operator[](size_t i) const - { - return ((const float *)(&data))[i]; - } - FLASHINFER_INLINE float *ptr() { return reinterpret_cast(&data); } - FLASHINFER_INLINE void fill(float val); - FLASHINFER_INLINE void load(const float *ptr); - FLASHINFER_INLINE void store(float *ptr) const; - template - FLASHINFER_INLINE void cast_from(const vec_t &src) - { - cast_from_impl(*this, src); - } - template FLASHINFER_INLINE void cast_load(const T *ptr) - { - cast_load_impl(*this, ptr); - } - template FLASHINFER_INLINE void cast_store(T *ptr) const - { - cast_store_impl(ptr, *this); - } - FLASHINFER_INLINE static void memcpy(float *dst, const float *src); +template <> +struct vec_t { + float2 data; + + FLASHINFER_INLINE float& operator[](size_t i) { return ((float*)(&data))[i]; } + FLASHINFER_INLINE const float& operator[](size_t i) const { return ((const float*)(&data))[i]; } + FLASHINFER_INLINE float* ptr() { return reinterpret_cast(&data); } + FLASHINFER_INLINE void fill(float val); + FLASHINFER_INLINE void load(const float* ptr); + FLASHINFER_INLINE void store(float* ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + FLASHINFER_INLINE static void memcpy(float* dst, const float* src); }; -FLASHINFER_INLINE void vec_t::fill(float val) -{ - data = make_float2(val, val); -} +FLASHINFER_INLINE void vec_t::fill(float val) { data = make_float2(val, val); } -FLASHINFER_INLINE void vec_t::load(const float *ptr) -{ - data = *((float2 *)ptr); -} +FLASHINFER_INLINE void vec_t::load(const float* ptr) { data = *((float2*)ptr); } -FLASHINFER_INLINE void vec_t::store(float *ptr) const -{ - *((float2 *)ptr) = data; -} +FLASHINFER_INLINE void vec_t::store(float* ptr) const { *((float2*)ptr) = data; } -FLASHINFER_INLINE void vec_t::memcpy(float *dst, const float *src) -{ - *((float2 *)dst) = *((float2 *)src); +FLASHINFER_INLINE void vec_t::memcpy(float* dst, const float* src) { + *((float2*)dst) = *((float2*)src); } // float x 4 or more -template struct vec_t -{ - float4 data[vec_size / 4]; - - FLASHINFER_INLINE float &operator[](size_t i) - { - return ((float *)(data))[i]; - } - FLASHINFER_INLINE const float &operator[](size_t i) const - { - return ((const float *)(data))[i]; - } - FLASHINFER_INLINE float *ptr() { return reinterpret_cast(&data); } - FLASHINFER_INLINE void fill(float val) - { +template +struct vec_t { + float4 data[vec_size / 4]; + + FLASHINFER_INLINE float& operator[](size_t i) { return ((float*)(data))[i]; } + FLASHINFER_INLINE const float& operator[](size_t i) const { return ((const float*)(data))[i]; } + FLASHINFER_INLINE float* ptr() { return reinterpret_cast(&data); } + FLASHINFER_INLINE void fill(float val) { #pragma unroll - for (size_t i = 0; i < vec_size / 4; ++i) { - data[i] = make_float4(val, val, val, val); - } + for (size_t i = 0; i < vec_size / 4; ++i) { + data[i] = make_float4(val, val, val, val); } - FLASHINFER_INLINE void load(const float *ptr) - { + } + FLASHINFER_INLINE void load(const float* ptr) { #pragma unroll - for (size_t i = 0; i < vec_size / 4; ++i) { - data[i] = ((float4 *)ptr)[i]; - } + for (size_t i = 0; i < vec_size / 4; ++i) { + data[i] = ((float4*)ptr)[i]; } - FLASHINFER_INLINE void store(float *ptr) const - { + } + FLASHINFER_INLINE void store(float* ptr) const { #pragma unroll - for (size_t i = 0; i < vec_size / 4; ++i) { - ((float4 *)ptr)[i] = data[i]; - } - } - template - FLASHINFER_INLINE void cast_from(const vec_t &src) - { - cast_from_impl(*this, src); - } - template FLASHINFER_INLINE void cast_load(const T *ptr) - { - cast_load_impl(*this, ptr); - } - template FLASHINFER_INLINE void cast_store(T *ptr) const - { - cast_store_impl(ptr, *this); - } - FLASHINFER_INLINE static void memcpy(float *dst, const float *src) - { + for (size_t i = 0; i < vec_size / 4; ++i) { + ((float4*)ptr)[i] = data[i]; + } + } + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + FLASHINFER_INLINE static void memcpy(float* dst, const float* src) { #pragma unroll - for (size_t i = 0; i < vec_size / 4; ++i) { - ((float4 *)dst)[i] = ((float4 *)src)[i]; - } + for (size_t i = 0; i < vec_size / 4; ++i) { + ((float4*)dst)[i] = ((float4*)src)[i]; } + } }; -} // namespace flashinfer +} // namespace flashinfer } diff --git a/libflashinfer/include/gpu_iface/conversion_utils.h b/libflashinfer/include/gpu_iface/conversion_utils.h index 768bde165d..0263722cb3 100644 --- a/libflashinfer/include/gpu_iface/conversion_utils.h +++ b/libflashinfer/include/gpu_iface/conversion_utils.h @@ -8,59 +8,46 @@ #include #include -namespace fi::con -{ +namespace fi::con { template -__host__ __device__ __inline__ DTypeOut explicit_casting(DTypeIn value) -{ - return DTypeOut(value); +__host__ __device__ __inline__ DTypeOut explicit_casting(DTypeIn value) { + return DTypeOut(value); } template <> -__host__ __device__ __inline__ float -explicit_casting<__half, float>(__half value) -{ - return __half2float(value); +__host__ __device__ __inline__ float explicit_casting<__half, float>(__half value) { + return __half2float(value); } template <> -__host__ __device__ __inline__ float -explicit_casting<__hip_bfloat16, float>(__hip_bfloat16 value) -{ - return __bfloat162float(value); +__host__ __device__ __inline__ float explicit_casting<__hip_bfloat16, float>(__hip_bfloat16 value) { + return __bfloat162float(value); } template <> -__host__ __device__ __inline__ __half -explicit_casting(float value) -{ - return __float2half(value); +__host__ __device__ __inline__ __half explicit_casting(float value) { + return __float2half(value); } template <> -__host__ __device__ __inline__ __hip_bfloat16 -explicit_casting<__half, __hip_bfloat16>(__half value) -{ - return __float2bfloat16(__half2float(value)); +__host__ __device__ __inline__ __hip_bfloat16 explicit_casting<__half, __hip_bfloat16>( + __half value) { + return __float2bfloat16(__half2float(value)); } template <> -__host__ __device__ __inline__ float explicit_casting(float value) -{ - return value; +__host__ __device__ __inline__ float explicit_casting(float value) { + return value; } template <> -__host__ __device__ __inline__ __half -explicit_casting<__half, __half>(__half value) -{ - return value; +__host__ __device__ __inline__ __half explicit_casting<__half, __half>(__half value) { + return value; } template <> -__host__ __device__ __inline__ __hip_bfloat16 -explicit_casting<__hip_bfloat16, __hip_bfloat16>(__hip_bfloat16 value) -{ - return value; +__host__ __device__ __inline__ __hip_bfloat16 explicit_casting<__hip_bfloat16, __hip_bfloat16>( + __hip_bfloat16 value) { + return value; } -} // namespace fi::con +} // namespace fi::con diff --git a/libflashinfer/include/gpu_iface/cooperative_groups.h b/libflashinfer/include/gpu_iface/cooperative_groups.h index 4617698071..4f75fe1566 100644 --- a/libflashinfer/include/gpu_iface/cooperative_groups.h +++ b/libflashinfer/include/gpu_iface/cooperative_groups.h @@ -7,21 +7,17 @@ #if defined(PLATFORM_CUDA_DEVICE) #include -namespace flashinfer -{ -namespace gpu_iface -{ +namespace flashinfer { +namespace gpu_iface { namespace cg = ::cooperative_groups; -} // namespace gpu_iface -} // namespace flashinfer +} // namespace gpu_iface +} // namespace flashinfer #elif defined(PLATFORM_HIP_DEVICE) #include -namespace flashinfer -{ -namespace gpu_iface -{ +namespace flashinfer { +namespace gpu_iface { namespace cg = ::cooperative_groups; -} // namespace gpu_iface -} // namespace flashinfer +} // namespace gpu_iface +} // namespace flashinfer #endif diff --git a/libflashinfer/include/gpu_iface/enums.hpp b/libflashinfer/include/gpu_iface/enums.hpp index 4b10b57a50..42fd246e83 100644 --- a/libflashinfer/include/gpu_iface/enums.hpp +++ b/libflashinfer/include/gpu_iface/enums.hpp @@ -3,28 +3,25 @@ #pragma once -namespace flashinfer -{ +namespace flashinfer { /*! * \brief An enumeration class that defines different modes for applying RoPE * (Rotary Positional Embeddings). */ -enum class PosEncodingMode -{ - // No rotary positional embeddings - kNone = 0U, - // Apply Llama-style rope. - kRoPELlama = 1U, - // Apply ALiBi bias - kALiBi = 2U +enum class PosEncodingMode { + // No rotary positional embeddings + kNone = 0U, + // Apply Llama-style rope. + kRoPELlama = 1U, + // Apply ALiBi bias + kALiBi = 2U }; -enum class MaskMode -{ - kNone = 0U, // No mask - kCausal = 1U, // Causal mask - kCustom = 2U, // Custom mask +enum class MaskMode { + kNone = 0U, // No mask + kCausal = 1U, // Causal mask + kCustom = 2U, // Custom mask }; -} // namespace flashinfer +} // namespace flashinfer diff --git a/libflashinfer/include/gpu_iface/error.hpp b/libflashinfer/include/gpu_iface/error.hpp index 035dafc54f..f2de69457f 100644 --- a/libflashinfer/include/gpu_iface/error.hpp +++ b/libflashinfer/include/gpu_iface/error.hpp @@ -3,50 +3,44 @@ // SPDX - License - Identifier : Apache - 2.0 #pragma once -#include "platform.hpp" #include #include #include -namespace flashinfer -{ -namespace gpu_iface -{ +#include "platform.hpp" + +namespace flashinfer { +namespace gpu_iface { // Platform-agnostic error type -class GpuError -{ -private: - int code_; - std::string message_; - -public: - GpuError() : code_(0) {} - GpuError(int code, std::string message) - : code_(code), message_(std::move(message)) - { - } - - bool isSuccess() const { return code_ == 0; } - int code() const { return code_; } - const std::string &message() const { return message_; } +class GpuError { + private: + int code_; + std::string message_; + + public: + GpuError() : code_(0) {} + GpuError(int code, std::string message) : code_(code), message_(std::move(message)) {} + + bool isSuccess() const { return code_ == 0; } + int code() const { return code_; } + const std::string& message() const { return message_; } #if defined(PLATFORM_CUDA_DEVICE) - cudaError_t getNative() const { return static_cast(code_); } + cudaError_t getNative() const { return static_cast(code_); } #elif defined(PLATFORM_HIP_DEVICE) - hipError_t getNative() const { return static_cast(code_); } + hipError_t getNative() const { return static_cast(code_); } #endif }; // Create error from message -inline GpuError CreateError(std::string message) -{ +inline GpuError CreateError(std::string message) { #if defined(PLATFORM_CUDA_DEVICE) - return GpuError(static_cast(cudaErrorUnknown), std::move(message)); + return GpuError(static_cast(cudaErrorUnknown), std::move(message)); #elif defined(PLATFORM_HIP_DEVICE) - return GpuError(static_cast(hipErrorUnknown), std::move(message)); + return GpuError(static_cast(hipErrorUnknown), std::move(message)); #endif } -} // namespace gpu_iface -} // namespace flashinfer +} // namespace gpu_iface +} // namespace flashinfer diff --git a/libflashinfer/include/gpu_iface/exception.h b/libflashinfer/include/gpu_iface/exception.h index 2630f5e44b..9d4f9d7832 100644 --- a/libflashinfer/include/gpu_iface/exception.h +++ b/libflashinfer/include/gpu_iface/exception.h @@ -19,40 +19,30 @@ #include #include -namespace flashinfer -{ - -class Error : public std::exception -{ -private: - std::string message_; - -public: - Error(const std::string &func, - const std::string &file, - int line, - const std::string &message) - { - std::ostringstream oss; - oss << "Error in function '" << func << "' " - << "at " << file << ":" << line << ": " << message; - message_ = oss.str(); - } - - virtual const char *what() const noexcept override - { - return message_.c_str(); - } +namespace flashinfer { + +class Error : public std::exception { + private: + std::string message_; + + public: + Error(const std::string& func, const std::string& file, int line, const std::string& message) { + std::ostringstream oss; + oss << "Error in function '" << func << "' " + << "at " << file << ":" << line << ": " << message; + message_ = oss.str(); + } + + virtual const char* what() const noexcept override { return message_.c_str(); } }; -#define FLASHINFER_ERROR(message) \ - throw Error(__FUNCTION__, __FILE__, __LINE__, message) +#define FLASHINFER_ERROR(message) throw Error(__FUNCTION__, __FILE__, __LINE__, message) -#define FLASHINFER_CHECK(condition, message) \ - if (!(condition)) { \ - FLASHINFER_ERROR(message); \ - } +#define FLASHINFER_CHECK(condition, message) \ + if (!(condition)) { \ + FLASHINFER_ERROR(message); \ + } -} // namespace flashinfer +} // namespace flashinfer -#endif // FLASHINFER_EXCEPTION_H_ +#endif // FLASHINFER_EXCEPTION_H_ diff --git a/libflashinfer/include/gpu_iface/fragment.hpp b/libflashinfer/include/gpu_iface/fragment.hpp index c997d304b6..558a503a02 100644 --- a/libflashinfer/include/gpu_iface/fragment.hpp +++ b/libflashinfer/include/gpu_iface/fragment.hpp @@ -11,138 +11,113 @@ #include #endif -namespace flashinfer -{ -namespace gpu_iface -{ -namespace mma -{ - -enum class FragmentType -{ - row_major, // Row-major matrix layout - col_major, // Column-major matrix layout - accumulator // Accumulator (no layout) +namespace flashinfer { +namespace gpu_iface { +namespace mma { + +enum class FragmentType { + row_major, // Row-major matrix layout + col_major, // Column-major matrix layout + accumulator // Accumulator (no layout) }; template -struct fragment_t -{ - using value_type = T; +struct fragment_t { + using value_type = T; #ifdef PLATFORM_CUDA_DEVICE - // flashinfer's generic CUDA implementation uses raw arrays for matrix - // fragments and the interface is designed to accomodate use of raw arrays - // for such use cases. - static constexpr int elements_per_thread = - (frag_type == FragmentType::accumulator) ? 8 - : (sizeof(T) == 1) ? 8 - : 4; - - // Number of 32-bit registers needed - static constexpr int num_regs = (elements_per_thread * sizeof(T) + 3) / 4; - - uint32_t data[num_regs]; - - // Provide array-like access - __device__ __forceinline__ T &operator[](int i) - { - return reinterpret_cast(data)[i]; - } - __device__ __forceinline__ const T &operator[](int i) const - { - return reinterpret_cast(data)[i]; - } + // flashinfer's generic CUDA implementation uses raw arrays for matrix + // fragments and the interface is designed to accomodate use of raw arrays + // for such use cases. + static constexpr int elements_per_thread = (frag_type == FragmentType::accumulator) ? 8 + : (sizeof(T) == 1) ? 8 + : 4; - // Get number of elements this thread holds - __device__ __forceinline__ constexpr int size() const - { - return elements_per_thread; - } + // Number of 32-bit registers needed + static constexpr int num_regs = (elements_per_thread * sizeof(T) + 3) / 4; + + uint32_t data[num_regs]; - // Get raw pointer for MMA operations - __device__ __forceinline__ uint32_t *raw_ptr() { return data; } - __device__ __forceinline__ const uint32_t *raw_ptr() const { return data; } + // Provide array-like access + __device__ __forceinline__ T& operator[](int i) { return reinterpret_cast(data)[i]; } + __device__ __forceinline__ const T& operator[](int i) const { + return reinterpret_cast(data)[i]; + } + + // Get number of elements this thread holds + __device__ __forceinline__ constexpr int size() const { return elements_per_thread; } + + // Get raw pointer for MMA operations + __device__ __forceinline__ uint32_t* raw_ptr() { return data; } + __device__ __forceinline__ const uint32_t* raw_ptr() const { return data; } #elif defined(PLATFORM_HIP_DEVICE) - // AMD: Use rocWMMA fragments - using rocwmma_layout = typename std::conditional< - frag_type == FragmentType::row_major, - rocwmma::row_major, - typename std::conditional::type>::type; - - using rocwmma_matrix_t = typename std::conditional< - frag_type == FragmentType::row_major, - rocwmma::matrix_a, - typename std::conditional::type>::type; - - // Select appropriate fragment type based on whether it's accumulator or not - using rocwmma_frag_t = typename std::conditional< - frag_type == FragmentType::accumulator, - rocwmma::fragment, - rocwmma::fragment>::type; - - rocwmma_frag_t frag; - - // Provide array-like access that maps to rocWMMA fragment - __device__ __forceinline__ T operator[](int i) const { return frag.x[i]; } - - // For non-const access, we need to provide a setter since we can't return a - // reference - __device__ __forceinline__ void set(int i, T value) { frag.x[i] = value; } - - // Get number of elements this thread holds - __device__ __forceinline__ int size() const { return frag.num_elements; } - - // Get raw pointer for operations that need it - __device__ __forceinline__ rocwmma_frag_t *raw_ptr() { return &frag; } - __device__ __forceinline__ const rocwmma_frag_t *raw_ptr() const - { - return &frag; - } + // AMD: Use rocWMMA fragments + using rocwmma_layout = + typename std::conditional::type>::type; + + using rocwmma_matrix_t = typename std::conditional< + frag_type == FragmentType::row_major, rocwmma::matrix_a, + typename std::conditional::type>::type; + + // Select appropriate fragment type based on whether it's accumulator or not + using rocwmma_frag_t = typename std::conditional< + frag_type == FragmentType::accumulator, rocwmma::fragment, + rocwmma::fragment >::type; + + rocwmma_frag_t frag; + + // Provide array-like access that maps to rocWMMA fragment + __device__ __forceinline__ T operator[](int i) const { return frag.x[i]; } + + // For non-const access, we need to provide a setter since we can't return a + // reference + __device__ __forceinline__ void set(int i, T value) { frag.x[i] = value; } + + // Get number of elements this thread holds + __device__ __forceinline__ int size() const { return frag.num_elements; } + + // Get raw pointer for operations that need it + __device__ __forceinline__ rocwmma_frag_t* raw_ptr() { return &frag; } + __device__ __forceinline__ const rocwmma_frag_t* raw_ptr() const { return &frag; } #endif - // Common interface - update fill method to use setter for HIP - __device__ __forceinline__ void fill(T value) - { + // Common interface - update fill method to use setter for HIP + __device__ __forceinline__ void fill(T value) { #ifdef PLATFORM_CUDA_DEVICE #pragma unroll - for (int i = 0; i < elements_per_thread; ++i) { - (*this)[i] = value; - } + for (int i = 0; i < elements_per_thread; ++i) { + (*this)[i] = value; + } #elif defined(PLATFORM_HIP_DEVICE) - rocwmma::fill_fragment(frag, value); + rocwmma::fill_fragment(frag, value); #endif - } + } }; // Convenience typedefs for common fragment types template -using row_major_fragment_m16n16k16 = - fragment_t; +using row_major_fragment_m16n16k16 = fragment_t; template -using col_major_fragment_m16n16k16 = - fragment_t; +using col_major_fragment_m16n16k16 = fragment_t; template -using accumulator_fragment_m16n16k16 = - fragment_t; +using accumulator_fragment_m16n16k16 = fragment_t; // Helper to get compile-time fragment size -template struct fragment_traits -{ +template +struct fragment_traits { #ifdef PLATFORM_CUDA_DEVICE - static constexpr int size = Fragment::elements_per_thread; + static constexpr int size = Fragment::elements_per_thread; #elif defined(PLATFORM_HIP_DEVICE) - // For HIP, we can't make this constexpr, so provide a device function - __device__ static int get_size(const Fragment &f) { return f.size(); } + // For HIP, we can't make this constexpr, so provide a device function + __device__ static int get_size(const Fragment& f) { return f.size(); } #endif }; -} // namespace mma -} // namespace gpu_iface -} // namespace flashinfer +} // namespace mma +} // namespace gpu_iface +} // namespace flashinfer diff --git a/libflashinfer/include/gpu_iface/gpu_runtime_compat.hpp b/libflashinfer/include/gpu_iface/gpu_runtime_compat.hpp index 5c46f62602..6bc083184c 100644 --- a/libflashinfer/include/gpu_iface/gpu_runtime_compat.hpp +++ b/libflashinfer/include/gpu_iface/gpu_runtime_compat.hpp @@ -40,8 +40,8 @@ #elif defined(PLATFORM_HIP_DEVICE) #define gpuGetDevice hipGetDevice #define gpuLaunchKernel hipLaunchKernel -#define gpuFuncSetAttribute(func, attr, val) \ - hipFuncSetAttribute(reinterpret_cast(func), attr, val) +#define gpuFuncSetAttribute(func, attr, val) \ + hipFuncSetAttribute(reinterpret_cast(func), attr, val) #define gpuDeviceGetAttribute hipDeviceGetAttribute #define gpuDeviceSynchronize hipDeviceSynchronize #endif @@ -66,30 +66,22 @@ // Function attribute enums (these have different names) #if defined(PLATFORM_CUDA_DEVICE) -#define gpuFuncAttributeMaxDynamicSharedMemorySize \ - cudaFuncAttributeMaxDynamicSharedMemorySize -#define gpuFuncAttributePreferredSharedMemoryCarveout \ - cudaFuncAttributePreferredSharedMemoryCarveout +#define gpuFuncAttributeMaxDynamicSharedMemorySize cudaFuncAttributeMaxDynamicSharedMemorySize +#define gpuFuncAttributePreferredSharedMemoryCarveout cudaFuncAttributePreferredSharedMemoryCarveout #elif defined(PLATFORM_HIP_DEVICE) -#define gpuFuncAttributeMaxDynamicSharedMemorySize \ - hipFuncAttributeMaxDynamicSharedMemorySize -#define gpuFuncAttributePreferredSharedMemoryCarveout \ - hipFuncAttributePreferredSharedMemoryCarveout +#define gpuFuncAttributeMaxDynamicSharedMemorySize hipFuncAttributeMaxDynamicSharedMemorySize +#define gpuFuncAttributePreferredSharedMemoryCarveout hipFuncAttributePreferredSharedMemoryCarveout #endif // Device attribute enums (different names) #if defined(PLATFORM_CUDA_DEVICE) #define gpuDevAttrMultiProcessorCount cudaDevAttrMultiProcessorCount -#define gpuDevAttrMaxSharedMemoryPerMultiProcessor \ - cudaDevAttrMaxSharedMemoryPerMultiprocessor -#define gpuOccupancyMaxActiveBlocksPerMultiprocessor \ - cudaOccupancyMaxActiveBlocksPerMultiprocessor +#define gpuDevAttrMaxSharedMemoryPerMultiProcessor cudaDevAttrMaxSharedMemoryPerMultiprocessor +#define gpuOccupancyMaxActiveBlocksPerMultiprocessor cudaOccupancyMaxActiveBlocksPerMultiprocessor #elif defined(PLATFORM_HIP_DEVICE) #define gpuDevAttrMultiProcessorCount hipDeviceAttributeMultiprocessorCount -#define gpuDevAttrMaxSharedMemoryPerMultiProcessor \ - hipDeviceAttributeMaxSharedMemPerMultiprocessor -#define gpuOccupancyMaxActiveBlocksPerMultiprocessor \ - hipOccupancyMaxActiveBlocksPerMultiprocessor +#define gpuDevAttrMaxSharedMemoryPerMultiProcessor hipDeviceAttributeMaxSharedMemPerMultiprocessor +#define gpuOccupancyMaxActiveBlocksPerMultiprocessor hipOccupancyMaxActiveBlocksPerMultiprocessor #endif // Event iface @@ -134,28 +126,26 @@ #endif // CUDA error checking macro (replaces FLASHINFER_CUDA_CALL) -#define FI_GPU_CALL(call) \ - do { \ - gpuError_t err = (call); \ - if (err != gpuSuccess) { \ - std::ostringstream err_msg; \ - err_msg << "GPU error: " << gpuGetErrorString(err) << " at " \ - << __FILE__ << ":" << __LINE__; \ - throw std::runtime_error(err_msg.str()); \ - } \ - } while (0) +#define FI_GPU_CALL(call) \ + do { \ + gpuError_t err = (call); \ + if (err != gpuSuccess) { \ + std::ostringstream err_msg; \ + err_msg << "GPU error: " << gpuGetErrorString(err) << " at " << __FILE__ << ":" << __LINE__; \ + throw std::runtime_error(err_msg.str()); \ + } \ + } while (0) -inline int getMaxSharedMemPerMultiprocessor(int dev_id) -{ - int max_smem_per_sm = 0; +inline int getMaxSharedMemPerMultiprocessor(int dev_id) { + int max_smem_per_sm = 0; #if defined(PLATFORM_CUDA_DEVICE) - FI_GPU_CALL(gpuDeviceGetAttribute( - &max_smem_per_sm, gpuDevAttrMaxSharedMemoryPerMultiProcessor, dev_id)); + FI_GPU_CALL( + gpuDeviceGetAttribute(&max_smem_per_sm, gpuDevAttrMaxSharedMemoryPerMultiProcessor, dev_id)); #elif defined(PLATFORM_HIP_DEVICE) - hipDeviceProp_t deviceProp; - FI_GPU_CALL(hipGetDeviceProperties(&deviceProp, dev_id)); - max_smem_per_sm = deviceProp.sharedMemPerMultiprocessor; + hipDeviceProp_t deviceProp; + FI_GPU_CALL(hipGetDeviceProperties(&deviceProp, dev_id)); + max_smem_per_sm = deviceProp.sharedMemPerMultiprocessor; #endif - return max_smem_per_sm; + return max_smem_per_sm; } diff --git a/libflashinfer/include/gpu_iface/math_ops.hpp b/libflashinfer/include/gpu_iface/math_ops.hpp index 98b124a1a3..160d3b6b3b 100644 --- a/libflashinfer/include/gpu_iface/math_ops.hpp +++ b/libflashinfer/include/gpu_iface/math_ops.hpp @@ -12,12 +12,9 @@ #include "backend/hip/math_hip.h" #endif -namespace flashinfer -{ -namespace gpu_iface -{ -namespace math -{ +namespace flashinfer { +namespace gpu_iface { +namespace math { #if defined(PLATFORM_CUDA_DEVICE) // Re-export CUDA math functions with same names using flashinfer::math::half2_as_uint32; @@ -48,6 +45,6 @@ using flashinfer::math::tanh; // Add other functions as needed #endif -} // namespace math -} // namespace gpu_iface -} // namespace flashinfer +} // namespace math +} // namespace gpu_iface +} // namespace flashinfer diff --git a/libflashinfer/include/gpu_iface/memory_ops.hpp b/libflashinfer/include/gpu_iface/memory_ops.hpp index 94048502cd..aea487b99f 100644 --- a/libflashinfer/include/gpu_iface/memory_ops.hpp +++ b/libflashinfer/include/gpu_iface/memory_ops.hpp @@ -5,29 +5,24 @@ #pragma once #include "platform.hpp" -namespace flashinfer -{ -namespace gpu_iface -{ -namespace memory -{ +namespace flashinfer { +namespace gpu_iface { +namespace memory { /** * @brief Control options for shared memory fill behavior */ -enum class SharedMemFillMode -{ - kFillZero, // Fill zero to shared memory when predicate is false - kNoFill // Do not fill zero to shared memory when predicate is false +enum class SharedMemFillMode { + kFillZero, // Fill zero to shared memory when predicate is false + kNoFill // Do not fill zero to shared memory when predicate is false }; /** * @brief Control options for memory prefetch behavior */ -enum class PrefetchMode -{ - kNoPrefetch, // Do not fetch additional data from global memory to L2 - kPrefetch // Fetch additional data from global memory to L2 +enum class PrefetchMode { + kNoPrefetch, // Do not fetch additional data from global memory to L2 + kPrefetch // Fetch additional data from global memory to L2 }; // Include platform-specific implementations @@ -49,9 +44,9 @@ __device__ __forceinline__ void commit_group() { mem_detail::commit_group(); } * * @tparam N Number of most recent groups to wait for (0-7) */ -template __device__ __forceinline__ void wait_group() -{ - mem_detail::wait_group(); +template +__device__ __forceinline__ void wait_group() { + mem_detail::wait_group(); } /** @@ -63,16 +58,14 @@ template __device__ __forceinline__ void wait_group() * @param gmem_ptr Source global memory pointer */ template -__device__ __forceinline__ void load_128b(T *smem_ptr, const T *gmem_ptr) -{ - mem_detail::load_128b(smem_ptr, gmem_ptr); +__device__ __forceinline__ void load_128b(T* smem_ptr, const T* gmem_ptr) { + mem_detail::load_128b(smem_ptr, gmem_ptr); } template -__device__ __forceinline__ void load_64b(T *smem_ptr, const T *gmem_ptr) -{ +__device__ __forceinline__ void load_64b(T* smem_ptr, const T* gmem_ptr) { #if defined(PLATFORM_HIP_DEVICE) - mem_detail::load_64b(smem_ptr, gmem_ptr); + mem_detail::load_64b(smem_ptr, gmem_ptr); #else #error "load_64b not implemented for this platform" #endif @@ -89,20 +82,14 @@ __device__ __forceinline__ void load_64b(T *smem_ptr, const T *gmem_ptr) * @param predicate Condition for executing the load */ template -__device__ __forceinline__ void -pred_load_128b(T *smem_ptr, const T *gmem_ptr, bool predicate) -{ - mem_detail::pred_load_128b(smem_ptr, gmem_ptr, - predicate); +__device__ __forceinline__ void pred_load_128b(T* smem_ptr, const T* gmem_ptr, bool predicate) { + mem_detail::pred_load_128b(smem_ptr, gmem_ptr, predicate); } template -__device__ __forceinline__ void -pred_load_64b(T *smem_ptr, const T *gmem_ptr, bool predicate) -{ +__device__ __forceinline__ void pred_load_64b(T* smem_ptr, const T* gmem_ptr, bool predicate) { #if defined(PLATFORM_HIP_DEVICE) - mem_detail::pred_load_64b(smem_ptr, gmem_ptr, - predicate); + mem_detail::pred_load_64b(smem_ptr, gmem_ptr, predicate); #else #error "pred_load_64b not implemented for this platform" #endif @@ -118,9 +105,8 @@ pred_load_64b(T *smem_ptr, const T *gmem_ptr, bool predicate) * @param gmem_ptr Source global memory pointer */ template -__device__ __forceinline__ void load(T *smem_ptr, const T *gmem_ptr) -{ - mem_detail::load(smem_ptr, gmem_ptr); +__device__ __forceinline__ void load(T* smem_ptr, const T* gmem_ptr) { + mem_detail::load(smem_ptr, gmem_ptr); } /** @@ -134,17 +120,11 @@ __device__ __forceinline__ void load(T *smem_ptr, const T *gmem_ptr) * @param gmem_ptr Source global memory pointer * @param predicate Condition for executing the load */ -template -__device__ __forceinline__ void -pred_load(T *smem_ptr, const T *gmem_ptr, bool predicate) -{ - mem_detail::pred_load(smem_ptr, gmem_ptr, - predicate); +template +__device__ __forceinline__ void pred_load(T* smem_ptr, const T* gmem_ptr, bool predicate) { + mem_detail::pred_load(smem_ptr, gmem_ptr, predicate); } -} // namespace memory -} // namespace gpu_iface -} // namespace flashinfer +} // namespace memory +} // namespace gpu_iface +} // namespace flashinfer diff --git a/libflashinfer/include/gpu_iface/mma_ops.hpp b/libflashinfer/include/gpu_iface/mma_ops.hpp index 97ae7bf506..fed85d69d2 100644 --- a/libflashinfer/include/gpu_iface/mma_ops.hpp +++ b/libflashinfer/include/gpu_iface/mma_ops.hpp @@ -15,12 +15,9 @@ namespace mma_detail = flashinfer::gpu_iface::mma_impl::cuda; namespace mma_detail = flashinfer::gpu_iface::mma_impl::hip; #endif -namespace flashinfer -{ -namespace gpu_iface -{ -namespace mma -{ +namespace flashinfer { +namespace gpu_iface { +namespace mma { /*! * \brief Loads data from shared memory to fragment @@ -32,26 +29,23 @@ namespace mma // inside mma there is impl of load template -__device__ __forceinline__ void load_fragment(uint32_t *R, const T *smem_ptr) -{ - mma_detail::load_fragment(R, smem_ptr); +__device__ __forceinline__ void load_fragment(uint32_t* R, const T* smem_ptr) { + mma_detail::load_fragment(R, smem_ptr); } template -__device__ __forceinline__ void -load_fragment_transpose(uint32_t *R, const T *smem_ptr, uint32_t stride) -{ - mma_detail::load_fragment_transpose(R, smem_ptr, stride); +__device__ __forceinline__ void load_fragment_transpose(uint32_t* R, const T* smem_ptr, + uint32_t stride) { + mma_detail::load_fragment_transpose(R, smem_ptr, stride); } #if defined(PLATFORM_HIP_DEVICE) template -__device__ __forceinline__ void -load_fragment_transpose_4x4_half_registers(uint32_t *R, const T *smem_ptr) -{ - static_assert(std::is_same::value, - "Only __half is supported for the 4x4 register transpose"); - mma_detail::load_fragment_4x4_half_registers(R, smem_ptr); +__device__ __forceinline__ void load_fragment_transpose_4x4_half_registers(uint32_t* R, + const T* smem_ptr) { + static_assert(std::is_same::value, + "Only __half is supported for the 4x4 register transpose"); + mma_detail::load_fragment_4x4_half_registers(R, smem_ptr); } #endif @@ -66,18 +60,16 @@ load_fragment_transpose_4x4_half_registers(uint32_t *R, const T *smem_ptr) * \param B pointer to the fragment of matrix B */ template -__device__ __forceinline__ void -mma_sync_m16n16k16_row_col_f16f16f32(float *C, uint32_t *A, uint32_t *B) -{ - mma_detail::mma_sync_m16n16k16_row_col_f16f16f32(C, A, B); +__device__ __forceinline__ void mma_sync_m16n16k16_row_col_f16f16f32(float* C, uint32_t* A, + uint32_t* B) { + mma_detail::mma_sync_m16n16k16_row_col_f16f16f32(C, A, B); } template -__device__ __forceinline__ void m16k16_rowsum_f16f16f32(float *d, DType *s) -{ - mma_detail::m16k16_rowsum_f16f16f32(d, s); +__device__ __forceinline__ void m16k16_rowsum_f16f16f32(float* d, DType* s) { + mma_detail::m16k16_rowsum_f16f16f32(d, s); } -} // namespace mma -} // namespace gpu_iface -} // namespace flashinfer +} // namespace mma +} // namespace gpu_iface +} // namespace flashinfer diff --git a/libflashinfer/include/gpu_iface/mma_types.hpp b/libflashinfer/include/gpu_iface/mma_types.hpp index 3c00bbc7ab..77914cce63 100644 --- a/libflashinfer/include/gpu_iface/mma_types.hpp +++ b/libflashinfer/include/gpu_iface/mma_types.hpp @@ -4,19 +4,15 @@ #pragma once -namespace flashinfer -{ -namespace gpu_iface -{ -namespace mma -{ +namespace flashinfer { +namespace gpu_iface { +namespace mma { -enum class MMAMode -{ - kInit = 0U, - kInplaceUpdate = 1U, +enum class MMAMode { + kInit = 0U, + kInplaceUpdate = 1U, }; -} // namespace mma -} // namespace gpu_iface -} // namespace flashinfer +} // namespace mma +} // namespace gpu_iface +} // namespace flashinfer diff --git a/libflashinfer/include/gpu_iface/platform.hpp b/libflashinfer/include/gpu_iface/platform.hpp index ece3691902..4c143307fe 100644 --- a/libflashinfer/include/gpu_iface/platform.hpp +++ b/libflashinfer/include/gpu_iface/platform.hpp @@ -3,14 +3,11 @@ // SPDX - License - Identifier : Apache - 2.0 #pragma once -#include "macros.hpp" - #include "gpu_runtime_compat.hpp" +#include "macros.hpp" -namespace flashinfer -{ -namespace gpu_iface -{ +namespace flashinfer { +namespace gpu_iface { // Platform-agnostic stream type #if defined(PLATFORM_CUDA_DEVICE) @@ -21,5 +18,5 @@ constexpr int kWarpSize = 64; #endif -} // namespace gpu_iface -} // namespace flashinfer +} // namespace gpu_iface +} // namespace flashinfer diff --git a/libflashinfer/include/gpu_iface/vec_dtypes.hpp b/libflashinfer/include/gpu_iface/vec_dtypes.hpp index b769286a94..95879b0a01 100644 --- a/libflashinfer/include/gpu_iface/vec_dtypes.hpp +++ b/libflashinfer/include/gpu_iface/vec_dtypes.hpp @@ -3,16 +3,14 @@ // SPDX-License-Identifier: Apache-2.0 #pragma once -#include "platform.hpp" #include #include -namespace flashinfer -{ -namespace gpu_iface -{ -namespace vec_dtypes -{ +#include "platform.hpp" + +namespace flashinfer { +namespace gpu_iface { +namespace vec_dtypes { // Include the appropriate backend implementation #if defined(PLATFORM_CUDA_DEVICE) @@ -29,6 +27,6 @@ namespace vec_t_detail = flashinfer::gpu_iface::vec_dtypes::detail::hip; using vec_t_detail::vec_cast; using vec_t_detail::vec_t; -} // namespace vec_dtypes -} // namespace gpu_iface -} // namespace flashinfer +} // namespace vec_dtypes +} // namespace gpu_iface +} // namespace flashinfer diff --git a/libflashinfer/tests/hip/test_apply_llama_rope.cpp b/libflashinfer/tests/hip/test_apply_llama_rope.cpp index 390f4d1648..3620388510 100644 --- a/libflashinfer/tests/hip/test_apply_llama_rope.cpp +++ b/libflashinfer/tests/hip/test_apply_llama_rope.cpp @@ -2,306 +2,266 @@ // // SPDX - License - Identifier : Apache 2.0 +#include + +#include +#include + #include "../../utils/cpu_reference_hip.h" #include "../../utils/utils_hip.h" #include "flashinfer/attention/generic/prefill.cuh" #include "gpu_iface/fastdiv.cuh" #include "gpu_iface/gpu_runtime_compat.hpp" -#include -#include -#include - -namespace -{ +namespace { using QParamType = std::tuple; -template struct TestKernelTraits -{ - static constexpr uint32_t NUM_MMA_D_QK = HEAD_DIM / 16; - static constexpr uint32_t NUM_MMA_D_VO = HEAD_DIM / 16; +template +struct TestKernelTraits { + static constexpr uint32_t NUM_MMA_D_QK = HEAD_DIM / 16; + static constexpr uint32_t NUM_MMA_D_VO = HEAD_DIM / 16; }; template -__global__ void test_init_rope_freq_kernel(float *output_freq, - float rope_rcp_scale, - float rope_rcp_theta) -{ - using KTraits = TestKernelTraits; - - // Allocate local frequency array - float rope_freq[KTraits::NUM_MMA_D_VO / 2][4]; // [2][4] for HEAD_DIM=64 - - // Call the init_rope_freq function from prefill.cuh - flashinfer::init_rope_freq(rope_freq, rope_rcp_scale, - rope_rcp_theta, threadIdx.x); - - // Write frequencies to their correct feature indices - const uint32_t lane_idx = threadIdx.x; - if (lane_idx < 64) { // Only write for valid threads - for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO / 2; ++mma_d) { - for (uint32_t j = 0; j < 4; ++j) { - // Calculate the actual feature index this frequency corresponds - // to - uint32_t feature_idx = - flashinfer::get_feature_index(mma_d, lane_idx, j); - - // Write frequency to the correct feature index in global array - if (feature_idx < HEAD_DIM) { - output_freq[feature_idx] = rope_freq[mma_d][j]; - if (feature_idx + HEAD_DIM / 2 < HEAD_DIM) { - output_freq[feature_idx + HEAD_DIM / 2] = - rope_freq[mma_d][j]; - } - } - } +__global__ void test_init_rope_freq_kernel(float* output_freq, float rope_rcp_scale, + float rope_rcp_theta) { + using KTraits = TestKernelTraits; + + // Allocate local frequency array + float rope_freq[KTraits::NUM_MMA_D_VO / 2][4]; // [2][4] for HEAD_DIM=64 + + // Call the init_rope_freq function from prefill.cuh + flashinfer::init_rope_freq(rope_freq, rope_rcp_scale, rope_rcp_theta, threadIdx.x); + + // Write frequencies to their correct feature indices + const uint32_t lane_idx = threadIdx.x; + if (lane_idx < 64) { // Only write for valid threads + for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO / 2; ++mma_d) { + for (uint32_t j = 0; j < 4; ++j) { + // Calculate the actual feature index this frequency corresponds + // to + uint32_t feature_idx = flashinfer::get_feature_index(mma_d, lane_idx, j); + + // Write frequency to the correct feature index in global array + if (feature_idx < HEAD_DIM) { + output_freq[feature_idx] = rope_freq[mma_d][j]; + if (feature_idx + HEAD_DIM / 2 < HEAD_DIM) { + output_freq[feature_idx + HEAD_DIM / 2] = rope_freq[mma_d][j]; + } } + } } + } } template -__global__ void -test_q_frag_apply_llama_rope_kernel(__half *q_input, - __half *q_output, - uint32_t qo_len, - uint32_t num_qo_heads, - uint32_t kv_len, - float rope_rcp_scale, - float rope_rcp_theta, - flashinfer::uint_fastdiv group_size_fastdiv) -{ - using KTraits = TestKernelTraits; - constexpr uint32_t HALF_ELEMS_PER_THREAD = 4; - constexpr uint32_t INT32_ELEMS_PER_THREAD = 2; - constexpr uint32_t NUM_MMA_D_QK = HEAD_DIM / 16; - - float rope_freq[KTraits::NUM_MMA_D_VO / 2][4]; - flashinfer::init_rope_freq(rope_freq, rope_rcp_scale, - rope_rcp_theta, threadIdx.x); - - const uint32_t lane_idx = threadIdx.x; - const uint32_t warp_idx = blockIdx.x; - - // TODO: Need to check that qo_len is evenly divisible by 16. - for (uint32_t qo_head_idx = 0; qo_head_idx < num_qo_heads; ++qo_head_idx) { - for (uint32_t seq_chunk = 0; seq_chunk < qo_len; seq_chunk += 16) { - - uint32_t seq_idx = seq_chunk + (lane_idx % 16); - if (seq_idx >= qo_len) - continue; - - uint32_t abs_position = seq_idx + kv_len - qo_len; - // Each iteration processes 16*2=32 features (first_half + - // second_half) - for (uint32_t feat_chunk = 0; feat_chunk < NUM_MMA_D_QK / 2; - ++feat_chunk) - { - uint32_t feat_offset_first = feat_chunk * 32; - uint32_t feat_offset_second = feat_offset_first + HEAD_DIM / 2; - - // Load fragments from global memory - __half q_frag_first[HALF_ELEMS_PER_THREAD]; - __half q_frag_second[HALF_ELEMS_PER_THREAD]; - - // Calculate base address for this sequence and head - uint32_t base_offset = qo_head_idx * HEAD_DIM + - seq_idx * (num_qo_heads * HEAD_DIM); - - // Load first half (4 consecutive features per thread) - for (uint32_t i = 0; i < HALF_ELEMS_PER_THREAD; ++i) { - uint32_t feat_idx1 = - flashinfer::get_feature_index(feat_chunk, - lane_idx, i); - uint32_t feat_idx2 = feat_idx1 + HEAD_DIM / 2; - q_frag_first[i] = *(q_input + base_offset + feat_idx1); - q_frag_second[i] = *(q_input + base_offset + feat_idx2); - } - - // Apply RoPE using the validated function - uint32_t mma_di = feat_chunk; - flashinfer::q_frag_apply_llama_rope<__half, - HALF_ELEMS_PER_THREAD>( - q_frag_first, q_frag_second, - rope_freq[mma_di % (KTraits::NUM_MMA_D_VO / 2)], - abs_position, group_size_fastdiv); - - // Store results back to global memory - for (uint32_t i = 0; i < HALF_ELEMS_PER_THREAD; ++i) { - uint32_t feat_idx1 = - flashinfer::get_feature_index(feat_chunk, - lane_idx, i); - uint32_t feat_idx2 = feat_idx1 + HEAD_DIM / 2; - *(q_output + base_offset + feat_idx1) = q_frag_first[i]; - *(q_output + base_offset + feat_idx2) = q_frag_second[i]; - } - } +__global__ void test_q_frag_apply_llama_rope_kernel(__half* q_input, __half* q_output, + uint32_t qo_len, uint32_t num_qo_heads, + uint32_t kv_len, float rope_rcp_scale, + float rope_rcp_theta, + flashinfer::uint_fastdiv group_size_fastdiv) { + using KTraits = TestKernelTraits; + constexpr uint32_t HALF_ELEMS_PER_THREAD = 4; + constexpr uint32_t INT32_ELEMS_PER_THREAD = 2; + constexpr uint32_t NUM_MMA_D_QK = HEAD_DIM / 16; + + float rope_freq[KTraits::NUM_MMA_D_VO / 2][4]; + flashinfer::init_rope_freq(rope_freq, rope_rcp_scale, rope_rcp_theta, threadIdx.x); + + const uint32_t lane_idx = threadIdx.x; + const uint32_t warp_idx = blockIdx.x; + + // TODO: Need to check that qo_len is evenly divisible by 16. + for (uint32_t qo_head_idx = 0; qo_head_idx < num_qo_heads; ++qo_head_idx) { + for (uint32_t seq_chunk = 0; seq_chunk < qo_len; seq_chunk += 16) { + uint32_t seq_idx = seq_chunk + (lane_idx % 16); + if (seq_idx >= qo_len) continue; + + uint32_t abs_position = seq_idx + kv_len - qo_len; + // Each iteration processes 16*2=32 features (first_half + + // second_half) + for (uint32_t feat_chunk = 0; feat_chunk < NUM_MMA_D_QK / 2; ++feat_chunk) { + uint32_t feat_offset_first = feat_chunk * 32; + uint32_t feat_offset_second = feat_offset_first + HEAD_DIM / 2; + + // Load fragments from global memory + __half q_frag_first[HALF_ELEMS_PER_THREAD]; + __half q_frag_second[HALF_ELEMS_PER_THREAD]; + + // Calculate base address for this sequence and head + uint32_t base_offset = qo_head_idx * HEAD_DIM + seq_idx * (num_qo_heads * HEAD_DIM); + + // Load first half (4 consecutive features per thread) + for (uint32_t i = 0; i < HALF_ELEMS_PER_THREAD; ++i) { + uint32_t feat_idx1 = flashinfer::get_feature_index(feat_chunk, lane_idx, i); + uint32_t feat_idx2 = feat_idx1 + HEAD_DIM / 2; + q_frag_first[i] = *(q_input + base_offset + feat_idx1); + q_frag_second[i] = *(q_input + base_offset + feat_idx2); + } + + // Apply RoPE using the validated function + uint32_t mma_di = feat_chunk; + flashinfer::q_frag_apply_llama_rope<__half, HALF_ELEMS_PER_THREAD>( + q_frag_first, q_frag_second, rope_freq[mma_di % (KTraits::NUM_MMA_D_VO / 2)], + abs_position, group_size_fastdiv); + + // Store results back to global memory + for (uint32_t i = 0; i < HALF_ELEMS_PER_THREAD; ++i) { + uint32_t feat_idx1 = flashinfer::get_feature_index(feat_chunk, lane_idx, i); + uint32_t feat_idx2 = feat_idx1 + HEAD_DIM / 2; + *(q_output + base_offset + feat_idx1) = q_frag_first[i]; + *(q_output + base_offset + feat_idx2) = q_frag_second[i]; } + } } + } } template -class LLamaRopeTestFixture : public ::testing::TestWithParam -{ -protected: - uint32_t qo_len, num_qo_heads, head_dim; - std::vector q; - - LLamaRopeTestFixture() - { - const auto ¶ms = GetParam(); - qo_len = std::get<0>(params); - num_qo_heads = std::get<1>(params); - head_dim = std::get<2>(params); - q.resize(qo_len * num_qo_heads * head_dim); +class LLamaRopeTestFixture : public ::testing::TestWithParam { + protected: + uint32_t qo_len, num_qo_heads, head_dim; + std::vector q; + + LLamaRopeTestFixture() { + const auto& params = GetParam(); + qo_len = std::get<0>(params); + num_qo_heads = std::get<1>(params); + head_dim = std::get<2>(params); + q.resize(qo_len * num_qo_heads * head_dim); + } + + void SetUp() override { utils::vec_normal_(q); } + + void TearDown() override {} + + std::vector apply_cpu_rope(size_t offset, float rope_scale = 1.0f, + float rope_theta = 10000.0f) { + return cpu_reference::apply_llama_rope(q.data(), head_dim, offset, rope_scale, rope_theta); + } + + std::vector get_cpu_rope_frequencies(float rope_scale = 1.0f, + float rope_theta = 10000.0f) { + std::vector frequencies(head_dim); + + for (size_t k = 0; k < head_dim; ++k) { + // Extract ONLY the frequency calculation (without position/offset) + float freq_base = float(2 * (k % (head_dim / 2))) / float(head_dim); + float frequency = (1.0f / rope_scale) / std::pow(rope_theta, freq_base); + frequencies[k] = frequency; } - void SetUp() override { utils::vec_normal_(q); } + return frequencies; + } - void TearDown() override {} + std::vector get_gpu_rope_frequencies(float rope_scale = 1.0f, + float rope_theta = 10000.0f) { + // Convert to reciprocal values as expected by GPU kernel + float rope_rcp_scale = 1.0f / rope_scale; + float rope_rcp_theta = 1.0f / rope_theta; - std::vector apply_cpu_rope(size_t offset, - float rope_scale = 1.0f, - float rope_theta = 10000.0f) - { - return cpu_reference::apply_llama_rope(q.data(), head_dim, offset, - rope_scale, rope_theta); - } + // Allocate GPU memory for output (one frequency per feature) + float* d_output_freq; + size_t output_size = head_dim * sizeof(float); + FI_GPU_CALL(hipMalloc(&d_output_freq, output_size)); + FI_GPU_CALL(hipMemset(d_output_freq, 0, output_size)); - std::vector get_cpu_rope_frequencies(float rope_scale = 1.0f, - float rope_theta = 10000.0f) - { - std::vector frequencies(head_dim); - - for (size_t k = 0; k < head_dim; ++k) { - // Extract ONLY the frequency calculation (without position/offset) - float freq_base = float(2 * (k % (head_dim / 2))) / float(head_dim); - float frequency = - (1.0f / rope_scale) / std::pow(rope_theta, freq_base); - frequencies[k] = frequency; - } + // Launch kernel with 64 threads + dim3 grid(1); + dim3 block(64); - return frequencies; + if (head_dim == 64) { + test_init_rope_freq_kernel<64> + <<>>(d_output_freq, rope_rcp_scale, rope_rcp_theta); } - std::vector get_gpu_rope_frequencies(float rope_scale = 1.0f, - float rope_theta = 10000.0f) - { - // Convert to reciprocal values as expected by GPU kernel - float rope_rcp_scale = 1.0f / rope_scale; - float rope_rcp_theta = 1.0f / rope_theta; - - // Allocate GPU memory for output (one frequency per feature) - float *d_output_freq; - size_t output_size = head_dim * sizeof(float); - FI_GPU_CALL(hipMalloc(&d_output_freq, output_size)); - FI_GPU_CALL(hipMemset(d_output_freq, 0, output_size)); - - // Launch kernel with 64 threads - dim3 grid(1); - dim3 block(64); - - if (head_dim == 64) { - test_init_rope_freq_kernel<64><<>>( - d_output_freq, rope_rcp_scale, rope_rcp_theta); - } + FI_GPU_CALL(hipDeviceSynchronize()); - FI_GPU_CALL(hipDeviceSynchronize()); + // Copy all frequencies back + std::vector gpu_frequencies(head_dim); + FI_GPU_CALL( + hipMemcpy(gpu_frequencies.data(), d_output_freq, output_size, hipMemcpyDeviceToHost)); - // Copy all frequencies back - std::vector gpu_frequencies(head_dim); - FI_GPU_CALL(hipMemcpy(gpu_frequencies.data(), d_output_freq, - output_size, hipMemcpyDeviceToHost)); + FI_GPU_CALL(hipFree(d_output_freq)); + return gpu_frequencies; + } - FI_GPU_CALL(hipFree(d_output_freq)); - return gpu_frequencies; - } + std::vector> apply_cpu_rope_all_sequences(size_t kv_len = 1000, + float rope_scale = 1.0f, + float rope_theta = 10000.0f) { + std::vector> results; - std::vector> - apply_cpu_rope_all_sequences(size_t kv_len = 1000, - float rope_scale = 1.0f, - float rope_theta = 10000.0f) - { - std::vector> results; - - DISPATCH_head_dim(head_dim, HEAD_DIM, { - using namespace flashinfer; - tensor_info_t info(qo_len, kv_len, num_qo_heads, num_qo_heads, - QKVLayout::kHND, HEAD_DIM); - - // Apply RoPE to all sequences and heads - for (size_t qo_head_idx = 0; qo_head_idx < num_qo_heads; - ++qo_head_idx) - { - for (size_t q_idx = 0; q_idx < qo_len; ++q_idx) { - size_t offset = q_idx + kv_len - qo_len; - - // Apply RoPE to this specific Q sequence/head - auto q_rotary_local = cpu_reference::apply_llama_rope_debug( - q.data() + - info.get_q_elem_offset(q_idx, qo_head_idx, 0), - head_dim, offset, rope_scale, rope_theta); - - results.push_back(std::move(q_rotary_local)); - } - } - }); - - return results; - } + DISPATCH_head_dim(head_dim, HEAD_DIM, { + using namespace flashinfer; + tensor_info_t info(qo_len, kv_len, num_qo_heads, num_qo_heads, QKVLayout::kHND, HEAD_DIM); - std::vector test_gpu_q_frag_apply_rope(size_t kv_len = 1000, - float rope_scale = 1.0f, - float rope_theta = 10000.0f) - { - // Convert to reciprocal values - float rope_rcp_scale = 1.0f / rope_scale; - float rope_rcp_theta = 1.0f / rope_theta; - uint32_t group_size = 1; // Simple case for now - - // Allocate GPU memory for input and output - __half *d_q_input, *d_q_output; - size_t q_size = q.size() * sizeof(__half); - - FI_GPU_CALL(hipMalloc(&d_q_input, q_size)); - FI_GPU_CALL(hipMalloc(&d_q_output, q_size)); - - // Copy input Q to GPU - FI_GPU_CALL( - hipMemcpy(d_q_input, q.data(), q_size, hipMemcpyHostToDevice)); - FI_GPU_CALL(hipMemset(d_q_output, 0, q_size)); - - // Launch kernel - one block with 64 threads - dim3 grid(1); // Single block for simplicity - dim3 block(64); // CDNA3 wavefront size - - if (head_dim == 64) { - test_q_frag_apply_llama_rope_kernel<64><<>>( - d_q_input, d_q_output, qo_len, num_qo_heads, kv_len, - rope_rcp_scale, rope_rcp_theta, group_size); - } - - FI_GPU_CALL(hipDeviceSynchronize()); + // Apply RoPE to all sequences and heads + for (size_t qo_head_idx = 0; qo_head_idx < num_qo_heads; ++qo_head_idx) { + for (size_t q_idx = 0; q_idx < qo_len; ++q_idx) { + size_t offset = q_idx + kv_len - qo_len; - // Copy results back to CPU - std::vector<__half> gpu_output(q.size()); - FI_GPU_CALL(hipMemcpy(gpu_output.data(), d_q_output, q_size, - hipMemcpyDeviceToHost)); + // Apply RoPE to this specific Q sequence/head + auto q_rotary_local = cpu_reference::apply_llama_rope_debug( + q.data() + info.get_q_elem_offset(q_idx, qo_head_idx, 0), head_dim, offset, + rope_scale, rope_theta); - // Convert to float for comparison - std::vector result(head_dim); - for (size_t i = 0; i < head_dim; ++i) { - result[i] = float(gpu_output[i]); // First sequence, first head + results.push_back(std::move(q_rotary_local)); } + } + }); + + return results; + } + + std::vector test_gpu_q_frag_apply_rope(size_t kv_len = 1000, float rope_scale = 1.0f, + float rope_theta = 10000.0f) { + // Convert to reciprocal values + float rope_rcp_scale = 1.0f / rope_scale; + float rope_rcp_theta = 1.0f / rope_theta; + uint32_t group_size = 1; // Simple case for now + + // Allocate GPU memory for input and output + __half *d_q_input, *d_q_output; + size_t q_size = q.size() * sizeof(__half); + + FI_GPU_CALL(hipMalloc(&d_q_input, q_size)); + FI_GPU_CALL(hipMalloc(&d_q_output, q_size)); + + // Copy input Q to GPU + FI_GPU_CALL(hipMemcpy(d_q_input, q.data(), q_size, hipMemcpyHostToDevice)); + FI_GPU_CALL(hipMemset(d_q_output, 0, q_size)); + + // Launch kernel - one block with 64 threads + dim3 grid(1); // Single block for simplicity + dim3 block(64); // CDNA3 wavefront size + + if (head_dim == 64) { + test_q_frag_apply_llama_rope_kernel<64><<>>(d_q_input, d_q_output, qo_len, + num_qo_heads, kv_len, rope_rcp_scale, + rope_rcp_theta, group_size); + } + + FI_GPU_CALL(hipDeviceSynchronize()); - FI_GPU_CALL(hipFree(d_q_input)); - FI_GPU_CALL(hipFree(d_q_output)); + // Copy results back to CPU + std::vector<__half> gpu_output(q.size()); + FI_GPU_CALL(hipMemcpy(gpu_output.data(), d_q_output, q_size, hipMemcpyDeviceToHost)); - return result; + // Convert to float for comparison + std::vector result(head_dim); + for (size_t i = 0; i < head_dim; ++i) { + result[i] = float(gpu_output[i]); // First sequence, first head } + + FI_GPU_CALL(hipFree(d_q_input)); + FI_GPU_CALL(hipFree(d_q_output)); + + return result; + } }; using LLamaRopeTestWithFP16 = LLamaRopeTestFixture<__half>; -} // namespace +} // namespace // Wrapper to validate freq application // call q_smem_inplace_apply_rotary and copy back results to CPU. @@ -312,96 +272,84 @@ using LLamaRopeTestWithFP16 = LLamaRopeTestFixture<__half>; // Test 2. Copy CPU Q matrix to GPU call freq apply validator // launch kernel -TEST_P(LLamaRopeTestWithFP16, TestInitRopeFreq) -{ - constexpr float RELATIVE_EPSILON = 1e-6f; - size_t num_mismatches = 0; - auto cpu_frequencies = this->get_cpu_rope_frequencies(); - auto gpu_frequencies = this->get_gpu_rope_frequencies(); - - // Print side-by-side comparison for easier visual inspection - std::cout << "\nSide-by-side comparison:\n"; - std::cout << "Index\tCPU\t\tGPU\t\tDifference\n"; - std::cout << "-----\t---\t\t---\t\t----------\n"; - - for (size_t i = 0; i < std::min(16u, this->head_dim); ++i) { - float diff = std::abs(cpu_frequencies[i] - gpu_frequencies[i]); - std::cout << i << "\t" << cpu_frequencies[i] << "\t\t" - << gpu_frequencies[i] << "\t\t" << diff << std::endl; - } - - ASSERT_EQ(cpu_frequencies.size(), this->head_dim); - ASSERT_EQ(gpu_frequencies.size(), this->head_dim); - - for (auto i = 0ul; i < cpu_frequencies.size(); ++i) { - auto diff = std::abs(cpu_frequencies[i] - gpu_frequencies[i]); - if (diff >= RELATIVE_EPSILON) { - std::cout << "Diff : " << diff << " at feature index " << i << " " - << "cpu_frequencies[i]: " << cpu_frequencies[i] << " " - << "gpu_frequencies[i]: " << gpu_frequencies[i] << '\n'; - ++num_mismatches; - } +TEST_P(LLamaRopeTestWithFP16, TestInitRopeFreq) { + constexpr float RELATIVE_EPSILON = 1e-6f; + size_t num_mismatches = 0; + auto cpu_frequencies = this->get_cpu_rope_frequencies(); + auto gpu_frequencies = this->get_gpu_rope_frequencies(); + + // Print side-by-side comparison for easier visual inspection + std::cout << "\nSide-by-side comparison:\n"; + std::cout << "Index\tCPU\t\tGPU\t\tDifference\n"; + std::cout << "-----\t---\t\t---\t\t----------\n"; + + for (size_t i = 0; i < std::min(16u, this->head_dim); ++i) { + float diff = std::abs(cpu_frequencies[i] - gpu_frequencies[i]); + std::cout << i << "\t" << cpu_frequencies[i] << "\t\t" << gpu_frequencies[i] << "\t\t" << diff + << std::endl; + } + + ASSERT_EQ(cpu_frequencies.size(), this->head_dim); + ASSERT_EQ(gpu_frequencies.size(), this->head_dim); + + for (auto i = 0ul; i < cpu_frequencies.size(); ++i) { + auto diff = std::abs(cpu_frequencies[i] - gpu_frequencies[i]); + if (diff >= RELATIVE_EPSILON) { + std::cout << "Diff : " << diff << " at feature index " << i << " " + << "cpu_frequencies[i]: " << cpu_frequencies[i] << " " + << "gpu_frequencies[i]: " << gpu_frequencies[i] << '\n'; + ++num_mismatches; } + } - ASSERT_EQ(num_mismatches, 0); + ASSERT_EQ(num_mismatches, 0); } -TEST_P(LLamaRopeTestWithFP16, VectorSizeIsCorrect) -{ - const auto ¶ms = GetParam(); - size_t expected_size = - std::get<0>(params) * std::get<1>(params) * std::get<2>(params); - ASSERT_EQ(this->q.size(), expected_size); +TEST_P(LLamaRopeTestWithFP16, VectorSizeIsCorrect) { + const auto& params = GetParam(); + size_t expected_size = std::get<0>(params) * std::get<1>(params) * std::get<2>(params); + ASSERT_EQ(this->q.size(), expected_size); } -TEST_P(LLamaRopeTestWithFP16, TestQFragApplyRopeComparison) -{ - constexpr float RELATIVE_EPSILON = 1e-2f; - - auto cpu_result = this->apply_cpu_rope(744); - auto gpu_result = this->test_gpu_q_frag_apply_rope(); - - std::cout << "\n=== CPU vs GPU RoPE Application Comparison ===\n"; - std::cout << "CPU result (offset=1000, first 8 features): "; - for (size_t i = 0; i < std::min(8u, this->head_dim); ++i) { - std::cout << cpu_result[i] << " "; - } - std::cout << std::endl; - - std::cout << "GPU result (offset=1000, first 8 features): "; - for (size_t i = 0; i < std::min(8u, this->head_dim); ++i) { - std::cout << gpu_result[i] << " "; - } - std::cout << std::endl; - - // Compare element by element - size_t num_mismatches = 0; - for (size_t i = 0; i < std::min(cpu_result.size(), gpu_result.size()); ++i) - { - float diff = std::abs(cpu_result[i] - gpu_result[i]); - float rel_diff = (std::abs(cpu_result[i]) > 1e-6f) - ? diff / std::abs(cpu_result[i]) - : diff; - - if (rel_diff > RELATIVE_EPSILON) { - std::cout << "Mismatch at feature " << i - << ": CPU=" << cpu_result[i] << " GPU=" << gpu_result[i] - << " diff=" << diff << " rel_diff=" << rel_diff - << std::endl; - ++num_mismatches; - } +TEST_P(LLamaRopeTestWithFP16, TestQFragApplyRopeComparison) { + constexpr float RELATIVE_EPSILON = 1e-2f; + + auto cpu_result = this->apply_cpu_rope(744); + auto gpu_result = this->test_gpu_q_frag_apply_rope(); + + std::cout << "\n=== CPU vs GPU RoPE Application Comparison ===\n"; + std::cout << "CPU result (offset=1000, first 8 features): "; + for (size_t i = 0; i < std::min(8u, this->head_dim); ++i) { + std::cout << cpu_result[i] << " "; + } + std::cout << std::endl; + + std::cout << "GPU result (offset=1000, first 8 features): "; + for (size_t i = 0; i < std::min(8u, this->head_dim); ++i) { + std::cout << gpu_result[i] << " "; + } + std::cout << std::endl; + + // Compare element by element + size_t num_mismatches = 0; + for (size_t i = 0; i < std::min(cpu_result.size(), gpu_result.size()); ++i) { + float diff = std::abs(cpu_result[i] - gpu_result[i]); + float rel_diff = (std::abs(cpu_result[i]) > 1e-6f) ? diff / std::abs(cpu_result[i]) : diff; + + if (rel_diff > RELATIVE_EPSILON) { + std::cout << "Mismatch at feature " << i << ": CPU=" << cpu_result[i] + << " GPU=" << gpu_result[i] << " diff=" << diff << " rel_diff=" << rel_diff + << std::endl; + ++num_mismatches; } + } - std::cout << "Total mismatches: " << num_mismatches << " out of " - << head_dim << std::endl; + std::cout << "Total mismatches: " << num_mismatches << " out of " << head_dim << std::endl; - EXPECT_EQ(num_mismatches, 0) - << "Found mismatches between CPU and GPU RoPE application"; + EXPECT_EQ(num_mismatches, 0) << "Found mismatches between CPU and GPU RoPE application"; } INSTANTIATE_TEST_SUITE_P( - LLamaRopeTestWithFP16, - LLamaRopeTestWithFP16, - ::testing::Values( - std::make_tuple(256, 1, 64) // qo_len=256, num_qo_heads=1, head_dim=64 - )); + LLamaRopeTestWithFP16, LLamaRopeTestWithFP16, + ::testing::Values(std::make_tuple(256, 1, 64) // qo_len=256, num_qo_heads=1, head_dim=64 + )); diff --git a/libflashinfer/tests/hip/test_batch_decode.cpp b/libflashinfer/tests/hip/test_batch_decode.cpp index 4073e2d64f..30d8cef39a 100644 --- a/libflashinfer/tests/hip/test_batch_decode.cpp +++ b/libflashinfer/tests/hip/test_batch_decode.cpp @@ -3,235 +3,206 @@ // // SPDX - License - Identifier : Apache 2.0 -#include "flashinfer/attention/generic/decode.cuh" -#include "flashinfer/attention/generic/default_decode_params.cuh" -#include "flashinfer/attention/generic/variants.cuh" +#include + +#include #include "../../utils/cpu_reference_hip.h" #include "../../utils/flashinfer_batch_decode_test_ops.hip.h" #include "../../utils/utils_hip.h" - -#include - -#include +#include "flashinfer/attention/generic/decode.cuh" +#include "flashinfer/attention/generic/default_decode_params.cuh" +#include "flashinfer/attention/generic/variants.cuh" using namespace flashinfer; constexpr QKVLayout kv_layout = QKVLayout::kNHD; template -std::pair -nan_detection_and_accuracy(const std::vector &o_host, - const std::vector &o_ref, - uint64_t batch_size, - uint64_t num_qo_heads, - uint64_t head_dim) -{ - - uint64_t num_result_errors_atol_1e_3_rtol_1e_3 = 0; - bool nan_detected = false; - uint64_t num_values = batch_size * num_qo_heads * head_dim; - for (size_t i = 0; i < o_host.size(); ++i) { - float o_host_value = fi::con::explicit_casting(o_host[i]); - float o_ref_value = fi::con::explicit_casting(o_ref[i]); - if (std::isnan(o_host_value) || std::isnan(o_ref_value)) { - nan_detected = true; - } - num_result_errors_atol_1e_3_rtol_1e_3 += - (!utils::isclose(o_host_value, o_ref_value, 1e-3, 1e-3)); +std::pair nan_detection_and_accuracy(const std::vector& o_host, + const std::vector& o_ref, + uint64_t batch_size, uint64_t num_qo_heads, + uint64_t head_dim) { + uint64_t num_result_errors_atol_1e_3_rtol_1e_3 = 0; + bool nan_detected = false; + uint64_t num_values = batch_size * num_qo_heads * head_dim; + for (size_t i = 0; i < o_host.size(); ++i) { + float o_host_value = fi::con::explicit_casting(o_host[i]); + float o_ref_value = fi::con::explicit_casting(o_ref[i]); + if (std::isnan(o_host_value) || std::isnan(o_ref_value)) { + nan_detected = true; } + num_result_errors_atol_1e_3_rtol_1e_3 += + (!utils::isclose(o_host_value, o_ref_value, 1e-3, 1e-3)); + } - float result_accuracy = - 1. - float(num_result_errors_atol_1e_3_rtol_1e_3) / float(num_values); + float result_accuracy = 1. - float(num_result_errors_atol_1e_3_rtol_1e_3) / float(num_values); - return {result_accuracy, nan_detected}; + return {result_accuracy, nan_detected}; } template -void _TestBatchDecodingKernelCorrectness(size_t page_size, - size_t batch_size, - size_t num_qo_heads, - size_t num_kv_heads, - size_t head_dim, - PosEncodingMode pos_encoding_mode) -{ - - std::vector seq_lens(batch_size); - utils::vec_randint_(seq_lens, 1, 1024); - std::vector append_indptr{0}; - for (size_t i = 0; i < batch_size; ++i) { - append_indptr.push_back(append_indptr.back() + seq_lens[i]); +void _TestBatchDecodingKernelCorrectness(size_t page_size, size_t batch_size, size_t num_qo_heads, + size_t num_kv_heads, size_t head_dim, + PosEncodingMode pos_encoding_mode) { + std::vector seq_lens(batch_size); + utils::vec_randint_(seq_lens, 1, 1024); + std::vector append_indptr{0}; + for (size_t i = 0; i < batch_size; ++i) { + append_indptr.push_back(append_indptr.back() + seq_lens[i]); + } + + std::vector q; + std::vector o_ref; + std::vector k_data; + std::vector v_data; + std::vector kv_indptr{0}; + std::vector kv_indices; + std::vector kv_last_page_len; + size_t page_counter = 0; + + std::vector> keys, values; + for (size_t i = 0; i < batch_size; ++i) { + size_t seq_len = seq_lens[i]; + size_t num_pages = (seq_len + page_size - 1) / page_size; + size_t last_page_len = (seq_len - 1) % page_size + 1; + std::vector qi(num_qo_heads * head_dim); + std::vector ki(seq_len * num_kv_heads * head_dim), + vi(seq_len * num_kv_heads * head_dim); + utils::vec_normal_(qi); + utils::vec_normal_(ki); + utils::vec_normal_(vi); + + // compute reference output + std::vector o_ref_i = cpu_reference::single_mha( + qi, ki, vi, 1, seq_len, num_qo_heads, num_kv_heads, head_dim, false, QKVLayout::kNHD, + pos_encoding_mode); + keys.push_back(ki); + values.push_back(vi); + // append new q and o_ref + q.insert(q.end(), qi.begin(), qi.end()); + o_ref.insert(o_ref.end(), o_ref_i.begin(), o_ref_i.end()); + // append new kv_indptr, kv_indices and kv_last_page_len + kv_last_page_len.push_back(last_page_len); + kv_indptr.push_back(kv_indptr.back() + num_pages); + for (size_t j = 0; j < num_pages; ++j) { + kv_indices.push_back(page_counter++); } - - std::vector q; - std::vector o_ref; - std::vector k_data; - std::vector v_data; - std::vector kv_indptr{0}; - std::vector kv_indices; - std::vector kv_last_page_len; - size_t page_counter = 0; - - std::vector> keys, values; - for (size_t i = 0; i < batch_size; ++i) { - size_t seq_len = seq_lens[i]; - size_t num_pages = (seq_len + page_size - 1) / page_size; - size_t last_page_len = (seq_len - 1) % page_size + 1; - std::vector qi(num_qo_heads * head_dim); - std::vector ki(seq_len * num_kv_heads * head_dim), - vi(seq_len * num_kv_heads * head_dim); - utils::vec_normal_(qi); - utils::vec_normal_(ki); - utils::vec_normal_(vi); - - // compute reference output - std::vector o_ref_i = - cpu_reference::single_mha( - qi, ki, vi, 1, seq_len, num_qo_heads, num_kv_heads, head_dim, - false, QKVLayout::kNHD, pos_encoding_mode); - keys.push_back(ki); - values.push_back(vi); - // append new q and o_ref - q.insert(q.end(), qi.begin(), qi.end()); - o_ref.insert(o_ref.end(), o_ref_i.begin(), o_ref_i.end()); - // append new kv_indptr, kv_indices and kv_last_page_len - kv_last_page_len.push_back(last_page_len); - kv_indptr.push_back(kv_indptr.back() + num_pages); - for (size_t j = 0; j < num_pages; ++j) { - kv_indices.push_back(page_counter++); - } - } - - k_data.resize(page_counter * num_kv_heads * page_size * head_dim); - v_data.resize(page_counter * num_kv_heads * page_size * head_dim); - utils::vec_zero_(k_data); - utils::vec_zero_(v_data); - assert(q.size() == batch_size * num_qo_heads * head_dim); - assert(o_ref.size() == batch_size * num_qo_heads * head_dim); - - paged_kv_t paged_kv_cpu( - num_kv_heads, page_size, head_dim, batch_size, kv_layout, k_data.data(), - v_data.data(), kv_indices.data(), kv_indptr.data(), - kv_last_page_len.data()); - cpu_reference::append_paged_kv_cache( - paged_kv_cpu, keys, values, append_indptr); - - DTypeKV *k_data_device; - DTypeKV *v_data_device; - int32_t *kv_indptr_device; - int32_t *kv_indices_device; - int32_t *kv_last_page_len_device; - DTypeQO *q_device; - DTypeQO *o_device; - - hipMalloc(&k_data_device, k_data.size() * sizeof(DTypeKV)); - hipMalloc(&v_data_device, v_data.size() * sizeof(DTypeKV)); - hipMalloc(&kv_indptr_device, kv_indptr.size() * sizeof(int32_t)); - hipMalloc(&kv_indices_device, kv_indices.size() * sizeof(int32_t)); - hipMalloc(&kv_last_page_len_device, - kv_last_page_len.size() * sizeof(int32_t)); - hipMalloc(&q_device, q.size() * sizeof(DTypeQO)); - hipMalloc(&o_device, o_ref.size() * sizeof(DTypeQO)); - - hipMemcpy(k_data_device, k_data.data(), k_data.size() * sizeof(DTypeKV), - hipMemcpyHostToDevice); - hipMemcpy(v_data_device, v_data.data(), v_data.size() * sizeof(DTypeKV), - hipMemcpyHostToDevice); - hipMemcpy(kv_indptr_device, kv_indptr.data(), - kv_indptr.size() * sizeof(int32_t), hipMemcpyHostToDevice); - hipMemcpy(kv_indices_device, kv_indices.data(), - kv_indices.size() * sizeof(int32_t), hipMemcpyHostToDevice); - hipMemcpy(kv_last_page_len_device, kv_last_page_len.data(), - kv_last_page_len.size() * sizeof(int32_t), hipMemcpyHostToDevice); - hipMemcpy(q_device, q.data(), q.size() * sizeof(DTypeQO), - hipMemcpyHostToDevice); - - // create paged_kv object - paged_kv_t paged_kv( - num_kv_heads, page_size, head_dim, batch_size, kv_layout, k_data_device, - v_data_device, kv_indices_device, kv_indptr_device, - kv_last_page_len_device); - - BatchDecodeHandler handler; - - size_t float_workspace_size_in_bytes = 32 * 1024 * 1024; - char *float_buffer; - hipMalloc(&float_buffer, float_workspace_size_in_bytes * sizeof(char)); - - size_t int_workspace_size_in_bytes = 8 * 1024 * 1024; - char *int_buffer; - hipMalloc(&int_buffer, int_workspace_size_in_bytes * sizeof(char)); - - BatchDecodeHandlerPlan( - &handler, (void *)float_buffer, float_workspace_size_in_bytes, - (void *)int_buffer, int_workspace_size_in_bytes, kv_indptr.data(), - kv_last_page_len.data(), batch_size, num_qo_heads, num_kv_heads, - head_dim, page_size, pos_encoding_mode); - - hipError_t status = - BatchDecodeWithPagedKVCacheWrapper( - &handler, q_device, /*q_rope_offset=*/nullptr, paged_kv, o_device, - /*lse=*/nullptr, num_qo_heads, pos_encoding_mode); - EXPECT_EQ(status, hipSuccess) - << "HIP error: " + std::string(hipGetErrorString(status)); - - // compare result - std::vector o_host(o_ref.size()); - hipMemcpy(o_host.data(), o_device, o_ref.size() * sizeof(DTypeQO), - hipMemcpyDeviceToHost); - - bool is_empty = o_host.empty(); - EXPECT_EQ(is_empty, false) << "Output is empty."; - - auto [result_accuracy, nan_detected] = nan_detection_and_accuracy( - o_host, o_ref, batch_size, num_qo_heads, head_dim); - - std::cout << "page_size=" << page_size << ", num_qo_heads=" << num_qo_heads - << ", num_kv_heads=" << num_kv_heads - << ", batch_size=" << batch_size << ", head_dim=" << head_dim - << ", pos_encoding_mode=" - << PosEncodingModeToString(pos_encoding_mode) - << ", result accuracy (atol=1e-3, rtol=1e-3): " << result_accuracy - << std::endl; - EXPECT_GT(result_accuracy, 0.90) << "Result correctness test failed."; - EXPECT_EQ(nan_detected, false) << "NaN detected."; - - hipFree(k_data_device); - hipFree(v_data_device); - hipFree(kv_indptr_device); - hipFree(kv_indices_device); - hipFree(kv_last_page_len_device); - hipFree(q_device); - hipFree(o_device); - hipFree(float_buffer); - hipFree(int_buffer); + } + + k_data.resize(page_counter * num_kv_heads * page_size * head_dim); + v_data.resize(page_counter * num_kv_heads * page_size * head_dim); + utils::vec_zero_(k_data); + utils::vec_zero_(v_data); + assert(q.size() == batch_size * num_qo_heads * head_dim); + assert(o_ref.size() == batch_size * num_qo_heads * head_dim); + + paged_kv_t paged_kv_cpu( + num_kv_heads, page_size, head_dim, batch_size, kv_layout, k_data.data(), v_data.data(), + kv_indices.data(), kv_indptr.data(), kv_last_page_len.data()); + cpu_reference::append_paged_kv_cache(paged_kv_cpu, keys, values, append_indptr); + + DTypeKV* k_data_device; + DTypeKV* v_data_device; + int32_t* kv_indptr_device; + int32_t* kv_indices_device; + int32_t* kv_last_page_len_device; + DTypeQO* q_device; + DTypeQO* o_device; + + hipMalloc(&k_data_device, k_data.size() * sizeof(DTypeKV)); + hipMalloc(&v_data_device, v_data.size() * sizeof(DTypeKV)); + hipMalloc(&kv_indptr_device, kv_indptr.size() * sizeof(int32_t)); + hipMalloc(&kv_indices_device, kv_indices.size() * sizeof(int32_t)); + hipMalloc(&kv_last_page_len_device, kv_last_page_len.size() * sizeof(int32_t)); + hipMalloc(&q_device, q.size() * sizeof(DTypeQO)); + hipMalloc(&o_device, o_ref.size() * sizeof(DTypeQO)); + + hipMemcpy(k_data_device, k_data.data(), k_data.size() * sizeof(DTypeKV), hipMemcpyHostToDevice); + hipMemcpy(v_data_device, v_data.data(), v_data.size() * sizeof(DTypeKV), hipMemcpyHostToDevice); + hipMemcpy(kv_indptr_device, kv_indptr.data(), kv_indptr.size() * sizeof(int32_t), + hipMemcpyHostToDevice); + hipMemcpy(kv_indices_device, kv_indices.data(), kv_indices.size() * sizeof(int32_t), + hipMemcpyHostToDevice); + hipMemcpy(kv_last_page_len_device, kv_last_page_len.data(), + kv_last_page_len.size() * sizeof(int32_t), hipMemcpyHostToDevice); + hipMemcpy(q_device, q.data(), q.size() * sizeof(DTypeQO), hipMemcpyHostToDevice); + + // create paged_kv object + paged_kv_t paged_kv(num_kv_heads, page_size, head_dim, batch_size, kv_layout, + k_data_device, v_data_device, kv_indices_device, + kv_indptr_device, kv_last_page_len_device); + + BatchDecodeHandler handler; + + size_t float_workspace_size_in_bytes = 32 * 1024 * 1024; + char* float_buffer; + hipMalloc(&float_buffer, float_workspace_size_in_bytes * sizeof(char)); + + size_t int_workspace_size_in_bytes = 8 * 1024 * 1024; + char* int_buffer; + hipMalloc(&int_buffer, int_workspace_size_in_bytes * sizeof(char)); + + BatchDecodeHandlerPlan( + &handler, (void*)float_buffer, float_workspace_size_in_bytes, (void*)int_buffer, + int_workspace_size_in_bytes, kv_indptr.data(), kv_last_page_len.data(), batch_size, + num_qo_heads, num_kv_heads, head_dim, page_size, pos_encoding_mode); + + hipError_t status = BatchDecodeWithPagedKVCacheWrapper( + &handler, q_device, /*q_rope_offset=*/nullptr, paged_kv, o_device, + /*lse=*/nullptr, num_qo_heads, pos_encoding_mode); + EXPECT_EQ(status, hipSuccess) << "HIP error: " + std::string(hipGetErrorString(status)); + + // compare result + std::vector o_host(o_ref.size()); + hipMemcpy(o_host.data(), o_device, o_ref.size() * sizeof(DTypeQO), hipMemcpyDeviceToHost); + + bool is_empty = o_host.empty(); + EXPECT_EQ(is_empty, false) << "Output is empty."; + + auto [result_accuracy, nan_detected] = + nan_detection_and_accuracy(o_host, o_ref, batch_size, num_qo_heads, head_dim); + + std::cout << "page_size=" << page_size << ", num_qo_heads=" << num_qo_heads + << ", num_kv_heads=" << num_kv_heads << ", batch_size=" << batch_size + << ", head_dim=" << head_dim + << ", pos_encoding_mode=" << PosEncodingModeToString(pos_encoding_mode) + << ", result accuracy (atol=1e-3, rtol=1e-3): " << result_accuracy << std::endl; + EXPECT_GT(result_accuracy, 0.90) << "Result correctness test failed."; + EXPECT_EQ(nan_detected, false) << "NaN detected."; + + hipFree(k_data_device); + hipFree(v_data_device); + hipFree(kv_indptr_device); + hipFree(kv_indices_device); + hipFree(kv_last_page_len_device); + hipFree(q_device); + hipFree(o_device); + hipFree(float_buffer); + hipFree(int_buffer); } template -void TestBatchDecodeKernelCorrectness() -{ - for (size_t page_size : {1, 3, 7, 16}) { - for (size_t batch_size : {1, 2, 4, 8}) { - for (size_t num_qo_heads : {32}) { - for (size_t num_kv_heads : {32, 8, 4}) { - for (size_t head_dim : {64, 128, 256}) { - for (size_t pos_encoding_mode : {0U, 1U}) { - _TestBatchDecodingKernelCorrectness( - page_size, batch_size, num_qo_heads, - num_kv_heads, head_dim, - PosEncodingMode(pos_encoding_mode)); - } - } - } +void TestBatchDecodeKernelCorrectness() { + for (size_t page_size : {1, 3, 7, 16}) { + for (size_t batch_size : {1, 2, 4, 8}) { + for (size_t num_qo_heads : {32}) { + for (size_t num_kv_heads : {32, 8, 4}) { + for (size_t head_dim : {64, 128, 256}) { + for (size_t pos_encoding_mode : {0U, 1U}) { + _TestBatchDecodingKernelCorrectness( + page_size, batch_size, num_qo_heads, num_kv_heads, head_dim, + PosEncodingMode(pos_encoding_mode)); } + } } + } } + } } -TEST(FlashInferCorrectnessTest, BatchDecodeKernelCorrectnessTestFP16) -{ - TestBatchDecodeKernelCorrectness<__half, __half>(); +TEST(FlashInferCorrectnessTest, BatchDecodeKernelCorrectnessTestFP16) { + TestBatchDecodeKernelCorrectness<__half, __half>(); } // Disabled for now - Look at https://github.com/AMD-AIOSS/flashinfer/issues/36 @@ -252,8 +223,7 @@ TEST(FlashInferCorrectnessTest, BatchDecodeKernelCorrectnessTestFP16) // } ///************************************************************************/ -int main(int argc, char **argv) -{ - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); } diff --git a/libflashinfer/tests/hip/test_cascade.cpp b/libflashinfer/tests/hip/test_cascade.cpp index 585e74a3b5..e65e6cb82b 100644 --- a/libflashinfer/tests/hip/test_cascade.cpp +++ b/libflashinfer/tests/hip/test_cascade.cpp @@ -3,584 +3,485 @@ // // SPDX - License - Identifier : Apache 2.0 -#include "../../utils/utils_hip.h" -#include "flashinfer/attention/generic/cascade.cuh" -#include "gpu_iface/conversion_utils.h" -#include "gpu_iface/layout.cuh" - +#include #include #include #include #include -#include +#include "../../utils/utils_hip.h" +#include "flashinfer/attention/generic/cascade.cuh" +#include "gpu_iface/conversion_utils.h" +#include "gpu_iface/layout.cuh" using namespace flashinfer; constexpr QKVLayout kv_layout = QKVLayout::kHND; -bool is_prime(int x) -{ - for (int i = 2; i < int(std::sqrt(x)); ++i) { - if (x % i == 0) - return false; - } - return true; +bool is_prime(int x) { + for (int i = 2; i < int(std::sqrt(x)); ++i) { + if (x % i == 0) return false; + } + return true; } template -void _TestVariableLengthMergeKernelCorrectness(size_t seq_len, - size_t num_heads, - size_t head_dim, - bool sparse_s) -{ - const uint32_t max_num_index_sets = 512; - std::vector lengths(seq_len); - utils::vec_randint_(lengths, 1, max_num_index_sets); - std::vector indptr{0}; - for (size_t i = 0; i < seq_len; ++i) { - indptr.push_back(indptr.back() + lengths[i]); - } - std::vector V_padded_host(seq_len * max_num_index_sets * num_heads * - head_dim); - std::vector V_ragged_host(indptr.back() * num_heads * head_dim); - std::vector S_padded_host(seq_len * max_num_index_sets * num_heads); - std::vector S_ragged_host(indptr.back() * num_heads); - - utils::vec_normal_(V_ragged_host); - for (uint32_t j = 0; j < seq_len; ++j) { - std::copy(V_ragged_host.begin() + indptr[j] * num_heads * head_dim, - V_ragged_host.begin() + indptr[j + 1] * num_heads * head_dim, - V_padded_host.begin() + - j * max_num_index_sets * num_heads * head_dim); - } - if (sparse_s) { - for (uint32_t i = 0; i < max_num_index_sets; ++i) { - float fill_val = is_prime(i) ? 10 : -10; - for (uint32_t j = 0; j < seq_len; ++j) { - if (i < lengths[j]) { - std::fill( - S_ragged_host.begin() + (indptr[j] + i) * num_heads, - S_ragged_host.begin() + (indptr[j] + i + 1) * num_heads, - fill_val); - std::fill(S_padded_host.begin() + - (j * max_num_index_sets + i) * num_heads, - S_padded_host.begin() + - (j * max_num_index_sets + i + 1) * num_heads, - fill_val); - } - else { - std::fill(S_padded_host.begin() + - (j * max_num_index_sets + i) * num_heads, - S_padded_host.begin() + - (j * max_num_index_sets + i + 1) * num_heads, - -5e4); - } - } +void _TestVariableLengthMergeKernelCorrectness(size_t seq_len, size_t num_heads, size_t head_dim, + bool sparse_s) { + const uint32_t max_num_index_sets = 512; + std::vector lengths(seq_len); + utils::vec_randint_(lengths, 1, max_num_index_sets); + std::vector indptr{0}; + for (size_t i = 0; i < seq_len; ++i) { + indptr.push_back(indptr.back() + lengths[i]); + } + std::vector V_padded_host(seq_len * max_num_index_sets * num_heads * head_dim); + std::vector V_ragged_host(indptr.back() * num_heads * head_dim); + std::vector S_padded_host(seq_len * max_num_index_sets * num_heads); + std::vector S_ragged_host(indptr.back() * num_heads); + + utils::vec_normal_(V_ragged_host); + for (uint32_t j = 0; j < seq_len; ++j) { + std::copy(V_ragged_host.begin() + indptr[j] * num_heads * head_dim, + V_ragged_host.begin() + indptr[j + 1] * num_heads * head_dim, + V_padded_host.begin() + j * max_num_index_sets * num_heads * head_dim); + } + if (sparse_s) { + for (uint32_t i = 0; i < max_num_index_sets; ++i) { + float fill_val = is_prime(i) ? 10 : -10; + for (uint32_t j = 0; j < seq_len; ++j) { + if (i < lengths[j]) { + std::fill(S_ragged_host.begin() + (indptr[j] + i) * num_heads, + S_ragged_host.begin() + (indptr[j] + i + 1) * num_heads, fill_val); + std::fill(S_padded_host.begin() + (j * max_num_index_sets + i) * num_heads, + S_padded_host.begin() + (j * max_num_index_sets + i + 1) * num_heads, fill_val); + } else { + std::fill(S_padded_host.begin() + (j * max_num_index_sets + i) * num_heads, + S_padded_host.begin() + (j * max_num_index_sets + i + 1) * num_heads, -5e4); } + } } - else { - utils::vec_uniform_(S_ragged_host, -10, 10); - for (uint32_t j = 0; j < seq_len; ++j) { - std::copy(S_ragged_host.begin() + indptr[j] * num_heads, - S_ragged_host.begin() + indptr[j + 1] * num_heads, - S_padded_host.begin() + - (j * max_num_index_sets) * num_heads); - std::fill(S_padded_host.begin() + - (j * max_num_index_sets + indptr[j + 1] - indptr[j]) * - num_heads, - S_padded_host.begin() + - (j + 1) * max_num_index_sets * num_heads, - -5e4); - } - } - - // Allocate device memory using HIP - T *V_padded_device; - T *V_ragged_device; - float *S_padded_device; - float *S_ragged_device; - int32_t *indptr_device; - T *V_merged_0_device; - T *V_merged_1_device; - float *S_merged_0_device; - float *S_merged_1_device; - - hipMalloc(&V_padded_device, V_padded_host.size() * sizeof(T)); - hipMalloc(&V_ragged_device, V_ragged_host.size() * sizeof(T)); - hipMalloc(&S_padded_device, S_padded_host.size() * sizeof(float)); - hipMalloc(&S_ragged_device, S_ragged_host.size() * sizeof(float)); - hipMalloc(&indptr_device, indptr.size() * sizeof(int32_t)); - hipMalloc(&V_merged_0_device, seq_len * num_heads * head_dim * sizeof(T)); - hipMalloc(&V_merged_1_device, seq_len * num_heads * head_dim * sizeof(T)); - hipMalloc(&S_merged_0_device, seq_len * num_heads * sizeof(float)); - hipMalloc(&S_merged_1_device, seq_len * num_heads * sizeof(float)); - - // Copy data from host to device - hipMemcpy(V_padded_device, V_padded_host.data(), - V_padded_host.size() * sizeof(T), hipMemcpyHostToDevice); - hipMemcpy(V_ragged_device, V_ragged_host.data(), - V_ragged_host.size() * sizeof(T), hipMemcpyHostToDevice); - hipMemcpy(S_padded_device, S_padded_host.data(), - S_padded_host.size() * sizeof(float), hipMemcpyHostToDevice); - hipMemcpy(S_ragged_device, S_ragged_host.data(), - S_ragged_host.size() * sizeof(float), hipMemcpyHostToDevice); - hipMemcpy(indptr_device, indptr.data(), indptr.size() * sizeof(int32_t), - hipMemcpyHostToDevice); - - // Initialize merged arrays to zero - hipMemset(V_merged_0_device, 0, seq_len * num_heads * head_dim * sizeof(T)); - hipMemset(V_merged_1_device, 0, seq_len * num_heads * head_dim * sizeof(T)); - hipMemset(S_merged_0_device, 0, seq_len * num_heads * sizeof(float)); - hipMemset(S_merged_1_device, 0, seq_len * num_heads * sizeof(float)); - - // Method 0: use MergeStates on padded data - MergeStates(V_padded_device, S_padded_device, V_merged_0_device, - S_merged_0_device, max_num_index_sets, seq_len, num_heads, - head_dim); - - // Method 1: use VariableLengthMergeStates on ragged data - VariableLengthMergeStates(V_ragged_device, S_ragged_device, indptr_device, - V_merged_1_device, S_merged_1_device, seq_len, - nullptr, num_heads, head_dim); - - // Allocate host memory for results - std::vector V_merged_0_host(seq_len * num_heads * head_dim); - std::vector V_merged_1_host(seq_len * num_heads * head_dim); - std::vector S_merged_0_host(seq_len * num_heads); - std::vector S_merged_1_host(seq_len * num_heads); - - // Copy results from device to host - hipMemcpy(V_merged_0_host.data(), V_merged_0_device, - seq_len * num_heads * head_dim * sizeof(T), - hipMemcpyDeviceToHost); - hipMemcpy(V_merged_1_host.data(), V_merged_1_device, - seq_len * num_heads * head_dim * sizeof(T), - hipMemcpyDeviceToHost); - hipMemcpy(S_merged_0_host.data(), S_merged_0_device, - seq_len * num_heads * sizeof(float), hipMemcpyDeviceToHost); - hipMemcpy(S_merged_1_host.data(), S_merged_1_device, - seq_len * num_heads * sizeof(float), hipMemcpyDeviceToHost); - - // Compare results - size_t num_V_result_errors_atol_1e_3_rtol_1e_3 = 0, - num_S_result_errors_atol_1e_3_rtol_1e_3 = 0; - for (size_t i = 0; i < seq_len * num_heads * head_dim; ++i) { - float V_merged_0_host_value = - fi::con::explicit_casting(V_merged_0_host[i]); - float V_merged_1_host_value = - fi::con::explicit_casting(V_merged_1_host[i]); - EXPECT_FALSE(std::isnan(V_merged_0_host_value)) - << "V_merged_0_host[" << i << "] is nan"; - EXPECT_FALSE(std::isnan(V_merged_1_host_value)) - << "V_merged_1_host[" << i << "] is nan"; - num_V_result_errors_atol_1e_3_rtol_1e_3 += (!utils::isclose( - V_merged_0_host_value, V_merged_1_host_value, 1e-3, 1e-3)); - } - for (size_t i = 0; i < seq_len * num_heads; ++i) { - EXPECT_FALSE(std::isnan(S_merged_0_host[i])) - << "S_merged_0_host[" << i << "] is nan"; - EXPECT_FALSE(std::isnan(S_merged_1_host[i])) - << "S_merged_1_host[" << i << "] is nan"; - num_S_result_errors_atol_1e_3_rtol_1e_3 += (!utils::isclose( - S_merged_0_host[i], S_merged_1_host[i], 1e-3, 1e-3)); + } else { + utils::vec_uniform_(S_ragged_host, -10, 10); + for (uint32_t j = 0; j < seq_len; ++j) { + std::copy(S_ragged_host.begin() + indptr[j] * num_heads, + S_ragged_host.begin() + indptr[j + 1] * num_heads, + S_padded_host.begin() + (j * max_num_index_sets) * num_heads); + std::fill( + S_padded_host.begin() + (j * max_num_index_sets + indptr[j + 1] - indptr[j]) * num_heads, + S_padded_host.begin() + (j + 1) * max_num_index_sets * num_heads, -5e4); } - - float V_result_accuracy = - 1.0 - float(num_V_result_errors_atol_1e_3_rtol_1e_3) / - (seq_len * num_heads * head_dim); - float S_result_accuracy = - 1.0 - - float(num_S_result_errors_atol_1e_3_rtol_1e_3) / (seq_len * num_heads); - std::cout << "seq_len=" << seq_len << ", num_heads=" << num_heads - << ", head_dim=" << head_dim << ", sparse_s=" << sparse_s - << ", V accuracy (atol=1e-3, rtol=1e-3)=" << V_result_accuracy - << ", S accuracy (atol=1e-3, rtol=1e-3)=" << S_result_accuracy - << std::endl; - - EXPECT_GT(V_result_accuracy, 0.99) << "V result correctness test failed."; - EXPECT_GT(S_result_accuracy, 0.99) << "S result correctness test failed."; - - // Free device memory - hipFree(V_padded_device); - hipFree(V_ragged_device); - hipFree(S_padded_device); - hipFree(S_ragged_device); - hipFree(indptr_device); - hipFree(V_merged_0_device); - hipFree(V_merged_1_device); - hipFree(S_merged_0_device); - hipFree(S_merged_1_device); + } + + // Allocate device memory using HIP + T* V_padded_device; + T* V_ragged_device; + float* S_padded_device; + float* S_ragged_device; + int32_t* indptr_device; + T* V_merged_0_device; + T* V_merged_1_device; + float* S_merged_0_device; + float* S_merged_1_device; + + hipMalloc(&V_padded_device, V_padded_host.size() * sizeof(T)); + hipMalloc(&V_ragged_device, V_ragged_host.size() * sizeof(T)); + hipMalloc(&S_padded_device, S_padded_host.size() * sizeof(float)); + hipMalloc(&S_ragged_device, S_ragged_host.size() * sizeof(float)); + hipMalloc(&indptr_device, indptr.size() * sizeof(int32_t)); + hipMalloc(&V_merged_0_device, seq_len * num_heads * head_dim * sizeof(T)); + hipMalloc(&V_merged_1_device, seq_len * num_heads * head_dim * sizeof(T)); + hipMalloc(&S_merged_0_device, seq_len * num_heads * sizeof(float)); + hipMalloc(&S_merged_1_device, seq_len * num_heads * sizeof(float)); + + // Copy data from host to device + hipMemcpy(V_padded_device, V_padded_host.data(), V_padded_host.size() * sizeof(T), + hipMemcpyHostToDevice); + hipMemcpy(V_ragged_device, V_ragged_host.data(), V_ragged_host.size() * sizeof(T), + hipMemcpyHostToDevice); + hipMemcpy(S_padded_device, S_padded_host.data(), S_padded_host.size() * sizeof(float), + hipMemcpyHostToDevice); + hipMemcpy(S_ragged_device, S_ragged_host.data(), S_ragged_host.size() * sizeof(float), + hipMemcpyHostToDevice); + hipMemcpy(indptr_device, indptr.data(), indptr.size() * sizeof(int32_t), hipMemcpyHostToDevice); + + // Initialize merged arrays to zero + hipMemset(V_merged_0_device, 0, seq_len * num_heads * head_dim * sizeof(T)); + hipMemset(V_merged_1_device, 0, seq_len * num_heads * head_dim * sizeof(T)); + hipMemset(S_merged_0_device, 0, seq_len * num_heads * sizeof(float)); + hipMemset(S_merged_1_device, 0, seq_len * num_heads * sizeof(float)); + + // Method 0: use MergeStates on padded data + MergeStates(V_padded_device, S_padded_device, V_merged_0_device, S_merged_0_device, + max_num_index_sets, seq_len, num_heads, head_dim); + + // Method 1: use VariableLengthMergeStates on ragged data + VariableLengthMergeStates(V_ragged_device, S_ragged_device, indptr_device, V_merged_1_device, + S_merged_1_device, seq_len, nullptr, num_heads, head_dim); + + // Allocate host memory for results + std::vector V_merged_0_host(seq_len * num_heads * head_dim); + std::vector V_merged_1_host(seq_len * num_heads * head_dim); + std::vector S_merged_0_host(seq_len * num_heads); + std::vector S_merged_1_host(seq_len * num_heads); + + // Copy results from device to host + hipMemcpy(V_merged_0_host.data(), V_merged_0_device, seq_len * num_heads * head_dim * sizeof(T), + hipMemcpyDeviceToHost); + hipMemcpy(V_merged_1_host.data(), V_merged_1_device, seq_len * num_heads * head_dim * sizeof(T), + hipMemcpyDeviceToHost); + hipMemcpy(S_merged_0_host.data(), S_merged_0_device, seq_len * num_heads * sizeof(float), + hipMemcpyDeviceToHost); + hipMemcpy(S_merged_1_host.data(), S_merged_1_device, seq_len * num_heads * sizeof(float), + hipMemcpyDeviceToHost); + + // Compare results + size_t num_V_result_errors_atol_1e_3_rtol_1e_3 = 0, num_S_result_errors_atol_1e_3_rtol_1e_3 = 0; + for (size_t i = 0; i < seq_len * num_heads * head_dim; ++i) { + float V_merged_0_host_value = fi::con::explicit_casting(V_merged_0_host[i]); + float V_merged_1_host_value = fi::con::explicit_casting(V_merged_1_host[i]); + EXPECT_FALSE(std::isnan(V_merged_0_host_value)) << "V_merged_0_host[" << i << "] is nan"; + EXPECT_FALSE(std::isnan(V_merged_1_host_value)) << "V_merged_1_host[" << i << "] is nan"; + num_V_result_errors_atol_1e_3_rtol_1e_3 += + (!utils::isclose(V_merged_0_host_value, V_merged_1_host_value, 1e-3, 1e-3)); + } + for (size_t i = 0; i < seq_len * num_heads; ++i) { + EXPECT_FALSE(std::isnan(S_merged_0_host[i])) << "S_merged_0_host[" << i << "] is nan"; + EXPECT_FALSE(std::isnan(S_merged_1_host[i])) << "S_merged_1_host[" << i << "] is nan"; + num_S_result_errors_atol_1e_3_rtol_1e_3 += + (!utils::isclose(S_merged_0_host[i], S_merged_1_host[i], 1e-3, 1e-3)); + } + + float V_result_accuracy = + 1.0 - float(num_V_result_errors_atol_1e_3_rtol_1e_3) / (seq_len * num_heads * head_dim); + float S_result_accuracy = + 1.0 - float(num_S_result_errors_atol_1e_3_rtol_1e_3) / (seq_len * num_heads); + std::cout << "seq_len=" << seq_len << ", num_heads=" << num_heads << ", head_dim=" << head_dim + << ", sparse_s=" << sparse_s + << ", V accuracy (atol=1e-3, rtol=1e-3)=" << V_result_accuracy + << ", S accuracy (atol=1e-3, rtol=1e-3)=" << S_result_accuracy << std::endl; + + EXPECT_GT(V_result_accuracy, 0.99) << "V result correctness test failed."; + EXPECT_GT(S_result_accuracy, 0.99) << "S result correctness test failed."; + + // Free device memory + hipFree(V_padded_device); + hipFree(V_ragged_device); + hipFree(S_padded_device); + hipFree(S_ragged_device); + hipFree(indptr_device); + hipFree(V_merged_0_device); + hipFree(V_merged_1_device); + hipFree(S_merged_0_device); + hipFree(S_merged_1_device); } template -void _TestVariableLengthMergeKernelPaddedCorrectness(size_t max_seq_len, - size_t seq_len) -{ - ASSERT_LE(seq_len, max_seq_len); - - const size_t num_heads = 4; - const size_t head_dim = 64; - const uint32_t max_num_index_sets = 512; - - std::vector lengths(max_seq_len); - utils::vec_randint_(lengths, 1, max_num_index_sets); - std::vector indptr(max_seq_len + 1, 0); - for (size_t i = 0; i < seq_len; ++i) { - indptr[i + 1] = indptr[i] + lengths[i]; - } - - uint32_t last_indptr = indptr[seq_len]; - std::vector V_ragged_host(last_indptr * num_heads * head_dim); - std::vector S_ragged_host(last_indptr * num_heads); - - utils::vec_normal_(V_ragged_host); - utils::vec_uniform_(S_ragged_host, -10, 10); - - // Allocate device memory using HIP - T *V_ragged_device; - float *S_ragged_device; - int32_t *indptr_device; - T *V_merged_0_device; - T *V_merged_1_device; - float *S_merged_0_device; - float *S_merged_1_device; - uint32_t *seq_len_device; - - hipMalloc(&V_ragged_device, V_ragged_host.size() * sizeof(T)); - hipMalloc(&S_ragged_device, S_ragged_host.size() * sizeof(float)); - hipMalloc(&indptr_device, indptr.size() * sizeof(int32_t)); - hipMalloc(&V_merged_0_device, - max_seq_len * num_heads * head_dim * sizeof(T)); - hipMalloc(&V_merged_1_device, - max_seq_len * num_heads * head_dim * sizeof(T)); - hipMalloc(&S_merged_0_device, max_seq_len * num_heads * sizeof(float)); - hipMalloc(&S_merged_1_device, max_seq_len * num_heads * sizeof(float)); - hipMalloc(&seq_len_device, sizeof(uint32_t)); - - // Copy data from host to device - hipMemcpy(V_ragged_device, V_ragged_host.data(), - V_ragged_host.size() * sizeof(T), hipMemcpyHostToDevice); - hipMemcpy(S_ragged_device, S_ragged_host.data(), - S_ragged_host.size() * sizeof(float), hipMemcpyHostToDevice); - hipMemcpy(indptr_device, indptr.data(), indptr.size() * sizeof(int32_t), - hipMemcpyHostToDevice); - uint32_t seq_len_value = static_cast(seq_len); - hipMemcpy(seq_len_device, &seq_len_value, sizeof(uint32_t), - hipMemcpyHostToDevice); - - // Initialize merged arrays to zero - hipMemset(V_merged_0_device, 0, - max_seq_len * num_heads * head_dim * sizeof(T)); - hipMemset(V_merged_1_device, 0, - max_seq_len * num_heads * head_dim * sizeof(T)); - hipMemset(S_merged_0_device, 0, max_seq_len * num_heads * sizeof(float)); - hipMemset(S_merged_1_device, 0, max_seq_len * num_heads * sizeof(float)); - - // Reference: use VariableLengthMergeStates on the precisely-sized input. - VariableLengthMergeStates(V_ragged_device, S_ragged_device, indptr_device, - V_merged_0_device, S_merged_0_device, seq_len, - nullptr, num_heads, head_dim); - // Expected: use VariableLengthMergeStates on a padded input - VariableLengthMergeStates(V_ragged_device, S_ragged_device, indptr_device, - V_merged_1_device, S_merged_1_device, max_seq_len, - seq_len_device, num_heads, head_dim); - - // Allocate host memory for results - std::vector V_merged_0_host(max_seq_len * num_heads * head_dim); - std::vector V_merged_1_host(max_seq_len * num_heads * head_dim); - std::vector S_merged_0_host(max_seq_len * num_heads); - std::vector S_merged_1_host(max_seq_len * num_heads); - - // Copy results from device to host - hipMemcpy(V_merged_0_host.data(), V_merged_0_device, - max_seq_len * num_heads * head_dim * sizeof(T), - hipMemcpyDeviceToHost); - hipMemcpy(V_merged_1_host.data(), V_merged_1_device, - max_seq_len * num_heads * head_dim * sizeof(T), - hipMemcpyDeviceToHost); - hipMemcpy(S_merged_0_host.data(), S_merged_0_device, - max_seq_len * num_heads * sizeof(float), hipMemcpyDeviceToHost); - hipMemcpy(S_merged_1_host.data(), S_merged_1_device, - max_seq_len * num_heads * sizeof(float), hipMemcpyDeviceToHost); - - // Compare results - size_t num_V_result_errors_atol_1e_3_rtol_1e_3 = 0, - num_S_result_errors_atol_1e_3_rtol_1e_3 = 0; - for (size_t i = 0; i < seq_len * num_heads * head_dim; ++i) { - float V_merged_1_host_value = - fi::con::explicit_casting(V_merged_1_host[i]); - float V_merged_0_host_value = - fi::con::explicit_casting(V_merged_0_host[i]); - EXPECT_FALSE(std::isnan(V_merged_1_host_value)) - << "V_merged_1_host[" << i << "] is nan"; - num_V_result_errors_atol_1e_3_rtol_1e_3 += (!utils::isclose( - V_merged_0_host_value, V_merged_1_host_value, 1e-3, 1e-3)); - } - for (size_t i = 0; i < seq_len * num_heads; ++i) { - EXPECT_FALSE(std::isnan(float(S_merged_0_host[i]))) - << "S_merged_0_host[" << i << "] is nan"; - EXPECT_FALSE(std::isnan(float(S_merged_1_host[i]))) - << "S_merged_1_host[" << i << "] is nan"; - num_S_result_errors_atol_1e_3_rtol_1e_3 += (!utils::isclose( - float(S_merged_0_host[i]), float(S_merged_1_host[i]), 1e-3, 1e-3)); - } - float V_result_accuracy = - 1.0 - float(num_V_result_errors_atol_1e_3_rtol_1e_3) / - (seq_len * num_heads * head_dim); - float S_result_accuracy = - 1.0 - - float(num_S_result_errors_atol_1e_3_rtol_1e_3) / (seq_len * num_heads); - std::cout << "seq_len=" << seq_len << ", num_heads=" << num_heads - << ", head_dim=" << head_dim - << ", V accuracy (atol=1e-3, rtol=1e-3)=" << V_result_accuracy - << ", S accuracy (atol=1e-3, rtol=1e-3)=" << S_result_accuracy - << std::endl; - - EXPECT_GT(V_result_accuracy, 0.99) << "V result correctness test failed."; - EXPECT_GT(S_result_accuracy, 0.99) << "S result correctness test failed."; - - // Free device memory - hipFree(V_ragged_device); - hipFree(S_ragged_device); - hipFree(indptr_device); - hipFree(V_merged_0_device); - hipFree(V_merged_1_device); - hipFree(S_merged_0_device); - hipFree(S_merged_1_device); - hipFree(seq_len_device); +void _TestVariableLengthMergeKernelPaddedCorrectness(size_t max_seq_len, size_t seq_len) { + ASSERT_LE(seq_len, max_seq_len); + + const size_t num_heads = 4; + const size_t head_dim = 64; + const uint32_t max_num_index_sets = 512; + + std::vector lengths(max_seq_len); + utils::vec_randint_(lengths, 1, max_num_index_sets); + std::vector indptr(max_seq_len + 1, 0); + for (size_t i = 0; i < seq_len; ++i) { + indptr[i + 1] = indptr[i] + lengths[i]; + } + + uint32_t last_indptr = indptr[seq_len]; + std::vector V_ragged_host(last_indptr * num_heads * head_dim); + std::vector S_ragged_host(last_indptr * num_heads); + + utils::vec_normal_(V_ragged_host); + utils::vec_uniform_(S_ragged_host, -10, 10); + + // Allocate device memory using HIP + T* V_ragged_device; + float* S_ragged_device; + int32_t* indptr_device; + T* V_merged_0_device; + T* V_merged_1_device; + float* S_merged_0_device; + float* S_merged_1_device; + uint32_t* seq_len_device; + + hipMalloc(&V_ragged_device, V_ragged_host.size() * sizeof(T)); + hipMalloc(&S_ragged_device, S_ragged_host.size() * sizeof(float)); + hipMalloc(&indptr_device, indptr.size() * sizeof(int32_t)); + hipMalloc(&V_merged_0_device, max_seq_len * num_heads * head_dim * sizeof(T)); + hipMalloc(&V_merged_1_device, max_seq_len * num_heads * head_dim * sizeof(T)); + hipMalloc(&S_merged_0_device, max_seq_len * num_heads * sizeof(float)); + hipMalloc(&S_merged_1_device, max_seq_len * num_heads * sizeof(float)); + hipMalloc(&seq_len_device, sizeof(uint32_t)); + + // Copy data from host to device + hipMemcpy(V_ragged_device, V_ragged_host.data(), V_ragged_host.size() * sizeof(T), + hipMemcpyHostToDevice); + hipMemcpy(S_ragged_device, S_ragged_host.data(), S_ragged_host.size() * sizeof(float), + hipMemcpyHostToDevice); + hipMemcpy(indptr_device, indptr.data(), indptr.size() * sizeof(int32_t), hipMemcpyHostToDevice); + uint32_t seq_len_value = static_cast(seq_len); + hipMemcpy(seq_len_device, &seq_len_value, sizeof(uint32_t), hipMemcpyHostToDevice); + + // Initialize merged arrays to zero + hipMemset(V_merged_0_device, 0, max_seq_len * num_heads * head_dim * sizeof(T)); + hipMemset(V_merged_1_device, 0, max_seq_len * num_heads * head_dim * sizeof(T)); + hipMemset(S_merged_0_device, 0, max_seq_len * num_heads * sizeof(float)); + hipMemset(S_merged_1_device, 0, max_seq_len * num_heads * sizeof(float)); + + // Reference: use VariableLengthMergeStates on the precisely-sized input. + VariableLengthMergeStates(V_ragged_device, S_ragged_device, indptr_device, V_merged_0_device, + S_merged_0_device, seq_len, nullptr, num_heads, head_dim); + // Expected: use VariableLengthMergeStates on a padded input + VariableLengthMergeStates(V_ragged_device, S_ragged_device, indptr_device, V_merged_1_device, + S_merged_1_device, max_seq_len, seq_len_device, num_heads, head_dim); + + // Allocate host memory for results + std::vector V_merged_0_host(max_seq_len * num_heads * head_dim); + std::vector V_merged_1_host(max_seq_len * num_heads * head_dim); + std::vector S_merged_0_host(max_seq_len * num_heads); + std::vector S_merged_1_host(max_seq_len * num_heads); + + // Copy results from device to host + hipMemcpy(V_merged_0_host.data(), V_merged_0_device, + max_seq_len * num_heads * head_dim * sizeof(T), hipMemcpyDeviceToHost); + hipMemcpy(V_merged_1_host.data(), V_merged_1_device, + max_seq_len * num_heads * head_dim * sizeof(T), hipMemcpyDeviceToHost); + hipMemcpy(S_merged_0_host.data(), S_merged_0_device, max_seq_len * num_heads * sizeof(float), + hipMemcpyDeviceToHost); + hipMemcpy(S_merged_1_host.data(), S_merged_1_device, max_seq_len * num_heads * sizeof(float), + hipMemcpyDeviceToHost); + + // Compare results + size_t num_V_result_errors_atol_1e_3_rtol_1e_3 = 0, num_S_result_errors_atol_1e_3_rtol_1e_3 = 0; + for (size_t i = 0; i < seq_len * num_heads * head_dim; ++i) { + float V_merged_1_host_value = fi::con::explicit_casting(V_merged_1_host[i]); + float V_merged_0_host_value = fi::con::explicit_casting(V_merged_0_host[i]); + EXPECT_FALSE(std::isnan(V_merged_1_host_value)) << "V_merged_1_host[" << i << "] is nan"; + num_V_result_errors_atol_1e_3_rtol_1e_3 += + (!utils::isclose(V_merged_0_host_value, V_merged_1_host_value, 1e-3, 1e-3)); + } + for (size_t i = 0; i < seq_len * num_heads; ++i) { + EXPECT_FALSE(std::isnan(float(S_merged_0_host[i]))) << "S_merged_0_host[" << i << "] is nan"; + EXPECT_FALSE(std::isnan(float(S_merged_1_host[i]))) << "S_merged_1_host[" << i << "] is nan"; + num_S_result_errors_atol_1e_3_rtol_1e_3 += + (!utils::isclose(float(S_merged_0_host[i]), float(S_merged_1_host[i]), 1e-3, 1e-3)); + } + float V_result_accuracy = + 1.0 - float(num_V_result_errors_atol_1e_3_rtol_1e_3) / (seq_len * num_heads * head_dim); + float S_result_accuracy = + 1.0 - float(num_S_result_errors_atol_1e_3_rtol_1e_3) / (seq_len * num_heads); + std::cout << "seq_len=" << seq_len << ", num_heads=" << num_heads << ", head_dim=" << head_dim + << ", V accuracy (atol=1e-3, rtol=1e-3)=" << V_result_accuracy + << ", S accuracy (atol=1e-3, rtol=1e-3)=" << S_result_accuracy << std::endl; + + EXPECT_GT(V_result_accuracy, 0.99) << "V result correctness test failed."; + EXPECT_GT(S_result_accuracy, 0.99) << "S result correctness test failed."; + + // Free device memory + hipFree(V_ragged_device); + hipFree(S_ragged_device); + hipFree(indptr_device); + hipFree(V_merged_0_device); + hipFree(V_merged_1_device); + hipFree(S_merged_0_device); + hipFree(S_merged_1_device); + hipFree(seq_len_device); } template -void _TestMergeKernelCorrectness(size_t num_index_sets, - size_t seq_len, - size_t num_heads, - size_t head_dim, - bool sparse_s) -{ - std::vector V_host(seq_len * num_index_sets * num_heads * head_dim); - std::vector V_host_trans_f32(num_index_sets * seq_len * num_heads * - head_dim); - std::vector S_host(seq_len * num_index_sets * num_heads); - std::vector S_host_trans(num_index_sets * seq_len * num_heads); - - utils::vec_normal_(V_host); - if (sparse_s) { - for (uint32_t i = 0; i < num_index_sets; ++i) { - float fill_val = is_prime(i) ? 10 : -10; - for (uint32_t j = 0; j < seq_len; ++j) { - for (uint32_t k = 0; k < num_heads; ++k) { - S_host[(j * num_index_sets + i) * num_heads + k] = fill_val; - } - } - } - } - else { - utils::vec_uniform_(S_host, -10, 10); - } - +void _TestMergeKernelCorrectness(size_t num_index_sets, size_t seq_len, size_t num_heads, + size_t head_dim, bool sparse_s) { + std::vector V_host(seq_len * num_index_sets * num_heads * head_dim); + std::vector V_host_trans_f32(num_index_sets * seq_len * num_heads * head_dim); + std::vector S_host(seq_len * num_index_sets * num_heads); + std::vector S_host_trans(num_index_sets * seq_len * num_heads); + + utils::vec_normal_(V_host); + if (sparse_s) { for (uint32_t i = 0; i < num_index_sets; ++i) { - for (uint32_t j = 0; j < seq_len; ++j) { - std::transform( - V_host.begin() + - (j * num_index_sets + i) * num_heads * head_dim, - V_host.begin() + - (j * num_index_sets + i + 1) * num_heads * head_dim, - V_host_trans_f32.begin() + - (i * seq_len + j) * num_heads * head_dim, - [](T x) { return fi::con::explicit_casting(x); }); - std::copy(S_host.begin() + (j * num_index_sets + i) * num_heads, - S_host.begin() + (j * num_index_sets + i + 1) * num_heads, - S_host_trans.begin() + (i * seq_len + j) * num_heads); - } - } - - // Allocate device memory using HIP - T *V_device; - float *V_device_trans_f32; - float *S_device; - float *S_device_trans; - float *V_merged_0_device; - float *S_merged_0_device; - T *V_merged_1_device; - float *S_merged_1_device; - - hipMalloc(&V_device, V_host.size() * sizeof(T)); - hipMalloc(&V_device_trans_f32, V_host_trans_f32.size() * sizeof(float)); - hipMalloc(&S_device, S_host.size() * sizeof(float)); - hipMalloc(&S_device_trans, S_host_trans.size() * sizeof(float)); - hipMalloc(&V_merged_0_device, - seq_len * num_heads * head_dim * sizeof(float)); - hipMalloc(&S_merged_0_device, seq_len * num_heads * sizeof(float)); - hipMalloc(&V_merged_1_device, seq_len * num_heads * head_dim * sizeof(T)); - hipMalloc(&S_merged_1_device, seq_len * num_heads * sizeof(float)); - - // Copy data from host to device - hipMemcpy(V_device, V_host.data(), V_host.size() * sizeof(T), - hipMemcpyHostToDevice); - hipMemcpy(V_device_trans_f32, V_host_trans_f32.data(), - V_host_trans_f32.size() * sizeof(float), hipMemcpyHostToDevice); - hipMemcpy(S_device, S_host.data(), S_host.size() * sizeof(float), - hipMemcpyHostToDevice); - hipMemcpy(S_device_trans, S_host_trans.data(), - S_host_trans.size() * sizeof(float), hipMemcpyHostToDevice); - - // Initialize merged arrays to zero - hipMemset(V_merged_0_device, 0, - seq_len * num_heads * head_dim * sizeof(float)); - hipMemset(S_merged_0_device, 0, seq_len * num_heads * sizeof(float)); - hipMemset(V_merged_1_device, 0, seq_len * num_heads * head_dim * sizeof(T)); - hipMemset(S_merged_1_device, 0, seq_len * num_heads * sizeof(float)); - - if (num_index_sets > 1) { - // Method 0: use MergeState - MergeState(V_device_trans_f32, S_device_trans, - V_device_trans_f32 + seq_len * num_heads * head_dim, - S_device_trans + seq_len * num_heads, V_merged_0_device, - S_merged_0_device, seq_len, num_heads, head_dim); - for (uint i = 2; i < num_index_sets; ++i) { - MergeStateInPlace(V_merged_0_device, S_merged_0_device, - V_device_trans_f32 + - i * seq_len * num_heads * head_dim, - S_device_trans + i * seq_len * num_heads, seq_len, - num_heads, head_dim); + float fill_val = is_prime(i) ? 10 : -10; + for (uint32_t j = 0; j < seq_len; ++j) { + for (uint32_t k = 0; k < num_heads; ++k) { + S_host[(j * num_index_sets + i) * num_heads + k] = fill_val; } + } } - else { - hipMemcpy(V_merged_0_device, V_device, - seq_len * num_heads * head_dim * sizeof(T), - hipMemcpyDeviceToDevice); - hipMemcpy(S_merged_0_device, S_device, - seq_len * num_heads * sizeof(float), hipMemcpyDeviceToDevice); - } + } else { + utils::vec_uniform_(S_host, -10, 10); + } - // Method 1: use MergeStates - MergeStates(V_device, S_device, V_merged_1_device, S_merged_1_device, - num_index_sets, seq_len, num_heads, head_dim); - - // Allocate host memory for results - std::vector V_merged_0_host(seq_len * num_heads * head_dim); - std::vector V_merged_1_host(seq_len * num_heads * head_dim); - std::vector S_merged_0_host(seq_len * num_heads); - std::vector S_merged_1_host(seq_len * num_heads); - - // Copy results from device to host - hipMemcpy(V_merged_0_host.data(), V_merged_0_device, - seq_len * num_heads * head_dim * sizeof(float), - hipMemcpyDeviceToHost); - hipMemcpy(V_merged_1_host.data(), V_merged_1_device, - seq_len * num_heads * head_dim * sizeof(T), - hipMemcpyDeviceToHost); - hipMemcpy(S_merged_0_host.data(), S_merged_0_device, - seq_len * num_heads * sizeof(float), hipMemcpyDeviceToHost); - hipMemcpy(S_merged_1_host.data(), S_merged_1_device, - seq_len * num_heads * sizeof(float), hipMemcpyDeviceToHost); - - // Compare results - size_t num_V_result_errors_atol_1e_3_rtol_1e_3 = 0, - num_S_result_errors_atol_1e_3_rtol_1e_3 = 0; - for (size_t i = 0; i < seq_len * num_heads * head_dim; ++i) { - - float V_merged_0_host_value = - V_merged_0_host[i]; // V_merged_0_host is already float - float V_merged_1_host_value = - fi::con::explicit_casting(V_merged_1_host[i]); - - EXPECT_FALSE(std::isnan(V_merged_0_host_value)) - << "V_merged_0_host[" << i << "] is nan"; - EXPECT_FALSE(std::isnan(V_merged_1_host_value)) - << "V_merged_1_host[" << i << "] is nan"; - num_V_result_errors_atol_1e_3_rtol_1e_3 += (!utils::isclose( - V_merged_0_host_value, V_merged_1_host_value, 1e-3, 1e-3)); + for (uint32_t i = 0; i < num_index_sets; ++i) { + for (uint32_t j = 0; j < seq_len; ++j) { + std::transform(V_host.begin() + (j * num_index_sets + i) * num_heads * head_dim, + V_host.begin() + (j * num_index_sets + i + 1) * num_heads * head_dim, + V_host_trans_f32.begin() + (i * seq_len + j) * num_heads * head_dim, + [](T x) { return fi::con::explicit_casting(x); }); + std::copy(S_host.begin() + (j * num_index_sets + i) * num_heads, + S_host.begin() + (j * num_index_sets + i + 1) * num_heads, + S_host_trans.begin() + (i * seq_len + j) * num_heads); } - for (size_t i = 0; i < seq_len * num_heads; ++i) { - EXPECT_FALSE(std::isnan(float(S_merged_0_host[i]))) - << "S_merged_0_host[" << i << "] is nan"; - EXPECT_FALSE(std::isnan(float(S_merged_1_host[i]))) - << "S_merged_1_host[" << i << "] is nan"; - num_S_result_errors_atol_1e_3_rtol_1e_3 += (!utils::isclose( - float(S_merged_0_host[i]), float(S_merged_1_host[i]), 1e-3, 1e-3)); + } + + // Allocate device memory using HIP + T* V_device; + float* V_device_trans_f32; + float* S_device; + float* S_device_trans; + float* V_merged_0_device; + float* S_merged_0_device; + T* V_merged_1_device; + float* S_merged_1_device; + + hipMalloc(&V_device, V_host.size() * sizeof(T)); + hipMalloc(&V_device_trans_f32, V_host_trans_f32.size() * sizeof(float)); + hipMalloc(&S_device, S_host.size() * sizeof(float)); + hipMalloc(&S_device_trans, S_host_trans.size() * sizeof(float)); + hipMalloc(&V_merged_0_device, seq_len * num_heads * head_dim * sizeof(float)); + hipMalloc(&S_merged_0_device, seq_len * num_heads * sizeof(float)); + hipMalloc(&V_merged_1_device, seq_len * num_heads * head_dim * sizeof(T)); + hipMalloc(&S_merged_1_device, seq_len * num_heads * sizeof(float)); + + // Copy data from host to device + hipMemcpy(V_device, V_host.data(), V_host.size() * sizeof(T), hipMemcpyHostToDevice); + hipMemcpy(V_device_trans_f32, V_host_trans_f32.data(), V_host_trans_f32.size() * sizeof(float), + hipMemcpyHostToDevice); + hipMemcpy(S_device, S_host.data(), S_host.size() * sizeof(float), hipMemcpyHostToDevice); + hipMemcpy(S_device_trans, S_host_trans.data(), S_host_trans.size() * sizeof(float), + hipMemcpyHostToDevice); + + // Initialize merged arrays to zero + hipMemset(V_merged_0_device, 0, seq_len * num_heads * head_dim * sizeof(float)); + hipMemset(S_merged_0_device, 0, seq_len * num_heads * sizeof(float)); + hipMemset(V_merged_1_device, 0, seq_len * num_heads * head_dim * sizeof(T)); + hipMemset(S_merged_1_device, 0, seq_len * num_heads * sizeof(float)); + + if (num_index_sets > 1) { + // Method 0: use MergeState + MergeState(V_device_trans_f32, S_device_trans, + V_device_trans_f32 + seq_len * num_heads * head_dim, + S_device_trans + seq_len * num_heads, V_merged_0_device, S_merged_0_device, seq_len, + num_heads, head_dim); + for (uint i = 2; i < num_index_sets; ++i) { + MergeStateInPlace(V_merged_0_device, S_merged_0_device, + V_device_trans_f32 + i * seq_len * num_heads * head_dim, + S_device_trans + i * seq_len * num_heads, seq_len, num_heads, head_dim); } - float V_result_accuracy = - 1.0 - float(num_V_result_errors_atol_1e_3_rtol_1e_3) / - (seq_len * num_heads * head_dim); - float S_result_accuracy = - 1.0 - - float(num_S_result_errors_atol_1e_3_rtol_1e_3) / (seq_len * num_heads); - std::cout << "num_index_sets=" << num_index_sets << ", seq_len=" << seq_len - << ", num_heads=" << num_heads << ", head_dim=" << head_dim - << ", sparse_s=" << sparse_s - << ", V accuracy (atol=1e-3, rtol=1e-3)=" << V_result_accuracy - << ", S accuracy (atol=1e-3, rtol=1e-3)=" << S_result_accuracy - << std::endl; - EXPECT_GT(V_result_accuracy, 0.99) << "V result correctness test failed."; - EXPECT_GT(S_result_accuracy, 0.99) << "S result correctness test failed."; - - // Free device memory - hipFree(V_device); - hipFree(V_device_trans_f32); - hipFree(S_device); - hipFree(S_device_trans); - hipFree(V_merged_0_device); - hipFree(S_merged_0_device); - hipFree(V_merged_1_device); - hipFree(S_merged_1_device); + } else { + hipMemcpy(V_merged_0_device, V_device, seq_len * num_heads * head_dim * sizeof(T), + hipMemcpyDeviceToDevice); + hipMemcpy(S_merged_0_device, S_device, seq_len * num_heads * sizeof(float), + hipMemcpyDeviceToDevice); + } + + // Method 1: use MergeStates + MergeStates(V_device, S_device, V_merged_1_device, S_merged_1_device, num_index_sets, seq_len, + num_heads, head_dim); + + // Allocate host memory for results + std::vector V_merged_0_host(seq_len * num_heads * head_dim); + std::vector V_merged_1_host(seq_len * num_heads * head_dim); + std::vector S_merged_0_host(seq_len * num_heads); + std::vector S_merged_1_host(seq_len * num_heads); + + // Copy results from device to host + hipMemcpy(V_merged_0_host.data(), V_merged_0_device, + seq_len * num_heads * head_dim * sizeof(float), hipMemcpyDeviceToHost); + hipMemcpy(V_merged_1_host.data(), V_merged_1_device, seq_len * num_heads * head_dim * sizeof(T), + hipMemcpyDeviceToHost); + hipMemcpy(S_merged_0_host.data(), S_merged_0_device, seq_len * num_heads * sizeof(float), + hipMemcpyDeviceToHost); + hipMemcpy(S_merged_1_host.data(), S_merged_1_device, seq_len * num_heads * sizeof(float), + hipMemcpyDeviceToHost); + + // Compare results + size_t num_V_result_errors_atol_1e_3_rtol_1e_3 = 0, num_S_result_errors_atol_1e_3_rtol_1e_3 = 0; + for (size_t i = 0; i < seq_len * num_heads * head_dim; ++i) { + float V_merged_0_host_value = V_merged_0_host[i]; // V_merged_0_host is already float + float V_merged_1_host_value = fi::con::explicit_casting(V_merged_1_host[i]); + + EXPECT_FALSE(std::isnan(V_merged_0_host_value)) << "V_merged_0_host[" << i << "] is nan"; + EXPECT_FALSE(std::isnan(V_merged_1_host_value)) << "V_merged_1_host[" << i << "] is nan"; + num_V_result_errors_atol_1e_3_rtol_1e_3 += + (!utils::isclose(V_merged_0_host_value, V_merged_1_host_value, 1e-3, 1e-3)); + } + for (size_t i = 0; i < seq_len * num_heads; ++i) { + EXPECT_FALSE(std::isnan(float(S_merged_0_host[i]))) << "S_merged_0_host[" << i << "] is nan"; + EXPECT_FALSE(std::isnan(float(S_merged_1_host[i]))) << "S_merged_1_host[" << i << "] is nan"; + num_S_result_errors_atol_1e_3_rtol_1e_3 += + (!utils::isclose(float(S_merged_0_host[i]), float(S_merged_1_host[i]), 1e-3, 1e-3)); + } + float V_result_accuracy = + 1.0 - float(num_V_result_errors_atol_1e_3_rtol_1e_3) / (seq_len * num_heads * head_dim); + float S_result_accuracy = + 1.0 - float(num_S_result_errors_atol_1e_3_rtol_1e_3) / (seq_len * num_heads); + std::cout << "num_index_sets=" << num_index_sets << ", seq_len=" << seq_len + << ", num_heads=" << num_heads << ", head_dim=" << head_dim << ", sparse_s=" << sparse_s + << ", V accuracy (atol=1e-3, rtol=1e-3)=" << V_result_accuracy + << ", S accuracy (atol=1e-3, rtol=1e-3)=" << S_result_accuracy << std::endl; + EXPECT_GT(V_result_accuracy, 0.99) << "V result correctness test failed."; + EXPECT_GT(S_result_accuracy, 0.99) << "S result correctness test failed."; + + // Free device memory + hipFree(V_device); + hipFree(V_device_trans_f32); + hipFree(S_device); + hipFree(S_device_trans); + hipFree(V_merged_0_device); + hipFree(S_merged_0_device); + hipFree(V_merged_1_device); + hipFree(S_merged_1_device); } -template void TestMergeKernelCorrectness() -{ - for (size_t num_index_sets : {2, 9, 81, 513}) { - for (size_t seq_len : {4, 16, 77}) { - for (size_t num_heads : {1, 21, 32}) { - for (size_t head_dim : {64, 128, 256}) { - for (bool sparse_s : {false, true}) { - _TestMergeKernelCorrectness(num_index_sets, seq_len, - num_heads, head_dim, - sparse_s); - } - } - } +template +void TestMergeKernelCorrectness() { + for (size_t num_index_sets : {2, 9, 81, 513}) { + for (size_t seq_len : {4, 16, 77}) { + for (size_t num_heads : {1, 21, 32}) { + for (size_t head_dim : {64, 128, 256}) { + for (bool sparse_s : {false, true}) { + _TestMergeKernelCorrectness(num_index_sets, seq_len, num_heads, head_dim, sparse_s); + } } + } } + } } -template void TestVariableLengthMergeKernelCorrectness() -{ - for (size_t seq_len : {1, 3, 77, 191}) { - for (size_t num_heads : {1, 4, 32}) { - for (size_t head_dim : {64, 128, 256}) { - for (bool sparse_s : {false, true}) { - _TestVariableLengthMergeKernelCorrectness( - seq_len, num_heads, head_dim, sparse_s); - } - } +template +void TestVariableLengthMergeKernelCorrectness() { + for (size_t seq_len : {1, 3, 77, 191}) { + for (size_t num_heads : {1, 4, 32}) { + for (size_t head_dim : {64, 128, 256}) { + for (bool sparse_s : {false, true}) { + _TestVariableLengthMergeKernelCorrectness(seq_len, num_heads, head_dim, sparse_s); } + } } + } } -template void TestVariableLengthMergeKernelPaddedCorrectness() -{ - _TestVariableLengthMergeKernelPaddedCorrectness(8, 1); - _TestVariableLengthMergeKernelPaddedCorrectness(128, 77); +template +void TestVariableLengthMergeKernelPaddedCorrectness() { + _TestVariableLengthMergeKernelPaddedCorrectness(8, 1); + _TestVariableLengthMergeKernelPaddedCorrectness(128, 77); } -TEST(FlashInferCorrectnessTest, MergeKernelCorrectnessTestFP16) -{ - TestMergeKernelCorrectness<__half>(); +TEST(FlashInferCorrectnessTest, MergeKernelCorrectnessTestFP16) { + TestMergeKernelCorrectness<__half>(); } -TEST(FlashInferCorrectnessTest, - VariableLengthMergeKernelPaddedCorrectnessTestFP16) -{ - TestVariableLengthMergeKernelPaddedCorrectness<__half>(); +TEST(FlashInferCorrectnessTest, VariableLengthMergeKernelPaddedCorrectnessTestFP16) { + TestVariableLengthMergeKernelPaddedCorrectness<__half>(); } -TEST(FlashInferCorrectnessTest, VariableLengthMergeKernelCorrectnessTestFP16) -{ - TestVariableLengthMergeKernelCorrectness<__half>(); +TEST(FlashInferCorrectnessTest, VariableLengthMergeKernelCorrectnessTestFP16) { + TestVariableLengthMergeKernelCorrectness<__half>(); } -int main(int argc, char **argv) -{ - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); } diff --git a/libflashinfer/tests/hip/test_compute_sfm.cpp b/libflashinfer/tests/hip/test_compute_sfm.cpp index e0cda09e50..09ac6fddec 100644 --- a/libflashinfer/tests/hip/test_compute_sfm.cpp +++ b/libflashinfer/tests/hip/test_compute_sfm.cpp @@ -3,232 +3,187 @@ // // SPDX - License - Identifier : Apache 2.0 +#include + +#include + #include "../../utils/flashinfer_prefill_ops.hip.h" #include "../../utils/utils_hip.h" #include "flashinfer/attention/generic/prefill.cuh" #include "gpu_iface/gpu_runtime_compat.hpp" -#include - -#include - #define HIP_ENABLE_WARP_SYNC_BUILTINS 1 using namespace flashinfer; -namespace -{ +namespace { template std::vector test_compute_qk_and_softmax_cpu( - const std::vector &q, - const std::vector &k, - const std::vector &v, - size_t qo_len, - size_t kv_len, - size_t num_qo_heads, - size_t num_kv_heads, - size_t head_dim, - bool causal = true, - QKVLayout kv_layout = QKVLayout::kHND, - PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, - float rope_scale = 1.f, - float rope_theta = 1e4) -{ - assert(qo_len <= kv_len); - assert(num_qo_heads % num_kv_heads == 0); - float sm_scale = 1.f / std::sqrt(float(head_dim)); - std::vector o(qo_len * num_qo_heads * head_dim); - std::vector att(kv_len); - std::vector q_rotary_local(head_dim); - std::vector k_rotary_local(head_dim); - DISPATCH_head_dim(head_dim, HEAD_DIM, { - tensor_info_t info(qo_len, kv_len, num_qo_heads, num_kv_heads, - kv_layout, HEAD_DIM); - - for (size_t qo_head_idx = 0; qo_head_idx < num_qo_heads; ++qo_head_idx) - { - const size_t kv_head_idx = qo_head_idx / info.get_group_size(); - for (size_t q_idx = 0; q_idx < qo_len; ++q_idx) { - float max_val = -5e4; - - for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { - att[kv_idx] = 0.; - switch (pos_encoding_mode) { - case PosEncodingMode::kNone: - { - for (size_t feat_idx = 0; feat_idx < head_dim; - ++feat_idx) - { - att[kv_idx] += - fi::con::explicit_casting( - q[info.get_q_elem_offset(q_idx, qo_head_idx, - feat_idx)]) * - fi::con::explicit_casting( - k[info.get_kv_elem_offset( - kv_idx, kv_head_idx, feat_idx)]) * - sm_scale; - } - break; - } - default: - { - std::ostringstream err_msg; - err_msg << "Unsupported rotary mode."; - FLASHINFER_ERROR(err_msg.str()); - } - } - max_val = std::max(max_val, att[kv_idx]); - } - // exp minus max - float denom = 0; - for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { - att[kv_idx] = std::exp(att[kv_idx] - max_val); - denom += att[kv_idx]; - } - - // divide by denom - for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { - att[kv_idx] /= denom; - } + const std::vector& q, const std::vector& k, const std::vector& v, + size_t qo_len, size_t kv_len, size_t num_qo_heads, size_t num_kv_heads, size_t head_dim, + bool causal = true, QKVLayout kv_layout = QKVLayout::kHND, + PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, float rope_scale = 1.f, + float rope_theta = 1e4) { + assert(qo_len <= kv_len); + assert(num_qo_heads % num_kv_heads == 0); + float sm_scale = 1.f / std::sqrt(float(head_dim)); + std::vector o(qo_len * num_qo_heads * head_dim); + std::vector att(kv_len); + std::vector q_rotary_local(head_dim); + std::vector k_rotary_local(head_dim); + DISPATCH_head_dim(head_dim, HEAD_DIM, { + tensor_info_t info(qo_len, kv_len, num_qo_heads, num_kv_heads, kv_layout, HEAD_DIM); + + for (size_t qo_head_idx = 0; qo_head_idx < num_qo_heads; ++qo_head_idx) { + const size_t kv_head_idx = qo_head_idx / info.get_group_size(); + for (size_t q_idx = 0; q_idx < qo_len; ++q_idx) { + float max_val = -5e4; + + for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { + att[kv_idx] = 0.; + switch (pos_encoding_mode) { + case PosEncodingMode::kNone: { + for (size_t feat_idx = 0; feat_idx < head_dim; ++feat_idx) { + att[kv_idx] += fi::con::explicit_casting( + q[info.get_q_elem_offset(q_idx, qo_head_idx, feat_idx)]) * + fi::con::explicit_casting( + k[info.get_kv_elem_offset(kv_idx, kv_head_idx, feat_idx)]) * + sm_scale; + } + break; + } + default: { + std::ostringstream err_msg; + err_msg << "Unsupported rotary mode."; + FLASHINFER_ERROR(err_msg.str()); } + } + max_val = std::max(max_val, att[kv_idx]); } - }); - return std::move(att); -} -} // namespace - -template -void _TestComputeSFMCorrectness(size_t qo_len, - size_t kv_len, - size_t num_qo_heads, - size_t num_kv_heads, - size_t head_dim, - bool causal, - QKVLayout kv_layout, - PosEncodingMode pos_encoding_mode, - bool use_fp16_qk_reduction, - float rtol = 1e-3, - float atol = 1e-3) -{ - std::vector q(qo_len * num_qo_heads * head_dim); - std::vector k(kv_len * num_kv_heads * head_dim); - std::vector v(kv_len * num_kv_heads * head_dim); - std::vector o(qo_len * num_qo_heads * head_dim); - - utils::generate_data(q); - utils::generate_data(k); - utils::generate_data(v); - utils::generate_data(o); - - DTypeQ *q_d; - FI_GPU_CALL(hipMalloc(&q_d, q.size() * sizeof(DTypeQ))); - FI_GPU_CALL(hipMemcpy(q_d, q.data(), q.size() * sizeof(DTypeQ), - hipMemcpyHostToDevice)); - - DTypeKV *k_d; - FI_GPU_CALL(hipMalloc(&k_d, k.size() * sizeof(DTypeKV))); - FI_GPU_CALL(hipMemcpy(k_d, k.data(), k.size() * sizeof(DTypeKV), - hipMemcpyHostToDevice)); - - DTypeKV *v_d; - FI_GPU_CALL(hipMalloc(&v_d, v.size() * sizeof(DTypeKV))); - FI_GPU_CALL(hipMemcpy(v_d, v.data(), v.size() * sizeof(DTypeKV), - hipMemcpyHostToDevice)); - - DTypeO *o_d; - FI_GPU_CALL(hipMalloc(&o_d, o.size() * sizeof(DTypeO))); - FI_GPU_CALL(hipMemcpy(o_d, o.data(), o.size() * sizeof(DTypeO), - hipMemcpyHostToDevice)); - - DTypeO *tmp_d; - FI_GPU_CALL(hipMalloc(&tmp_d, 16 * 1024 * 1024 * sizeof(DTypeO))); - - hipError_t status = - flashinfer::SinglePrefillWithKVCache( - q_d, k_d, v_d, o_d, tmp_d, - /*lse=*/nullptr, num_qo_heads, num_kv_heads, qo_len, kv_len, - head_dim, causal, kv_layout, pos_encoding_mode, - use_fp16_qk_reduction); - - EXPECT_EQ(status, hipSuccess) - << "SinglePrefillWithKVCache kernel launch failed, error message: " - << hipGetErrorString(status); - - std::vector o_h(o.size()); - FI_GPU_CALL(hipMemcpy(o_h.data(), o_d, o_h.size() * sizeof(DTypeO), - hipMemcpyDeviceToHost)); - - // Print the first 10 elements of the output vector for debugging - // std::cout << "Output vector (first 10 elements):"; - // std::cout << "[" << std::endl; - // for (int i = 0; i < 10; ++i) { - // std::cout << fi::con::explicit_casting(o_h[i]) << " "; - // } - // std::cout << "]" << std::endl; - - bool isEmpty = o_h.empty(); - EXPECT_EQ(isEmpty, false) << "Output vector is empty"; - - std::vector o_ref = - test_compute_qk_and_softmax_cpu( - q, k, v, qo_len, kv_len, num_qo_heads, num_kv_heads, head_dim, - causal, kv_layout, pos_encoding_mode); - size_t num_results_error_atol = 0; - bool nan_detected = false; - - for (size_t i = 0; i < o_ref.size(); ++i) { - float o_h_val = fi::con::explicit_casting(o_h[i]); - float o_ref_val = fi::con::explicit_casting(o_ref[i]); - - if (isnan(o_h_val)) { - nan_detected = true; + // exp minus max + float denom = 0; + for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { + att[kv_idx] = std::exp(att[kv_idx] - max_val); + denom += att[kv_idx]; } - num_results_error_atol += - (!utils::isclose(o_ref_val, o_h_val, rtol, atol)); - if (!utils::isclose(o_ref_val, o_h_val, rtol, atol)) { - std::cout << "i=" << i << ", o_ref[i]=" << o_ref_val - << ", o_h[i]=" << o_h_val << std::endl; + // divide by denom + for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { + att[kv_idx] /= denom; } + } + } + }); + return std::move(att); +} +} // namespace + +template +void _TestComputeSFMCorrectness(size_t qo_len, size_t kv_len, size_t num_qo_heads, + size_t num_kv_heads, size_t head_dim, bool causal, + QKVLayout kv_layout, PosEncodingMode pos_encoding_mode, + bool use_fp16_qk_reduction, float rtol = 1e-3, float atol = 1e-3) { + std::vector q(qo_len * num_qo_heads * head_dim); + std::vector k(kv_len * num_kv_heads * head_dim); + std::vector v(kv_len * num_kv_heads * head_dim); + std::vector o(qo_len * num_qo_heads * head_dim); + + utils::generate_data(q); + utils::generate_data(k); + utils::generate_data(v); + utils::generate_data(o); + + DTypeQ* q_d; + FI_GPU_CALL(hipMalloc(&q_d, q.size() * sizeof(DTypeQ))); + FI_GPU_CALL(hipMemcpy(q_d, q.data(), q.size() * sizeof(DTypeQ), hipMemcpyHostToDevice)); + + DTypeKV* k_d; + FI_GPU_CALL(hipMalloc(&k_d, k.size() * sizeof(DTypeKV))); + FI_GPU_CALL(hipMemcpy(k_d, k.data(), k.size() * sizeof(DTypeKV), hipMemcpyHostToDevice)); + + DTypeKV* v_d; + FI_GPU_CALL(hipMalloc(&v_d, v.size() * sizeof(DTypeKV))); + FI_GPU_CALL(hipMemcpy(v_d, v.data(), v.size() * sizeof(DTypeKV), hipMemcpyHostToDevice)); + + DTypeO* o_d; + FI_GPU_CALL(hipMalloc(&o_d, o.size() * sizeof(DTypeO))); + FI_GPU_CALL(hipMemcpy(o_d, o.data(), o.size() * sizeof(DTypeO), hipMemcpyHostToDevice)); + + DTypeO* tmp_d; + FI_GPU_CALL(hipMalloc(&tmp_d, 16 * 1024 * 1024 * sizeof(DTypeO))); + + hipError_t status = flashinfer::SinglePrefillWithKVCache( + q_d, k_d, v_d, o_d, tmp_d, + /*lse=*/nullptr, num_qo_heads, num_kv_heads, qo_len, kv_len, head_dim, causal, kv_layout, + pos_encoding_mode, use_fp16_qk_reduction); + + EXPECT_EQ(status, hipSuccess) << "SinglePrefillWithKVCache kernel launch failed, error message: " + << hipGetErrorString(status); + + std::vector o_h(o.size()); + FI_GPU_CALL(hipMemcpy(o_h.data(), o_d, o_h.size() * sizeof(DTypeO), hipMemcpyDeviceToHost)); + + // Print the first 10 elements of the output vector for debugging + // std::cout << "Output vector (first 10 elements):"; + // std::cout << "[" << std::endl; + // for (int i = 0; i < 10; ++i) { + // std::cout << fi::con::explicit_casting(o_h[i]) << " "; + // } + // std::cout << "]" << std::endl; + + bool isEmpty = o_h.empty(); + EXPECT_EQ(isEmpty, false) << "Output vector is empty"; + + std::vector o_ref = test_compute_qk_and_softmax_cpu( + q, k, v, qo_len, kv_len, num_qo_heads, num_kv_heads, head_dim, causal, kv_layout, + pos_encoding_mode); + size_t num_results_error_atol = 0; + bool nan_detected = false; + + for (size_t i = 0; i < o_ref.size(); ++i) { + float o_h_val = fi::con::explicit_casting(o_h[i]); + float o_ref_val = fi::con::explicit_casting(o_ref[i]); + + if (isnan(o_h_val)) { + nan_detected = true; } - float result_accuracy = - 1. - float(num_results_error_atol) / float(o_ref.size()); - std::cout << "num_qo_heads=" << num_qo_heads - << ", num_kv_heads=" << num_kv_heads << ", qo_len=" << qo_len - << ", kv_len=" << kv_len << ", head_dim=" << head_dim - << ", causal=" << causal - << ", kv_layout=" << QKVLayoutToString(kv_layout) - << ", pos_encoding_mode=" - << PosEncodingModeToString(pos_encoding_mode) - << ", result_accuracy=" << result_accuracy << std::endl; - - EXPECT_GT(result_accuracy, 0.90) << "Result correctness test failed."; - EXPECT_FALSE(nan_detected) << "Nan detected in the result."; - - FI_GPU_CALL(hipFree(q_d)); - FI_GPU_CALL(hipFree(k_d)); - FI_GPU_CALL(hipFree(v_d)); - FI_GPU_CALL(hipFree(o_d)); - FI_GPU_CALL(hipFree(tmp_d)); + num_results_error_atol += (!utils::isclose(o_ref_val, o_h_val, rtol, atol)); + if (!utils::isclose(o_ref_val, o_h_val, rtol, atol)) { + std::cout << "i=" << i << ", o_ref[i]=" << o_ref_val << ", o_h[i]=" << o_h_val << std::endl; + } + } + + float result_accuracy = 1. - float(num_results_error_atol) / float(o_ref.size()); + std::cout << "num_qo_heads=" << num_qo_heads << ", num_kv_heads=" << num_kv_heads + << ", qo_len=" << qo_len << ", kv_len=" << kv_len << ", head_dim=" << head_dim + << ", causal=" << causal << ", kv_layout=" << QKVLayoutToString(kv_layout) + << ", pos_encoding_mode=" << PosEncodingModeToString(pos_encoding_mode) + << ", result_accuracy=" << result_accuracy << std::endl; + + EXPECT_GT(result_accuracy, 0.90) << "Result correctness test failed."; + EXPECT_FALSE(nan_detected) << "Nan detected in the result."; + + FI_GPU_CALL(hipFree(q_d)); + FI_GPU_CALL(hipFree(k_d)); + FI_GPU_CALL(hipFree(v_d)); + FI_GPU_CALL(hipFree(o_d)); + FI_GPU_CALL(hipFree(tmp_d)); } -int main(int argc, char **argv) -{ - - using DTypeIn = __half; - using DTypeO = __half; - bool use_fp16_qk_reduction = false; - size_t qo_len = 399; - size_t kv_len = 533; - size_t num_heads = 1; - size_t head_dim = 64; - bool causal = false; - size_t pos_encoding_mode = 0; - size_t kv_layout = 0; - - _TestComputeSFMCorrectness( - qo_len, kv_len, num_heads, num_heads, head_dim, causal, - QKVLayout(kv_layout), PosEncodingMode(pos_encoding_mode), - use_fp16_qk_reduction); +int main(int argc, char** argv) { + using DTypeIn = __half; + using DTypeO = __half; + bool use_fp16_qk_reduction = false; + size_t qo_len = 399; + size_t kv_len = 533; + size_t num_heads = 1; + size_t head_dim = 64; + bool causal = false; + size_t pos_encoding_mode = 0; + size_t kv_layout = 0; + + _TestComputeSFMCorrectness( + qo_len, kv_len, num_heads, num_heads, head_dim, causal, QKVLayout(kv_layout), + PosEncodingMode(pos_encoding_mode), use_fp16_qk_reduction); } diff --git a/libflashinfer/tests/hip/test_k_smem_read_pattern.cpp b/libflashinfer/tests/hip/test_k_smem_read_pattern.cpp index 1676c88285..748f042081 100644 --- a/libflashinfer/tests/hip/test_k_smem_read_pattern.cpp +++ b/libflashinfer/tests/hip/test_k_smem_read_pattern.cpp @@ -1,189 +1,167 @@ +#include + #include #include -#include #include #include // Constants for MI300 -constexpr uint32_t WARP_SIZE = 64; // 64 threads per wavefront -constexpr uint32_t HALF_ELEMS_PER_THREAD = - 4; // Each thread processes 4 half elements -constexpr uint32_t INT32_ELEMS_PER_THREAD = 2; // 2 int32 registers per thread +constexpr uint32_t WARP_SIZE = 64; // 64 threads per wavefront +constexpr uint32_t HALF_ELEMS_PER_THREAD = 4; // Each thread processes 4 half elements +constexpr uint32_t INT32_ELEMS_PER_THREAD = 2; // 2 int32 registers per thread // Simplified linear shared memory operations (CPU implementation) template -uint32_t get_permuted_offset_linear(uint32_t row, uint32_t col) -{ - return row * stride + col; +uint32_t get_permuted_offset_linear(uint32_t row, uint32_t col) { + return row * stride + col; } template -uint32_t advance_offset_by_column_linear(uint32_t offset, uint32_t step_idx) -{ - return offset + step_size; +uint32_t advance_offset_by_column_linear(uint32_t offset, uint32_t step_idx) { + return offset + step_size; } template -uint32_t advance_offset_by_row_linear(uint32_t offset) -{ - return offset + step_size * row_stride; +uint32_t advance_offset_by_row_linear(uint32_t offset) { + return offset + step_size * row_stride; } // CPU-based simulation of k-matrix access pattern in compute_qk template -void SimulateKReadPattern(std::vector &thread_ids_reading_offsets) -{ - // Constants derived from HEAD_DIM - constexpr uint32_t UPCAST_STRIDE_K = HEAD_DIM / HALF_ELEMS_PER_THREAD; - constexpr uint32_t NUM_MMA_D_QK = HEAD_DIM / 16; - constexpr uint32_t grid_width = HEAD_DIM / HALF_ELEMS_PER_THREAD; - constexpr uint32_t grid_height = 16 * NUM_MMA_KV; - - constexpr uint32_t K_SMEM_COLUMN_ADVANCE = - 16 / HALF_ELEMS_PER_THREAD; // = 4 for MI300 - - // Initialize with -1 (unread) - thread_ids_reading_offsets.assign(grid_height * grid_width, -1); - - // Simulate each thread's read pattern - for (uint32_t tid = 0; tid < WARP_SIZE; tid++) { - // Map tid to kernel's lane_idx - uint32_t lane_idx = tid; - uint32_t warp_idx_kv = 0; // For simplicity, assuming one warp group - - // Exactly match the kernel's initial offset calculation - MI300 version - uint32_t k_smem_offset_r = get_permuted_offset_linear( - warp_idx_kv * NUM_MMA_KV * 16 + 4 * (lane_idx / 16) + lane_idx % 4, - (lane_idx % 16) / 4); - - // uint32_t k_smem_offset_r = - // get_permuted_offset_linear( - // warp_idx_kv * NUM_MMA_KV * 16 + - // 4 * (lane_idx / 16), - // (lane_idx % 16)); - - // Follow the same loop structure as in compute_qk - for (uint32_t mma_d = 0; mma_d < NUM_MMA_D_QK; ++mma_d) { - for (uint32_t mma_kv = 0; mma_kv < NUM_MMA_KV; ++mma_kv) { - // Mark grid positions accessed by ldmatrix_m8n8x4 / - // load_fragment - uint32_t read_row = k_smem_offset_r / UPCAST_STRIDE_K; - uint32_t read_col = k_smem_offset_r % UPCAST_STRIDE_K; - - if (tid == 0) { - std::cout << "Thread " << tid << " k_smem_offset_r " - << k_smem_offset_r << '\n'; - } - - // Simulate loading a matrix fragment - for (uint32_t reg_id = 0; reg_id < INT32_ELEMS_PER_THREAD; - reg_id++) - { - if (read_row < grid_height && read_col < grid_width) { - thread_ids_reading_offsets[read_row * grid_width + - read_col] = tid; - } - - // Each INT32_ELEMS_PER_THREAD register holds 2 half - // elements For simplicity, we're just recording the base - // offset - } - - // Advance to next row, exactly as in compute_qk - k_smem_offset_r = - advance_offset_by_row_linear<16, UPCAST_STRIDE_K>( - k_smem_offset_r); - } - - // Reset row position and advance to next column section, exactly as - // in compute_qk For MI300, advance by 4 columns (vs 2 for NVIDIA) - k_smem_offset_r = - advance_offset_by_column_linear( - k_smem_offset_r, mma_d) - - NUM_MMA_KV * 16 * UPCAST_STRIDE_K; +void SimulateKReadPattern(std::vector& thread_ids_reading_offsets) { + // Constants derived from HEAD_DIM + constexpr uint32_t UPCAST_STRIDE_K = HEAD_DIM / HALF_ELEMS_PER_THREAD; + constexpr uint32_t NUM_MMA_D_QK = HEAD_DIM / 16; + constexpr uint32_t grid_width = HEAD_DIM / HALF_ELEMS_PER_THREAD; + constexpr uint32_t grid_height = 16 * NUM_MMA_KV; + + constexpr uint32_t K_SMEM_COLUMN_ADVANCE = 16 / HALF_ELEMS_PER_THREAD; // = 4 for MI300 + + // Initialize with -1 (unread) + thread_ids_reading_offsets.assign(grid_height * grid_width, -1); + + // Simulate each thread's read pattern + for (uint32_t tid = 0; tid < WARP_SIZE; tid++) { + // Map tid to kernel's lane_idx + uint32_t lane_idx = tid; + uint32_t warp_idx_kv = 0; // For simplicity, assuming one warp group + + // Exactly match the kernel's initial offset calculation - MI300 version + uint32_t k_smem_offset_r = get_permuted_offset_linear( + warp_idx_kv * NUM_MMA_KV * 16 + 4 * (lane_idx / 16) + lane_idx % 4, (lane_idx % 16) / 4); + + // uint32_t k_smem_offset_r = + // get_permuted_offset_linear( + // warp_idx_kv * NUM_MMA_KV * 16 + + // 4 * (lane_idx / 16), + // (lane_idx % 16)); + + // Follow the same loop structure as in compute_qk + for (uint32_t mma_d = 0; mma_d < NUM_MMA_D_QK; ++mma_d) { + for (uint32_t mma_kv = 0; mma_kv < NUM_MMA_KV; ++mma_kv) { + // Mark grid positions accessed by ldmatrix_m8n8x4 / + // load_fragment + uint32_t read_row = k_smem_offset_r / UPCAST_STRIDE_K; + uint32_t read_col = k_smem_offset_r % UPCAST_STRIDE_K; + + if (tid == 0) { + std::cout << "Thread " << tid << " k_smem_offset_r " << k_smem_offset_r << '\n'; } - } -} -// Helper function to run the test with configurable parameters -template void RunKReadPatternTest() -{ - constexpr uint32_t grid_width = HEAD_DIM / HALF_ELEMS_PER_THREAD; - constexpr uint32_t grid_height = 16 * NUM_MMA_KV; - - printf("\n=== Testing key read pattern with HEAD_DIM = %u, NUM_MMA_KV = %u " - "===\n", - HEAD_DIM, NUM_MMA_KV); - - // Host array to store thread IDs at each offset - std::vector thread_ids(grid_height * grid_width, -1); + // Simulate loading a matrix fragment + for (uint32_t reg_id = 0; reg_id < INT32_ELEMS_PER_THREAD; reg_id++) { + if (read_row < grid_height && read_col < grid_width) { + thread_ids_reading_offsets[read_row * grid_width + read_col] = tid; + } - // Run CPU simulation of read pattern - SimulateKReadPattern(thread_ids); + // Each INT32_ELEMS_PER_THREAD register holds 2 half + // elements For simplicity, we're just recording the base + // offset + } - // Print the grid of thread IDs - printf("Thread IDs reading from each offset (%dx%d grid):\n", grid_height, - grid_width); + // Advance to next row, exactly as in compute_qk + k_smem_offset_r = advance_offset_by_row_linear<16, UPCAST_STRIDE_K>(k_smem_offset_r); + } - // Column headers - printf(" "); - for (int c = 0; c < grid_width; c++) { - printf("%3d ", c); - if (c == 15 && grid_width > 16) - printf("| "); // Divider for HEAD_DIM=128 + // Reset row position and advance to next column section, exactly as + // in compute_qk For MI300, advance by 4 columns (vs 2 for NVIDIA) + k_smem_offset_r = + advance_offset_by_column_linear(k_smem_offset_r, mma_d) - + NUM_MMA_KV * 16 * UPCAST_STRIDE_K; } - printf("\n +"); + } +} + +// Helper function to run the test with configurable parameters +template +void RunKReadPatternTest() { + constexpr uint32_t grid_width = HEAD_DIM / HALF_ELEMS_PER_THREAD; + constexpr uint32_t grid_height = 16 * NUM_MMA_KV; + + printf( + "\n=== Testing key read pattern with HEAD_DIM = %u, NUM_MMA_KV = %u " + "===\n", + HEAD_DIM, NUM_MMA_KV); + + // Host array to store thread IDs at each offset + std::vector thread_ids(grid_height * grid_width, -1); + + // Run CPU simulation of read pattern + SimulateKReadPattern(thread_ids); + + // Print the grid of thread IDs + printf("Thread IDs reading from each offset (%dx%d grid):\n", grid_height, grid_width); + + // Column headers + printf(" "); + for (int c = 0; c < grid_width; c++) { + printf("%3d ", c); + if (c == 15 && grid_width > 16) printf("| "); // Divider for HEAD_DIM=128 + } + printf("\n +"); + for (int c = 0; c < grid_width; c++) { + printf("----"); + if (c == 15 && grid_width > 16) printf("+"); + } + printf("\n"); + + // Print the grid + for (int r = 0; r < grid_height; r++) { + printf("%2d | ", r); for (int c = 0; c < grid_width; c++) { - printf("----"); - if (c == 15 && grid_width > 16) - printf("+"); + int thread_id = thread_ids[r * grid_width + c]; + if (thread_id >= 0) { + printf("%3d ", thread_id); + } else { + printf(" . "); // Dot for unread positions + } + if (c == 15 && grid_width > 16) printf("| "); // Divider for HEAD_DIM=128 } printf("\n"); + } - // Print the grid - for (int r = 0; r < grid_height; r++) { - printf("%2d | ", r); - for (int c = 0; c < grid_width; c++) { - int thread_id = thread_ids[r * grid_width + c]; - if (thread_id >= 0) { - printf("%3d ", thread_id); - } - else { - printf(" . "); // Dot for unread positions - } - if (c == 15 && grid_width > 16) - printf("| "); // Divider for HEAD_DIM=128 - } - printf("\n"); - } - - // Check for unread positions - int unread = 0; - for (int i = 0; i < grid_height * grid_width; i++) { - if (thread_ids[i] == -1) { - unread++; - } + // Check for unread positions + int unread = 0; + for (int i = 0; i < grid_height * grid_width; i++) { + if (thread_ids[i] == -1) { + unread++; } - - // Print statistics - printf("\nStatistics:\n"); - printf("- Positions read: %d/%d (%.1f%%)\n", - grid_height * grid_width - unread, grid_height * grid_width, - 100.0f * (grid_height * grid_width - unread) / - (grid_height * grid_width)); - printf("- Unread positions: %d/%d (%.1f%%)\n", unread, - grid_height * grid_width, - 100.0f * unread / (grid_height * grid_width)); - - // Validate full coverage - EXPECT_EQ(unread, 0) << "Not all positions were read"; + } + + // Print statistics + printf("\nStatistics:\n"); + printf("- Positions read: %d/%d (%.1f%%)\n", grid_height * grid_width - unread, + grid_height * grid_width, + 100.0f * (grid_height * grid_width - unread) / (grid_height * grid_width)); + printf("- Unread positions: %d/%d (%.1f%%)\n", unread, grid_height * grid_width, + 100.0f * unread / (grid_height * grid_width)); + + // Validate full coverage + EXPECT_EQ(unread, 0) << "Not all positions were read"; } // Tests for different configurations -TEST(MI300KReadPatternTest, HeadDim64_NumMmaKV1) -{ - RunKReadPatternTest<64, 1>(); -} +TEST(MI300KReadPatternTest, HeadDim64_NumMmaKV1) { RunKReadPatternTest<64, 1>(); } // TEST(MI300KReadPatternTest, HeadDim128_NumMmaKV1) { // RunKReadPatternTest<128, 1>(); @@ -197,8 +175,7 @@ TEST(MI300KReadPatternTest, HeadDim64_NumMmaKV1) // RunKReadPatternTest<128, 2>(); // } -int main(int argc, char **argv) -{ - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); } diff --git a/libflashinfer/tests/hip/test_load_q_global_smem.cpp b/libflashinfer/tests/hip/test_load_q_global_smem.cpp index bedaf3afb7..d4e381aa62 100644 --- a/libflashinfer/tests/hip/test_load_q_global_smem.cpp +++ b/libflashinfer/tests/hip/test_load_q_global_smem.cpp @@ -2,10 +2,11 @@ // // SPDX - License - Identifier : Apache 2.0 -#include -#include #include #include + +#include +#include #include #include @@ -13,21 +14,19 @@ #include "flashinfer/attention/generic/prefill.cuh" #include "flashinfer/attention/generic/variants.cuh" #include "utils/cpu_reference_hip.h" -#include "utils/utils_hip.h" // vec_normal_ +#include "utils/utils_hip.h" // vec_normal_ -namespace -{ +namespace { constexpr uint32_t qo_len = 64; constexpr uint32_t num_qo_heads = 1; constexpr uint32_t head_dim = 64; -} // namespace +} // namespace // CPU reference implementation that creates a Q matrix with a kNHD layout and // initializes. -void initialize_cpu_q() -{ - std::vector q(qo_len * num_qo_heads * head_dim); - utils::vec_normal_(q); +void initialize_cpu_q() { + std::vector q(qo_len * num_qo_heads * head_dim); + utils::vec_normal_(q); } // Validates the original Q matrix on CPU with the copied over data from GPU. diff --git a/libflashinfer/tests/hip/test_load_q_global_smem_v1.cpp b/libflashinfer/tests/hip/test_load_q_global_smem_v1.cpp index c9d8c3840a..ffad3a5a47 100644 --- a/libflashinfer/tests/hip/test_load_q_global_smem_v1.cpp +++ b/libflashinfer/tests/hip/test_load_q_global_smem_v1.cpp @@ -1,179 +1,157 @@ +#include + #include #include -#include #include #include // Constants for MI300 -constexpr uint32_t WARP_STEP_SIZE = 16; // 16 threads per warp row -constexpr uint32_t QUERY_ELEMS_PER_THREAD = - 4; // Each thread loads 4 fp16 elements -constexpr uint32_t WARP_THREAD_ROWS = 4; // 4 rows of threads in a warp +constexpr uint32_t WARP_STEP_SIZE = 16; // 16 threads per warp row +constexpr uint32_t QUERY_ELEMS_PER_THREAD = 4; // Each thread loads 4 fp16 elements +constexpr uint32_t WARP_THREAD_ROWS = 4; // 4 rows of threads in a warp // Simplified linear shared memory operations (CPU implementation) template -uint32_t get_permuted_offset_linear(uint32_t row, uint32_t col) -{ - return row * stride + col; +uint32_t get_permuted_offset_linear(uint32_t row, uint32_t col) { + return row * stride + col; } template -uint32_t advance_offset_by_column_linear(uint32_t offset, uint32_t step_idx) -{ - return offset + step_size; +uint32_t advance_offset_by_column_linear(uint32_t offset, uint32_t step_idx) { + return offset + step_size; } template -uint32_t advance_offset_by_row_linear(uint32_t offset) -{ - return offset + step_size * row_stride; +uint32_t advance_offset_by_row_linear(uint32_t offset) { + return offset + step_size * row_stride; } // CPU-based offset pattern verification with configurable NUM_MMA_Q template -void SimulateOffsetPattern(std::vector &thread_ids_at_offsets) -{ - // Constants derived from HEAD_DIM - constexpr uint32_t UPCAST_STRIDE_Q = HEAD_DIM / QUERY_ELEMS_PER_THREAD; - constexpr uint32_t NUM_MMA_D_QK = HEAD_DIM / 16; - constexpr uint32_t COLUMN_RESET_OFFSET = - (NUM_MMA_D_QK / 4) * WARP_STEP_SIZE; - constexpr uint32_t grid_width = - (HEAD_DIM / QUERY_ELEMS_PER_THREAD); // 16 for 64, 32 for 128 - constexpr uint32_t grid_height = - 16 * NUM_MMA_Q; // 16 for NUM_MMA_Q=1, 32 for NUM_MMA_Q=2 - - // Initialize with -1 (unwritten) - thread_ids_at_offsets.assign(grid_height * grid_width, -1); - - // Simulate each thread - for (uint32_t tid = 0; tid < 64; tid++) { - uint32_t row = tid / WARP_STEP_SIZE; // 0-3 for 64 threads - uint32_t col = tid % WARP_STEP_SIZE; // 0-15 - - // Calculate initial offset using linear addressing - uint32_t q_smem_offset_w = - get_permuted_offset_linear(row, col); - - // Main loop structure from load_q_global_smem - for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { - for (uint32_t j = 0; j < 4; ++j) { - // Calculate sequence index - const uint32_t seq_idx = row + mma_q * 16 + j; - - for (uint32_t mma_do = 0; mma_do < NUM_MMA_D_QK / 4; ++mma_do) { - // Record which thread wrote to this offset - if (q_smem_offset_w < grid_height * grid_width) - { // Safety check - thread_ids_at_offsets[q_smem_offset_w] = tid; - } - else { - printf("ERROR by tid: %d, offset: %d\n", tid, - q_smem_offset_w); - } - - // Advance to next column within same row - q_smem_offset_w = - advance_offset_by_column_linear( - q_smem_offset_w, mma_do); - } - - // Advance to next sequence (row) with adjustment back to first - // column - q_smem_offset_w = advance_offset_by_row_linear( - q_smem_offset_w) - - COLUMN_RESET_OFFSET; - } +void SimulateOffsetPattern(std::vector& thread_ids_at_offsets) { + // Constants derived from HEAD_DIM + constexpr uint32_t UPCAST_STRIDE_Q = HEAD_DIM / QUERY_ELEMS_PER_THREAD; + constexpr uint32_t NUM_MMA_D_QK = HEAD_DIM / 16; + constexpr uint32_t COLUMN_RESET_OFFSET = (NUM_MMA_D_QK / 4) * WARP_STEP_SIZE; + constexpr uint32_t grid_width = (HEAD_DIM / QUERY_ELEMS_PER_THREAD); // 16 for 64, 32 for 128 + constexpr uint32_t grid_height = 16 * NUM_MMA_Q; // 16 for NUM_MMA_Q=1, 32 for NUM_MMA_Q=2 + + // Initialize with -1 (unwritten) + thread_ids_at_offsets.assign(grid_height * grid_width, -1); + + // Simulate each thread + for (uint32_t tid = 0; tid < 64; tid++) { + uint32_t row = tid / WARP_STEP_SIZE; // 0-3 for 64 threads + uint32_t col = tid % WARP_STEP_SIZE; // 0-15 + + // Calculate initial offset using linear addressing + uint32_t q_smem_offset_w = get_permuted_offset_linear(row, col); + + // Main loop structure from load_q_global_smem + for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { + for (uint32_t j = 0; j < 4; ++j) { + // Calculate sequence index + const uint32_t seq_idx = row + mma_q * 16 + j; + + for (uint32_t mma_do = 0; mma_do < NUM_MMA_D_QK / 4; ++mma_do) { + // Record which thread wrote to this offset + if (q_smem_offset_w < grid_height * grid_width) { // Safety check + thread_ids_at_offsets[q_smem_offset_w] = tid; + } else { + printf("ERROR by tid: %d, offset: %d\n", tid, q_smem_offset_w); + } + + // Advance to next column within same row + q_smem_offset_w = + advance_offset_by_column_linear(q_smem_offset_w, mma_do); } + + // Advance to next sequence (row) with adjustment back to first + // column + q_smem_offset_w = + advance_offset_by_row_linear(q_smem_offset_w) - + COLUMN_RESET_OFFSET; + } } + } } // Helper function to run the test with configurable NUM_MMA_Q -template void RunOffsetTest() -{ - constexpr uint32_t grid_width = - (HEAD_DIM / QUERY_ELEMS_PER_THREAD); // 16 for 64, 32 for 128 - constexpr uint32_t grid_height = - 16 * NUM_MMA_Q; // 16 for NUM_MMA_Q=1, 32 for NUM_MMA_Q=2 - - printf("\n=== Testing offset calculations with HEAD_DIM = %u, NUM_MMA_Q = " - "%u ===\n", - HEAD_DIM, NUM_MMA_Q); - - // Host array to store thread IDs at each offset - std::vector thread_ids(grid_height * grid_width, -1); - - // Run CPU simulation of offset pattern - SimulateOffsetPattern(thread_ids); - - // Print the grid of thread IDs (potentially truncated for readability) - printf("Thread IDs writing to each offset (%dx%d grid):\n", grid_height, - grid_width); - - // Column headers - printf(" "); +template +void RunOffsetTest() { + constexpr uint32_t grid_width = (HEAD_DIM / QUERY_ELEMS_PER_THREAD); // 16 for 64, 32 for 128 + constexpr uint32_t grid_height = 16 * NUM_MMA_Q; // 16 for NUM_MMA_Q=1, 32 for NUM_MMA_Q=2 + + printf( + "\n=== Testing offset calculations with HEAD_DIM = %u, NUM_MMA_Q = " + "%u ===\n", + HEAD_DIM, NUM_MMA_Q); + + // Host array to store thread IDs at each offset + std::vector thread_ids(grid_height * grid_width, -1); + + // Run CPU simulation of offset pattern + SimulateOffsetPattern(thread_ids); + + // Print the grid of thread IDs (potentially truncated for readability) + printf("Thread IDs writing to each offset (%dx%d grid):\n", grid_height, grid_width); + + // Column headers + printf(" "); + for (int c = 0; c < grid_width; c++) { + printf("%3d ", c); + if (c == 15 && grid_width > 16) printf("| "); // Divider between first and second half + } + printf("\n +"); + for (int c = 0; c < grid_width; c++) { + printf("----"); + if (c == 15 && grid_width > 16) printf("+"); // Divider between first and second half + } + printf("\n"); + + // Print quadrants with clear separation + for (int r = 0; r < grid_height; r++) { + printf("%2d | ", r); for (int c = 0; c < grid_width; c++) { - printf("%3d ", c); - if (c == 15 && grid_width > 16) - printf("| "); // Divider between first and second half - } - printf("\n +"); - for (int c = 0; c < grid_width; c++) { - printf("----"); - if (c == 15 && grid_width > 16) - printf("+"); // Divider between first and second half + int thread_id = thread_ids[r * grid_width + c]; + if (thread_id >= 0) { + printf("%3d ", thread_id); + } else { + printf(" . "); // Dot for unwritten positions + } + if (c == 15 && grid_width > 16) printf("| "); // Divider between first and second half } printf("\n"); - // Print quadrants with clear separation - for (int r = 0; r < grid_height; r++) { - printf("%2d | ", r); - for (int c = 0; c < grid_width; c++) { - int thread_id = thread_ids[r * grid_width + c]; - if (thread_id >= 0) { - printf("%3d ", thread_id); - } - else { - printf(" . "); // Dot for unwritten positions - } - if (c == 15 && grid_width > 16) - printf("| "); // Divider between first and second half - } - printf("\n"); - - // Add horizontal divider between first and second block of sequences - if (r == 15 && NUM_MMA_Q > 1) { - printf(" +"); - for (int c = 0; c < grid_width; c++) { - printf("----"); - if (c == 15 && grid_width > 16) - printf("+"); // Intersection divider - } - printf("\n"); - } + // Add horizontal divider between first and second block of sequences + if (r == 15 && NUM_MMA_Q > 1) { + printf(" +"); + for (int c = 0; c < grid_width; c++) { + printf("----"); + if (c == 15 && grid_width > 16) printf("+"); // Intersection divider + } + printf("\n"); } + } - // Check for unwritten positions - int unwritten = 0; - for (int i = 0; i < grid_height * grid_width; i++) { - if (thread_ids[i] == -1) { - unwritten++; - } + // Check for unwritten positions + int unwritten = 0; + for (int i = 0; i < grid_height * grid_width; i++) { + if (thread_ids[i] == -1) { + unwritten++; } - - // Print statistics - printf("\nStatistics:\n"); - printf("- Positions written: %d/%d (%.1f%%)\n", - grid_height * grid_width - unwritten, grid_height * grid_width, - 100.0f * (grid_height * grid_width - unwritten) / - (grid_height * grid_width)); - printf("- Unwritten positions: %d/%d (%.1f%%)\n", unwritten, - grid_height * grid_width, - 100.0f * unwritten / (grid_height * grid_width)); - - // Validate full coverage - EXPECT_EQ(unwritten, 0) << "Not all positions were written"; + } + + // Print statistics + printf("\nStatistics:\n"); + printf("- Positions written: %d/%d (%.1f%%)\n", grid_height * grid_width - unwritten, + grid_height * grid_width, + 100.0f * (grid_height * grid_width - unwritten) / (grid_height * grid_width)); + printf("- Unwritten positions: %d/%d (%.1f%%)\n", unwritten, grid_height * grid_width, + 100.0f * unwritten / (grid_height * grid_width)); + + // Validate full coverage + EXPECT_EQ(unwritten, 0) << "Not all positions were written"; } // Original tests with NUM_MMA_Q = 1 @@ -186,8 +164,7 @@ TEST(MI300OffsetTest, HeadDim64_NumMmaQ2) { RunOffsetTest<64, 2>(); } TEST(MI300OffsetTest, HeadDim128_NumMmaQ2) { RunOffsetTest<128, 2>(); } -int main(int argc, char **argv) -{ - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); } diff --git a/libflashinfer/tests/hip/test_load_q_global_smem_v2.cpp b/libflashinfer/tests/hip/test_load_q_global_smem_v2.cpp index c57628fbf7..dc05a86b4c 100644 --- a/libflashinfer/tests/hip/test_load_q_global_smem_v2.cpp +++ b/libflashinfer/tests/hip/test_load_q_global_smem_v2.cpp @@ -1,8 +1,9 @@ // test_load_q_global_smem.cpp -#include -#include #include #include + +#include +#include #include #include @@ -17,273 +18,230 @@ using namespace flashinfer; // CPU Reference Implementation for Q Loading template -std::vector -cpu_reference_q_smem_layout(const std::vector &q_global, - size_t qo_len, - size_t num_qo_heads, - size_t head_dim, - size_t q_stride_n, - size_t q_stride_h, - size_t qo_packed_idx_base, - uint32_t group_size, - size_t smem_height, - size_t smem_width) -{ - std::vector q_smem_expected(smem_height * smem_width, DTypeQ(0)); - - // Simulate the loading pattern that load_q_global_smem should follow - for (size_t smem_row = 0; smem_row < smem_height; ++smem_row) { - uint32_t q_packed_idx = qo_packed_idx_base + smem_row; - uint32_t q_idx = q_packed_idx / group_size; // Sequence position - uint32_t r = q_packed_idx % group_size; // Head offset within group - - if (q_idx < qo_len) { - for (size_t feat_idx = 0; feat_idx < head_dim; ++feat_idx) { - // Calculate global memory offset - size_t global_offset = - q_idx * q_stride_n + r * q_stride_h + feat_idx; - - // Place in shared memory layout (assuming linear layout for - // test) - size_t smem_offset = smem_row * smem_width + feat_idx; - if (global_offset < q_global.size()) { - q_smem_expected[smem_offset] = q_global[global_offset]; - } - } +std::vector cpu_reference_q_smem_layout(const std::vector& q_global, size_t qo_len, + size_t num_qo_heads, size_t head_dim, + size_t q_stride_n, size_t q_stride_h, + size_t qo_packed_idx_base, uint32_t group_size, + size_t smem_height, size_t smem_width) { + std::vector q_smem_expected(smem_height * smem_width, DTypeQ(0)); + + // Simulate the loading pattern that load_q_global_smem should follow + for (size_t smem_row = 0; smem_row < smem_height; ++smem_row) { + uint32_t q_packed_idx = qo_packed_idx_base + smem_row; + uint32_t q_idx = q_packed_idx / group_size; // Sequence position + uint32_t r = q_packed_idx % group_size; // Head offset within group + + if (q_idx < qo_len) { + for (size_t feat_idx = 0; feat_idx < head_dim; ++feat_idx) { + // Calculate global memory offset + size_t global_offset = q_idx * q_stride_n + r * q_stride_h + feat_idx; + + // Place in shared memory layout (assuming linear layout for + // test) + size_t smem_offset = smem_row * smem_width + feat_idx; + if (global_offset < q_global.size()) { + q_smem_expected[smem_offset] = q_global[global_offset]; } + } } + } - return q_smem_expected; + return q_smem_expected; } -uint_fastdiv create_group_size_div(uint32_t group_size) -{ - return uint_fastdiv(group_size); -} +uint_fastdiv create_group_size_div(uint32_t group_size) { return uint_fastdiv(group_size); } // Test kernel for Q loading template -__global__ void test_q_loading_kernel(typename KTraits::DTypeQ *q_global, - typename KTraits::DTypeQ *q_smem_output, - uint32_t qo_packed_idx_base, - uint32_t qo_len, - uint32_t q_stride_n, - uint32_t q_stride_h, - uint_fastdiv group_size_div) -{ - // Set up shared memory - extern __shared__ uint8_t smem[]; - typename KTraits::SharedStorage &smem_storage = - reinterpret_cast(smem); - - smem_t q_smem( - smem_storage.q_smem); - - // Call the function we're testing - load_q_global_smem(qo_packed_idx_base, qo_len, q_global, - q_stride_n, q_stride_h, group_size_div, &q_smem, - threadIdx); - - // Synchronize to ensure loading is complete - __syncthreads(); - - if (threadIdx.y == 0 && threadIdx.z == 0) { - const uint32_t lane_idx = threadIdx.x; - constexpr uint32_t smem_height = KTraits::CTA_TILE_Q; // 16 - constexpr uint32_t smem_width = KTraits::HEAD_DIM_QK; // 64 - constexpr uint32_t total_elements = smem_height * smem_width; - - // Each thread copies using proper swizzled access - for (uint32_t linear_idx = lane_idx; linear_idx < total_elements; - linear_idx += KTraits::NUM_THREADS) - { - if (linear_idx < total_elements) { - uint32_t row = linear_idx / smem_width; - uint32_t col = linear_idx % smem_width; - uint32_t swizzled_offset = q_smem.template get_permuted_offset< - smem_width / upcast_size()>( - row, col / upcast_size()); - uint32_t element_idx = - col % upcast_size(); - typename KTraits::DTypeQ *smem_ptr = - reinterpret_cast( - q_smem.base + swizzled_offset); - q_smem_output[linear_idx] = smem_ptr[element_idx]; - } - } +__global__ void test_q_loading_kernel(typename KTraits::DTypeQ* q_global, + typename KTraits::DTypeQ* q_smem_output, + uint32_t qo_packed_idx_base, uint32_t qo_len, + uint32_t q_stride_n, uint32_t q_stride_h, + uint_fastdiv group_size_div) { + // Set up shared memory + extern __shared__ uint8_t smem[]; + typename KTraits::SharedStorage& smem_storage = + reinterpret_cast(smem); + + smem_t q_smem(smem_storage.q_smem); + + // Call the function we're testing + load_q_global_smem(qo_packed_idx_base, qo_len, q_global, q_stride_n, q_stride_h, + group_size_div, &q_smem, threadIdx); + + // Synchronize to ensure loading is complete + __syncthreads(); + + if (threadIdx.y == 0 && threadIdx.z == 0) { + const uint32_t lane_idx = threadIdx.x; + constexpr uint32_t smem_height = KTraits::CTA_TILE_Q; // 16 + constexpr uint32_t smem_width = KTraits::HEAD_DIM_QK; // 64 + constexpr uint32_t total_elements = smem_height * smem_width; + + // Each thread copies using proper swizzled access + for (uint32_t linear_idx = lane_idx; linear_idx < total_elements; + linear_idx += KTraits::NUM_THREADS) { + if (linear_idx < total_elements) { + uint32_t row = linear_idx / smem_width; + uint32_t col = linear_idx % smem_width; + uint32_t swizzled_offset = q_smem.template get_permuted_offset< + smem_width / upcast_size()>( + row, col / upcast_size()); + uint32_t element_idx = + col % upcast_size(); + typename KTraits::DTypeQ* smem_ptr = + reinterpret_cast(q_smem.base + swizzled_offset); + q_smem_output[linear_idx] = smem_ptr[element_idx]; + } } + } } // Main test function -template bool test_q_loading_correctness() -{ - std::cout << "Testing Q loading correctness with " << sizeof(DTypeQ) * 8 - << "-bit precision..." << std::endl; - - // Test parameters - small sizes for initial validation - constexpr size_t qo_len = 8; - constexpr size_t num_qo_heads = 8; - constexpr size_t num_kv_heads = 2; - constexpr size_t head_dim = 64; - constexpr uint32_t group_size = num_qo_heads / num_kv_heads; - - // Create test data with known pattern for easier debugging - const size_t q_size = qo_len * num_qo_heads * head_dim; - std::vector q_host(q_size); - - // Fill with simple pattern: row*1000 + col for easier validation - for (size_t i = 0; i < q_size; ++i) { - float val = float(i % 100) / 10.0f; // Values 0.0, 0.1, 0.2, ... 9.9 - q_host[i] = fi::con::explicit_casting(val); - } - - // GPU memory allocation - DTypeQ *q_device, *q_smem_output; - const size_t smem_elements = 16 * head_dim; // Single MMA block - FI_GPU_CALL(hipMalloc(&q_device, q_size * sizeof(DTypeQ))); - FI_GPU_CALL(hipMalloc(&q_smem_output, smem_elements * sizeof(DTypeQ))); - - FI_GPU_CALL(hipMemcpy(q_device, q_host.data(), q_size * sizeof(DTypeQ), - hipMemcpyHostToDevice)); - - // Define kernel traits for CDNA3 - using KTraits = - KernelTraits>; - - // Launch parameters - dim3 block_size(64, 1, 1); // CDNA3: 64 threads per wavefront - dim3 grid_size(1, 1, 1); - size_t shared_mem_size = sizeof(typename KTraits::SharedStorage); - - // Test parameters - const uint32_t qo_packed_idx_base = 0; // Start from beginning - const uint32_t q_stride_n = num_qo_heads * head_dim; - const uint32_t q_stride_h = head_dim; - - std::cout << "Launching kernel with:" << std::endl; - std::cout << " Block size: " << block_size.x << "x" << block_size.y << "x" - << block_size.z << std::endl; - std::cout << " Shared memory: " << shared_mem_size << " bytes" - << std::endl; - std::cout << " Q size: " << q_size << " elements" << std::endl; - - uint_fastdiv group_size_div = create_group_size_div(group_size); - - // Launch test kernel - test_q_loading_kernel<<>>( - q_device, q_smem_output, qo_packed_idx_base, qo_len, q_stride_n, - q_stride_h, group_size_div); - - FI_GPU_CALL(hipDeviceSynchronize()); - - // Get results back - std::vector q_smem_actual(smem_elements); - FI_GPU_CALL(hipMemcpy(q_smem_actual.data(), q_smem_output, - smem_elements * sizeof(DTypeQ), - hipMemcpyDeviceToHost)); - - // Generate CPU reference - std::vector q_smem_expected = cpu_reference_q_smem_layout( - q_host, qo_len, num_qo_heads, head_dim, q_stride_n, q_stride_h, - qo_packed_idx_base, group_size, 16, head_dim); - - // Compare results - bool passed = true; - float max_diff = 0.0f; - size_t mismatch_count = 0; - - std::cout << "\nValidation results:" << std::endl; - std::cout << "Comparing " << q_smem_actual.size() << " elements..." - << std::endl; - - for (size_t i = 0; - i < std::min(q_smem_actual.size(), q_smem_expected.size()); ++i) - { - float actual = - fi::con::explicit_casting(q_smem_actual[i]); - float expected = - fi::con::explicit_casting(q_smem_expected[i]); - float diff = std::abs(actual - expected); - max_diff = std::max(max_diff, diff); - - if (!utils::isclose(q_smem_actual[i], q_smem_expected[i], 1e-3f, 1e-4f)) - { - if (mismatch_count < 10) { // Show first 10 mismatches - size_t row = i / head_dim; - size_t col = i % head_dim; - std::cout << "Mismatch at [" << row << "][" << col - << "] (index " << i << "): " - << "expected " << expected << ", got " << actual - << ", diff " << diff << std::endl; - } - mismatch_count++; - passed = false; - } +template +bool test_q_loading_correctness() { + std::cout << "Testing Q loading correctness with " << sizeof(DTypeQ) * 8 << "-bit precision..." + << std::endl; + + // Test parameters - small sizes for initial validation + constexpr size_t qo_len = 8; + constexpr size_t num_qo_heads = 8; + constexpr size_t num_kv_heads = 2; + constexpr size_t head_dim = 64; + constexpr uint32_t group_size = num_qo_heads / num_kv_heads; + + // Create test data with known pattern for easier debugging + const size_t q_size = qo_len * num_qo_heads * head_dim; + std::vector q_host(q_size); + + // Fill with simple pattern: row*1000 + col for easier validation + for (size_t i = 0; i < q_size; ++i) { + float val = float(i % 100) / 10.0f; // Values 0.0, 0.1, 0.2, ... 9.9 + q_host[i] = fi::con::explicit_casting(val); + } + + // GPU memory allocation + DTypeQ *q_device, *q_smem_output; + const size_t smem_elements = 16 * head_dim; // Single MMA block + FI_GPU_CALL(hipMalloc(&q_device, q_size * sizeof(DTypeQ))); + FI_GPU_CALL(hipMalloc(&q_smem_output, smem_elements * sizeof(DTypeQ))); + + FI_GPU_CALL(hipMemcpy(q_device, q_host.data(), q_size * sizeof(DTypeQ), hipMemcpyHostToDevice)); + + // Define kernel traits for CDNA3 + using KTraits = + KernelTraits>; + + // Launch parameters + dim3 block_size(64, 1, 1); // CDNA3: 64 threads per wavefront + dim3 grid_size(1, 1, 1); + size_t shared_mem_size = sizeof(typename KTraits::SharedStorage); + + // Test parameters + const uint32_t qo_packed_idx_base = 0; // Start from beginning + const uint32_t q_stride_n = num_qo_heads * head_dim; + const uint32_t q_stride_h = head_dim; + + std::cout << "Launching kernel with:" << std::endl; + std::cout << " Block size: " << block_size.x << "x" << block_size.y << "x" << block_size.z + << std::endl; + std::cout << " Shared memory: " << shared_mem_size << " bytes" << std::endl; + std::cout << " Q size: " << q_size << " elements" << std::endl; + + uint_fastdiv group_size_div = create_group_size_div(group_size); + + // Launch test kernel + test_q_loading_kernel<<>>( + q_device, q_smem_output, qo_packed_idx_base, qo_len, q_stride_n, q_stride_h, group_size_div); + + FI_GPU_CALL(hipDeviceSynchronize()); + + // Get results back + std::vector q_smem_actual(smem_elements); + FI_GPU_CALL(hipMemcpy(q_smem_actual.data(), q_smem_output, smem_elements * sizeof(DTypeQ), + hipMemcpyDeviceToHost)); + + // Generate CPU reference + std::vector q_smem_expected = + cpu_reference_q_smem_layout(q_host, qo_len, num_qo_heads, head_dim, q_stride_n, q_stride_h, + qo_packed_idx_base, group_size, 16, head_dim); + + // Compare results + bool passed = true; + float max_diff = 0.0f; + size_t mismatch_count = 0; + + std::cout << "\nValidation results:" << std::endl; + std::cout << "Comparing " << q_smem_actual.size() << " elements..." << std::endl; + + for (size_t i = 0; i < std::min(q_smem_actual.size(), q_smem_expected.size()); ++i) { + float actual = fi::con::explicit_casting(q_smem_actual[i]); + float expected = fi::con::explicit_casting(q_smem_expected[i]); + float diff = std::abs(actual - expected); + max_diff = std::max(max_diff, diff); + + if (!utils::isclose(q_smem_actual[i], q_smem_expected[i], 1e-3f, 1e-4f)) { + if (mismatch_count < 10) { // Show first 10 mismatches + size_t row = i / head_dim; + size_t col = i % head_dim; + std::cout << "Mismatch at [" << row << "][" << col << "] (index " << i << "): " + << "expected " << expected << ", got " << actual << ", diff " << diff + << std::endl; + } + mismatch_count++; + passed = false; } - - std::cout << "Max difference: " << max_diff << std::endl; - std::cout << "Total mismatches: " << mismatch_count << " / " - << q_smem_actual.size() << std::endl; - std::cout << "Q loading test: " << (passed ? "PASSED" : "FAILED") - << std::endl; - - // Show some sample values for debugging - if (!passed) { - std::cout << "\nFirst 10 expected vs actual values:" << std::endl; - for (size_t i = 0; i < std::min(size_t(10), q_smem_actual.size()); ++i) - { - float actual = - fi::con::explicit_casting(q_smem_actual[i]); - float expected = - fi::con::explicit_casting(q_smem_expected[i]); - std::cout << "[" << i << "] expected: " << expected - << ", actual: " << actual << std::endl; - } + } + + std::cout << "Max difference: " << max_diff << std::endl; + std::cout << "Total mismatches: " << mismatch_count << " / " << q_smem_actual.size() << std::endl; + std::cout << "Q loading test: " << (passed ? "PASSED" : "FAILED") << std::endl; + + // Show some sample values for debugging + if (!passed) { + std::cout << "\nFirst 10 expected vs actual values:" << std::endl; + for (size_t i = 0; i < std::min(size_t(10), q_smem_actual.size()); ++i) { + float actual = fi::con::explicit_casting(q_smem_actual[i]); + float expected = fi::con::explicit_casting(q_smem_expected[i]); + std::cout << "[" << i << "] expected: " << expected << ", actual: " << actual << std::endl; } + } - // Cleanup - FI_GPU_CALL(hipFree(q_device)); - FI_GPU_CALL(hipFree(q_smem_output)); + // Cleanup + FI_GPU_CALL(hipFree(q_device)); + FI_GPU_CALL(hipFree(q_smem_output)); - return passed; + return passed; } // Main function -int main() -{ - std::cout << "=== FlashInfer Q Loading Component Test ===" << std::endl; - std::cout << "Testing load_q_global_smem function for CDNA3 architecture" - << std::endl; - - // Initialize HIP - hipError_t err = hipSetDevice(0); - if (err != hipSuccess) { - std::cout << "Failed to set HIP device: " << hipGetErrorString(err) - << std::endl; - return 1; - } - - hipDeviceProp_t prop; - FI_GPU_CALL(hipGetDeviceProperties(&prop, 0)); - std::cout << "Running on: " << prop.name << std::endl; - - bool all_passed = true; - - // Test with half precision - std::cout << "\n--- Testing with FP16 ---" << std::endl; - all_passed &= test_q_loading_correctness<__half>(); - - if (all_passed) { - std::cout << "\n✅ All Q loading tests PASSED!" << std::endl; - return 0; - } - else { - std::cout << "\n❌ Some Q loading tests FAILED!" << std::endl; - return 1; - } +int main() { + std::cout << "=== FlashInfer Q Loading Component Test ===" << std::endl; + std::cout << "Testing load_q_global_smem function for CDNA3 architecture" << std::endl; + + // Initialize HIP + hipError_t err = hipSetDevice(0); + if (err != hipSuccess) { + std::cout << "Failed to set HIP device: " << hipGetErrorString(err) << std::endl; + return 1; + } + + hipDeviceProp_t prop; + FI_GPU_CALL(hipGetDeviceProperties(&prop, 0)); + std::cout << "Running on: " << prop.name << std::endl; + + bool all_passed = true; + + // Test with half precision + std::cout << "\n--- Testing with FP16 ---" << std::endl; + all_passed &= test_q_loading_correctness<__half>(); + + if (all_passed) { + std::cout << "\n✅ All Q loading tests PASSED!" << std::endl; + return 0; + } else { + std::cout << "\n❌ Some Q loading tests FAILED!" << std::endl; + return 1; + } } diff --git a/libflashinfer/tests/hip/test_math.cpp b/libflashinfer/tests/hip/test_math.cpp index 0174b89474..c55fe0f28f 100644 --- a/libflashinfer/tests/hip/test_math.cpp +++ b/libflashinfer/tests/hip/test_math.cpp @@ -2,287 +2,257 @@ // // SPDX - License - Identifier : Apache 2.0 -#include "gpu_iface/math_ops.hpp" - #include +#include "gpu_iface/math_ops.hpp" + using namespace flashinfer::math; -#define CHECK_HIP_ERROR(call) \ - { \ - hipError_t err = call; \ - if (err != hipSuccess) { \ - std::cerr << "HIP error at " << __FILE__ << " : " << __LINE__ \ - << " -> " << hipGetErrorString(err) << std::endl; \ - exit(1); \ - } \ - } +#define CHECK_HIP_ERROR(call) \ + { \ + hipError_t err = call; \ + if (err != hipSuccess) { \ + std::cerr << "HIP error at " << __FILE__ << " : " << __LINE__ << " -> " \ + << hipGetErrorString(err) << std::endl; \ + exit(1); \ + } \ + } constexpr int NUM_VALUES = 5; constexpr size_t BLOCK_SIZE = 256; template -__global__ void test_ptx_exp2_kernel(T *x_values, T *results) -{ - int idx = threadIdx.x + blockIdx.x * blockDim.x; +__global__ void test_ptx_exp2_kernel(T* x_values, T* results) { + int idx = threadIdx.x + blockIdx.x * blockDim.x; - if (idx < NUM_VALUES) { - results[idx] = ptx_exp2(x_values[idx]); - } + if (idx < NUM_VALUES) { + results[idx] = ptx_exp2(x_values[idx]); + } } -__global__ void test_ptx_log2_kernel(float *x_values, float *results) -{ - int idx = threadIdx.x + blockIdx.x * blockDim.x; +__global__ void test_ptx_log2_kernel(float* x_values, float* results) { + int idx = threadIdx.x + blockIdx.x * blockDim.x; - if (idx < NUM_VALUES) { - results[idx] = ptx_log2(x_values[idx]); - } + if (idx < NUM_VALUES) { + results[idx] = ptx_log2(x_values[idx]); + } } -__global__ void test_ptx_rcp_kernel(float *x_values, float *results) -{ - int idx = threadIdx.x + blockIdx.x * blockDim.x; +__global__ void test_ptx_rcp_kernel(float* x_values, float* results) { + int idx = threadIdx.x + blockIdx.x * blockDim.x; - if (idx < NUM_VALUES) { - results[idx] = ptx_rcp(x_values[idx]); - } + if (idx < NUM_VALUES) { + results[idx] = ptx_rcp(x_values[idx]); + } } -__global__ void test_rsqrt_kernel(float *x_values, float *results) -{ - int idx = threadIdx.x + blockIdx.x * blockDim.x; +__global__ void test_rsqrt_kernel(float* x_values, float* results) { + int idx = threadIdx.x + blockIdx.x * blockDim.x; - if (idx < NUM_VALUES) { - results[idx] = rsqrt(x_values[idx]); - } + if (idx < NUM_VALUES) { + results[idx] = rsqrt(x_values[idx]); + } } -template __global__ void test_tanh_kernel(T *x_values, T *results) -{ - int idx = threadIdx.x + blockIdx.x * blockDim.x; +template +__global__ void test_tanh_kernel(T* x_values, T* results) { + int idx = threadIdx.x + blockIdx.x * blockDim.x; - if (idx < NUM_VALUES) { - results[idx] = tanh(x_values[idx]); - } + if (idx < NUM_VALUES) { + results[idx] = tanh(x_values[idx]); + } } -__global__ void test_shfl_xor_sync(float *input, float *output, int lane_mask) -{ - int lane = threadIdx.x % 64; - float val = input[lane]; - float result = shfl_xor_sync(val, lane_mask); - output[lane] = result; +__global__ void test_shfl_xor_sync(float* input, float* output, int lane_mask) { + int lane = threadIdx.x % 64; + float val = input[lane]; + float result = shfl_xor_sync(val, lane_mask); + output[lane] = result; } -TEST(hipFunctionsTest, TestPtxExp2Float) -{ - - float x_host[NUM_VALUES] = {0.0f, 1.0f, -1.0f, 2.0f, -2.0f}; - float results_host[NUM_VALUES]; +TEST(hipFunctionsTest, TestPtxExp2Float) { + float x_host[NUM_VALUES] = {0.0f, 1.0f, -1.0f, 2.0f, -2.0f}; + float results_host[NUM_VALUES]; - float *x_device, *results_device; + float *x_device, *results_device; - CHECK_HIP_ERROR(hipMalloc((void **)&x_device, NUM_VALUES * sizeof(float))); - CHECK_HIP_ERROR( - hipMalloc((void **)&results_device, NUM_VALUES * sizeof(float))); + CHECK_HIP_ERROR(hipMalloc((void**)&x_device, NUM_VALUES * sizeof(float))); + CHECK_HIP_ERROR(hipMalloc((void**)&results_device, NUM_VALUES * sizeof(float))); - CHECK_HIP_ERROR(hipMemcpy(x_device, x_host, NUM_VALUES * sizeof(float), - hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(x_device, x_host, NUM_VALUES * sizeof(float), hipMemcpyHostToDevice)); - int grid_size = (NUM_VALUES + BLOCK_SIZE - 1) / BLOCK_SIZE; + int grid_size = (NUM_VALUES + BLOCK_SIZE - 1) / BLOCK_SIZE; - test_ptx_exp2_kernel<<>>(x_device, results_device); + test_ptx_exp2_kernel<<>>(x_device, results_device); - CHECK_HIP_ERROR(hipMemcpy(results_host, results_device, - NUM_VALUES * sizeof(float), - hipMemcpyDeviceToHost)); + CHECK_HIP_ERROR( + hipMemcpy(results_host, results_device, NUM_VALUES * sizeof(float), hipMemcpyDeviceToHost)); - for (size_t i = 0; i < NUM_VALUES; ++i) { - x_host[i] = std::pow(2, x_host[i]); - } + for (size_t i = 0; i < NUM_VALUES; ++i) { + x_host[i] = std::pow(2, x_host[i]); + } - for (int i = 0; i < NUM_VALUES; ++i) { - EXPECT_NEAR(x_host[i], results_host[i], 1e-5); - } + for (int i = 0; i < NUM_VALUES; ++i) { + EXPECT_NEAR(x_host[i], results_host[i], 1e-5); + } - CHECK_HIP_ERROR(hipFree(x_device)); - CHECK_HIP_ERROR(hipFree(results_device)); + CHECK_HIP_ERROR(hipFree(x_device)); + CHECK_HIP_ERROR(hipFree(results_device)); } -TEST(hipFunctionsTest, TestPtxLog2) -{ - float x_host[NUM_VALUES] = {100.8, 37.85, 8.12f, 15.63, 29.0f}; - float results_host[NUM_VALUES]; +TEST(hipFunctionsTest, TestPtxLog2) { + float x_host[NUM_VALUES] = {100.8, 37.85, 8.12f, 15.63, 29.0f}; + float results_host[NUM_VALUES]; - float *x_device, *results_device; + float *x_device, *results_device; - CHECK_HIP_ERROR(hipMalloc((void **)&x_device, NUM_VALUES * sizeof(float))); - CHECK_HIP_ERROR( - hipMalloc((void **)&results_device, NUM_VALUES * sizeof(float))); + CHECK_HIP_ERROR(hipMalloc((void**)&x_device, NUM_VALUES * sizeof(float))); + CHECK_HIP_ERROR(hipMalloc((void**)&results_device, NUM_VALUES * sizeof(float))); - CHECK_HIP_ERROR(hipMemcpy(x_device, x_host, NUM_VALUES * sizeof(float), - hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(x_device, x_host, NUM_VALUES * sizeof(float), hipMemcpyHostToDevice)); - int grid_size = (NUM_VALUES + BLOCK_SIZE - 1) / BLOCK_SIZE; + int grid_size = (NUM_VALUES + BLOCK_SIZE - 1) / BLOCK_SIZE; - test_ptx_log2_kernel<<>>(x_device, results_device); + test_ptx_log2_kernel<<>>(x_device, results_device); - CHECK_HIP_ERROR(hipMemcpy(results_host, results_device, - NUM_VALUES * sizeof(float), - hipMemcpyDeviceToHost)); + CHECK_HIP_ERROR( + hipMemcpy(results_host, results_device, NUM_VALUES * sizeof(float), hipMemcpyDeviceToHost)); - for (size_t i = 0; i < NUM_VALUES; ++i) { - x_host[i] = std::log2f(x_host[i]); - } + for (size_t i = 0; i < NUM_VALUES; ++i) { + x_host[i] = std::log2f(x_host[i]); + } - for (int i = 0; i < NUM_VALUES; ++i) { - EXPECT_NEAR(x_host[i], results_host[i], 1e-5); - } + for (int i = 0; i < NUM_VALUES; ++i) { + EXPECT_NEAR(x_host[i], results_host[i], 1e-5); + } - CHECK_HIP_ERROR(hipFree(x_device)); - CHECK_HIP_ERROR(hipFree(results_device)); + CHECK_HIP_ERROR(hipFree(x_device)); + CHECK_HIP_ERROR(hipFree(results_device)); } -TEST(hipFunctionsTest, TestPtxRcp) -{ - float x_host[NUM_VALUES] = {10.23f, 5.56f, 8.2f, 3.141f, 9.81f}; - float results_host[NUM_VALUES]; +TEST(hipFunctionsTest, TestPtxRcp) { + float x_host[NUM_VALUES] = {10.23f, 5.56f, 8.2f, 3.141f, 9.81f}; + float results_host[NUM_VALUES]; - float *x_device, *results_device; + float *x_device, *results_device; - CHECK_HIP_ERROR(hipMalloc((void **)&x_device, NUM_VALUES * sizeof(float))); - CHECK_HIP_ERROR( - hipMalloc((void **)&results_device, NUM_VALUES * sizeof(float))); + CHECK_HIP_ERROR(hipMalloc((void**)&x_device, NUM_VALUES * sizeof(float))); + CHECK_HIP_ERROR(hipMalloc((void**)&results_device, NUM_VALUES * sizeof(float))); - CHECK_HIP_ERROR(hipMemcpy(x_device, x_host, NUM_VALUES * sizeof(float), - hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(x_device, x_host, NUM_VALUES * sizeof(float), hipMemcpyHostToDevice)); - int grid_size = (NUM_VALUES + BLOCK_SIZE - 1) / BLOCK_SIZE; + int grid_size = (NUM_VALUES + BLOCK_SIZE - 1) / BLOCK_SIZE; - test_ptx_rcp_kernel<<>>(x_device, results_device); + test_ptx_rcp_kernel<<>>(x_device, results_device); - CHECK_HIP_ERROR(hipMemcpy(results_host, results_device, - NUM_VALUES * sizeof(float), - hipMemcpyDeviceToHost)); + CHECK_HIP_ERROR( + hipMemcpy(results_host, results_device, NUM_VALUES * sizeof(float), hipMemcpyDeviceToHost)); - for (size_t i = 0; i < NUM_VALUES; ++i) { - x_host[i] = 1.0f / x_host[i]; - } + for (size_t i = 0; i < NUM_VALUES; ++i) { + x_host[i] = 1.0f / x_host[i]; + } - for (int i = 0; i < NUM_VALUES; ++i) { - EXPECT_NEAR(x_host[i], results_host[i], 1e-5); - } + for (int i = 0; i < NUM_VALUES; ++i) { + EXPECT_NEAR(x_host[i], results_host[i], 1e-5); + } - CHECK_HIP_ERROR(hipFree(x_device)); - CHECK_HIP_ERROR(hipFree(results_device)); + CHECK_HIP_ERROR(hipFree(x_device)); + CHECK_HIP_ERROR(hipFree(results_device)); } -TEST(hipFunctionsTest, TestRsqrt) -{ - float x_host[NUM_VALUES] = {10.23f, 5.56f, 8.2f, 3.141f, 9.81f}; - float results_host[NUM_VALUES]; +TEST(hipFunctionsTest, TestRsqrt) { + float x_host[NUM_VALUES] = {10.23f, 5.56f, 8.2f, 3.141f, 9.81f}; + float results_host[NUM_VALUES]; - float *x_device, *results_device; + float *x_device, *results_device; - CHECK_HIP_ERROR(hipMalloc((void **)&x_device, NUM_VALUES * sizeof(float))); - CHECK_HIP_ERROR( - hipMalloc((void **)&results_device, NUM_VALUES * sizeof(float))); + CHECK_HIP_ERROR(hipMalloc((void**)&x_device, NUM_VALUES * sizeof(float))); + CHECK_HIP_ERROR(hipMalloc((void**)&results_device, NUM_VALUES * sizeof(float))); - CHECK_HIP_ERROR(hipMemcpy(x_device, x_host, NUM_VALUES * sizeof(float), - hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(x_device, x_host, NUM_VALUES * sizeof(float), hipMemcpyHostToDevice)); - int grid_size = (NUM_VALUES + BLOCK_SIZE - 1) / BLOCK_SIZE; + int grid_size = (NUM_VALUES + BLOCK_SIZE - 1) / BLOCK_SIZE; - test_rsqrt_kernel<<>>(x_device, results_device); + test_rsqrt_kernel<<>>(x_device, results_device); - CHECK_HIP_ERROR(hipMemcpy(results_host, results_device, - NUM_VALUES * sizeof(float), - hipMemcpyDeviceToHost)); + CHECK_HIP_ERROR( + hipMemcpy(results_host, results_device, NUM_VALUES * sizeof(float), hipMemcpyDeviceToHost)); - for (size_t i = 0; i < NUM_VALUES; ++i) { - x_host[i] = 1.0f / std::sqrtf(x_host[i]); - } + for (size_t i = 0; i < NUM_VALUES; ++i) { + x_host[i] = 1.0f / std::sqrtf(x_host[i]); + } - for (int i = 0; i < NUM_VALUES; ++i) { - EXPECT_NEAR(x_host[i], results_host[i], 1e-5); - } + for (int i = 0; i < NUM_VALUES; ++i) { + EXPECT_NEAR(x_host[i], results_host[i], 1e-5); + } - CHECK_HIP_ERROR(hipFree(x_device)); - CHECK_HIP_ERROR(hipFree(results_device)); + CHECK_HIP_ERROR(hipFree(x_device)); + CHECK_HIP_ERROR(hipFree(results_device)); } -TEST(hipFunctionsTest, TestTanh) -{ - float x_host[NUM_VALUES] = {3.5f, -2.2f, 1.5f, 1.83f, 0.87f}; - float results_host[NUM_VALUES]; +TEST(hipFunctionsTest, TestTanh) { + float x_host[NUM_VALUES] = {3.5f, -2.2f, 1.5f, 1.83f, 0.87f}; + float results_host[NUM_VALUES]; - float *x_device, *results_device; + float *x_device, *results_device; - CHECK_HIP_ERROR(hipMalloc((void **)&x_device, NUM_VALUES * sizeof(float))); - CHECK_HIP_ERROR( - hipMalloc((void **)&results_device, NUM_VALUES * sizeof(float))); + CHECK_HIP_ERROR(hipMalloc((void**)&x_device, NUM_VALUES * sizeof(float))); + CHECK_HIP_ERROR(hipMalloc((void**)&results_device, NUM_VALUES * sizeof(float))); - CHECK_HIP_ERROR(hipMemcpy(x_device, x_host, NUM_VALUES * sizeof(float), - hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(x_device, x_host, NUM_VALUES * sizeof(float), hipMemcpyHostToDevice)); - int grid_size = (NUM_VALUES + BLOCK_SIZE - 1) / BLOCK_SIZE; + int grid_size = (NUM_VALUES + BLOCK_SIZE - 1) / BLOCK_SIZE; - test_tanh_kernel<<>>(x_device, results_device); + test_tanh_kernel<<>>(x_device, results_device); - CHECK_HIP_ERROR(hipMemcpy(results_host, results_device, - NUM_VALUES * sizeof(float), - hipMemcpyDeviceToHost)); + CHECK_HIP_ERROR( + hipMemcpy(results_host, results_device, NUM_VALUES * sizeof(float), hipMemcpyDeviceToHost)); - for (size_t i = 0; i < NUM_VALUES; ++i) { - x_host[i] = std::tanhf(x_host[i]); - } + for (size_t i = 0; i < NUM_VALUES; ++i) { + x_host[i] = std::tanhf(x_host[i]); + } - for (int i = 0; i < NUM_VALUES; ++i) { - EXPECT_NEAR(x_host[i], results_host[i], 1e-5); - } + for (int i = 0; i < NUM_VALUES; ++i) { + EXPECT_NEAR(x_host[i], results_host[i], 1e-5); + } - CHECK_HIP_ERROR(hipFree(x_device)); - CHECK_HIP_ERROR(hipFree(results_device)); + CHECK_HIP_ERROR(hipFree(x_device)); + CHECK_HIP_ERROR(hipFree(results_device)); } -TEST(hipFunctionsTest, TestShflXorSync) -{ - - const int WARP_SIZE = 64; - float h_input[WARP_SIZE], h_output[WARP_SIZE]; +TEST(hipFunctionsTest, TestShflXorSync) { + const int WARP_SIZE = 64; + float h_input[WARP_SIZE], h_output[WARP_SIZE]; - float *d_input, *d_output; - int lane_mask = 1; + float *d_input, *d_output; + int lane_mask = 1; - for (int i = 0; i < WARP_SIZE; ++i) { - h_input[i] = static_cast(i); - } + for (int i = 0; i < WARP_SIZE; ++i) { + h_input[i] = static_cast(i); + } - size_t BYTES = WARP_SIZE * sizeof(float); + size_t BYTES = WARP_SIZE * sizeof(float); - CHECK_HIP_ERROR(hipMalloc((void **)&d_input, BYTES)); - CHECK_HIP_ERROR(hipMalloc((void **)&d_output, BYTES)); + CHECK_HIP_ERROR(hipMalloc((void**)&d_input, BYTES)); + CHECK_HIP_ERROR(hipMalloc((void**)&d_output, BYTES)); - CHECK_HIP_ERROR(hipMemcpy(d_input, h_input, BYTES, hipMemcpyHostToDevice)); + CHECK_HIP_ERROR(hipMemcpy(d_input, h_input, BYTES, hipMemcpyHostToDevice)); - test_shfl_xor_sync<<<1, WARP_SIZE>>>(d_input, d_output, lane_mask); - CHECK_HIP_ERROR( - hipMemcpy(h_output, d_output, BYTES, hipMemcpyDeviceToHost)); + test_shfl_xor_sync<<<1, WARP_SIZE>>>(d_input, d_output, lane_mask); + CHECK_HIP_ERROR(hipMemcpy(h_output, d_output, BYTES, hipMemcpyDeviceToHost)); - for (int i = 0; i < WARP_SIZE; ++i) { - int expected_idx = i ^ lane_mask; - if (expected_idx < WARP_SIZE) { - ASSERT_EQ(h_output[i], h_input[expected_idx]); - } + for (int i = 0; i < WARP_SIZE; ++i) { + int expected_idx = i ^ lane_mask; + if (expected_idx < WARP_SIZE) { + ASSERT_EQ(h_output[i], h_input[expected_idx]); } + } - CHECK_HIP_ERROR(hipFree(d_input)); - CHECK_HIP_ERROR(hipFree(d_output)); + CHECK_HIP_ERROR(hipFree(d_input)); + CHECK_HIP_ERROR(hipFree(d_output)); } -int main(int argc, char **argv) -{ - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); } diff --git a/libflashinfer/tests/hip/test_mfma_fp32_16x16x16fp16.cpp b/libflashinfer/tests/hip/test_mfma_fp32_16x16x16fp16.cpp index e448ccd02d..24eb5219af 100644 --- a/libflashinfer/tests/hip/test_mfma_fp32_16x16x16fp16.cpp +++ b/libflashinfer/tests/hip/test_mfma_fp32_16x16x16fp16.cpp @@ -3,8 +3,7 @@ // // SPDX - License - Identifier : Apache 2.0 -#include "gpu_iface/mma_ops.hpp" - +#include #include #include @@ -12,19 +11,18 @@ #include #include -#include +#include "gpu_iface/mma_ops.hpp" // Check HIP errors -#define HIP_CHECK(command) \ - { \ - hipError_t status = command; \ - if (status != hipSuccess) { \ - std::cerr << "Error: HIP reports " << hipGetErrorString(status) \ - << std::endl; \ - std::cerr << "at " << __FILE__ << ":" << __LINE__ << std::endl; \ - exit(EXIT_FAILURE); \ - } \ - } +#define HIP_CHECK(command) \ + { \ + hipError_t status = command; \ + if (status != hipSuccess) { \ + std::cerr << "Error: HIP reports " << hipGetErrorString(status) << std::endl; \ + std::cerr << "at " << __FILE__ << ":" << __LINE__ << std::endl; \ + exit(EXIT_FAILURE); \ + } \ + } // Dimensions for our test matrices constexpr int M = 16; @@ -37,152 +35,130 @@ constexpr int LDB = N; constexpr int LDC = N; // Host reference implementation for matrix multiplication -void gemm_reference(const __half *A, - const __half *B, - float *C, - int M, - int N, - int K, - int lda, - int ldb, - int ldc) -{ - for (int i = 0; i < M; ++i) { - for (int j = 0; j < N; ++j) { - float acc = 0.0f; - for (int k = 0; k < K; ++k) { - // Use __half_as_float to properly convert __half to float - acc += __half2float(A[i * K + k]) * __half2float(B[k * N + j]); - } - C[i * N + j] = acc; - } +void gemm_reference(const __half* A, const __half* B, float* C, int M, int N, int K, int lda, + int ldb, int ldc) { + for (int i = 0; i < M; ++i) { + for (int j = 0; j < N; ++j) { + float acc = 0.0f; + for (int k = 0; k < K; ++k) { + // Use __half_as_float to properly convert __half to float + acc += __half2float(A[i * K + k]) * __half2float(B[k * N + j]); + } + C[i * N + j] = acc; } + } } -__global__ void test_mfma_kernel(const __half *A, const __half *B, float *C) -{ - uint32_t a_reg[2]; - uint32_t b_reg[2]; - float c_reg[4] = {0.0f, 0.0f, 0.0f, 0.0f}; +__global__ void test_mfma_kernel(const __half* A, const __half* B, float* C) { + uint32_t a_reg[2]; + uint32_t b_reg[2]; + float c_reg[4] = {0.0f, 0.0f, 0.0f, 0.0f}; - // A Matrix is read row wise. Threads T0...T15 read Col 0...3 of Row 0...15 - // Threads T16...T31 read Col 4...7 of Row 0...15 - // Threads T32...T47 read Col 8...11 of Row 0...15 - // Threads T48...T63 read Col 12...15 of Row 0...15 + // A Matrix is read row wise. Threads T0...T15 read Col 0...3 of Row 0...15 + // Threads T16...T31 read Col 4...7 of Row 0...15 + // Threads T32...T47 read Col 8...11 of Row 0...15 + // Threads T48...T63 read Col 12...15 of Row 0...15 - // B Matrix is read column wise. Threads T0...T15 read Row 0...3 of Col - // 0...15 (Each thread reads 1 column per 4 rows) Threads T16...T31 read - // Row 4...7 of Col 0...15 Threads T32...T47 read Row 8...11 of Col 0...15 - // Threads T48...T63 read Row 12...15 of Col 0...15 - int a_idx = (threadIdx.x / 16) * 4 + threadIdx.x % 16 * LDA; - int b_idx = (threadIdx.x / 16) * LDB * 4 + threadIdx.x % 16; + // B Matrix is read column wise. Threads T0...T15 read Row 0...3 of Col + // 0...15 (Each thread reads 1 column per 4 rows) Threads T16...T31 read + // Row 4...7 of Col 0...15 Threads T32...T47 read Row 8...11 of Col 0...15 + // Threads T48...T63 read Row 12...15 of Col 0...15 + int a_idx = (threadIdx.x / 16) * 4 + threadIdx.x % 16 * LDA; + int b_idx = (threadIdx.x / 16) * LDB * 4 + threadIdx.x % 16; - flashinfer::gpu_iface::mma::load_fragment<__half>(a_reg, &A[a_idx]); - flashinfer::gpu_iface::mma::load_fragment_transpose<__half>(b_reg, - &B[b_idx], LDB); + flashinfer::gpu_iface::mma::load_fragment<__half>(a_reg, &A[a_idx]); + flashinfer::gpu_iface::mma::load_fragment_transpose<__half>(b_reg, &B[b_idx], LDB); - flashinfer::gpu_iface::mma::mma_sync_m16n16k16_row_col_f16f16f32<__half>( - c_reg, a_reg, b_reg); + flashinfer::gpu_iface::mma::mma_sync_m16n16k16_row_col_f16f16f32<__half>(c_reg, a_reg, b_reg); - for (int i = 0; i < 4; ++i) { - const int d_idx = - threadIdx.x % 16 + i * LDC + (threadIdx.x / 16) * 4 * LDC; + for (int i = 0; i < 4; ++i) { + const int d_idx = threadIdx.x % 16 + i * LDC + (threadIdx.x / 16) * 4 * LDC; - C[d_idx] = c_reg[i]; - } + C[d_idx] = c_reg[i]; + } } // Test class -class MfmaTest : public ::testing::Test -{ -protected: - std::vector<__half> A_host; - std::vector<__half> B_host; - std::vector C_host; - std::vector C_ref; - - __half *A_dev = nullptr; - __half *B_dev = nullptr; - float *C_dev = nullptr; - - void SetUp() override - { - // Initialize host data - A_host.resize(M * K); - B_host.resize(K * N); - C_host.resize(M * N, 0.0f); - C_ref.resize(M * N, 0.0f); - - // Fill with deterministic values - std::mt19937 gen(42); - std::uniform_real_distribution dist(-1.0f, 1.0f); - - for (int i = 0; i < M * K; ++i) { - A_host[i] = __float2half(dist(gen)); - } - - for (int i = 0; i < K * N; ++i) { - B_host[i] = __float2half(dist(gen)); - } - - // Calculate reference result - gemm_reference(A_host.data(), B_host.data(), C_ref.data(), M, N, K, LDA, - LDB, LDC); - - // Allocate device memory - HIP_CHECK(hipMalloc(&A_dev, M * K * sizeof(__half))); - HIP_CHECK(hipMalloc(&B_dev, K * N * sizeof(__half))); - HIP_CHECK(hipMalloc(&C_dev, M * N * sizeof(float))); - - // Copy input data to device - HIP_CHECK(hipMemcpy(A_dev, A_host.data(), M * K * sizeof(__half), - hipMemcpyHostToDevice)); - HIP_CHECK(hipMemcpy(B_dev, B_host.data(), K * N * sizeof(__half), - hipMemcpyHostToDevice)); - HIP_CHECK(hipMemset(C_dev, 0, M * N * sizeof(float))); +class MfmaTest : public ::testing::Test { + protected: + std::vector<__half> A_host; + std::vector<__half> B_host; + std::vector C_host; + std::vector C_ref; + + __half* A_dev = nullptr; + __half* B_dev = nullptr; + float* C_dev = nullptr; + + void SetUp() override { + // Initialize host data + A_host.resize(M * K); + B_host.resize(K * N); + C_host.resize(M * N, 0.0f); + C_ref.resize(M * N, 0.0f); + + // Fill with deterministic values + std::mt19937 gen(42); + std::uniform_real_distribution dist(-1.0f, 1.0f); + + for (int i = 0; i < M * K; ++i) { + A_host[i] = __float2half(dist(gen)); } - void TearDown() override - { - // Free device memory - HIP_CHECK(hipFree(A_dev)); - HIP_CHECK(hipFree(B_dev)); - HIP_CHECK(hipFree(C_dev)); + for (int i = 0; i < K * N; ++i) { + B_host[i] = __float2half(dist(gen)); } + + // Calculate reference result + gemm_reference(A_host.data(), B_host.data(), C_ref.data(), M, N, K, LDA, LDB, LDC); + + // Allocate device memory + HIP_CHECK(hipMalloc(&A_dev, M * K * sizeof(__half))); + HIP_CHECK(hipMalloc(&B_dev, K * N * sizeof(__half))); + HIP_CHECK(hipMalloc(&C_dev, M * N * sizeof(float))); + + // Copy input data to device + HIP_CHECK(hipMemcpy(A_dev, A_host.data(), M * K * sizeof(__half), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(B_dev, B_host.data(), K * N * sizeof(__half), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemset(C_dev, 0, M * N * sizeof(float))); + } + + void TearDown() override { + // Free device memory + HIP_CHECK(hipFree(A_dev)); + HIP_CHECK(hipFree(B_dev)); + HIP_CHECK(hipFree(C_dev)); + } }; // Test that verifies mfma_fp32_16x16x16fp16 calculates correct results -TEST_F(MfmaTest, CorrectResults) -{ - // Launch kernel with one block of 64 threads (one wavefront) - dim3 gridDim(1); - dim3 blockDim(64); - test_mfma_kernel<<>>(A_dev, B_dev, C_dev); - - // Copy results back to host - HIP_CHECK(hipMemcpy(C_host.data(), C_dev, M * N * sizeof(float), - hipMemcpyDeviceToHost)); - - // Verify results with small tolerance for floating point differences - const float tolerance = 1e-3f; - bool all_pass = true; - for (int i = 0; i < M * N; ++i) { - float diff = std::abs(C_host[i] - C_ref[i]); - if (diff > tolerance) { - std::cout << "Mismatch at index " << i << ": " - << "Actual=" << C_host[i] << ", Expected=" << C_ref[i] - << ", Diff=" << diff << std::endl; - all_pass = false; - } +TEST_F(MfmaTest, CorrectResults) { + // Launch kernel with one block of 64 threads (one wavefront) + dim3 gridDim(1); + dim3 blockDim(64); + test_mfma_kernel<<>>(A_dev, B_dev, C_dev); + + // Copy results back to host + HIP_CHECK(hipMemcpy(C_host.data(), C_dev, M * N * sizeof(float), hipMemcpyDeviceToHost)); + + // Verify results with small tolerance for floating point differences + const float tolerance = 1e-3f; + bool all_pass = true; + for (int i = 0; i < M * N; ++i) { + float diff = std::abs(C_host[i] - C_ref[i]); + if (diff > tolerance) { + std::cout << "Mismatch at index " << i << ": " + << "Actual=" << C_host[i] << ", Expected=" << C_ref[i] << ", Diff=" << diff + << std::endl; + all_pass = false; } + } - EXPECT_TRUE(all_pass) - << "Matrix multiplication results don't match reference implementation"; + EXPECT_TRUE(all_pass) << "Matrix multiplication results don't match reference implementation"; } // Main function that runs all tests -int main(int argc, char **argv) -{ - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); } diff --git a/libflashinfer/tests/hip/test_page.cpp b/libflashinfer/tests/hip/test_page.cpp index 770a442793..31e4842f43 100644 --- a/libflashinfer/tests/hip/test_page.cpp +++ b/libflashinfer/tests/hip/test_page.cpp @@ -3,420 +3,365 @@ // // SPDX - License - Identifier : Apache 2.0 -#include "flashinfer/attention/generic/page.cuh" - #include #include #include #include -namespace utils -{ +#include "flashinfer/attention/generic/page.cuh" + +namespace utils { template -void vec_normal_(std::vector &vec, float mean = 0.f, float std = 1.f) -{ - std::random_device rd{}; - std::mt19937 gen{rd()}; - std::normal_distribution d{mean, std}; - for (size_t i = 0; i < vec.size(); ++i) { - vec[i] = T(d(gen)); - } +void vec_normal_(std::vector& vec, float mean = 0.f, float std = 1.f) { + std::random_device rd{}; + std::mt19937 gen{rd()}; + std::normal_distribution d{mean, std}; + for (size_t i = 0; i < vec.size(); ++i) { + vec[i] = T(d(gen)); + } } template -void vec_uniform_(std::vector &vec, float a = 0.f, float b = 1.f) -{ - std::random_device rd{}; - std::mt19937 gen{rd()}; - std::uniform_real_distribution d{a, b}; - for (size_t i = 0; i < vec.size(); ++i) { - vec[i] = T(d(gen)); - } +void vec_uniform_(std::vector& vec, float a = 0.f, float b = 1.f) { + std::random_device rd{}; + std::mt19937 gen{rd()}; + std::uniform_real_distribution d{a, b}; + for (size_t i = 0; i < vec.size(); ++i) { + vec[i] = T(d(gen)); + } } -template void vec_zero_(std::vector &vec) -{ - std::fill(vec.begin(), vec.end(), T(0)); +template +void vec_zero_(std::vector& vec) { + std::fill(vec.begin(), vec.end(), T(0)); } -template void vec_fill_(std::vector &vec, T val) -{ - std::fill(vec.begin(), vec.end(), val); +template +void vec_fill_(std::vector& vec, T val) { + std::fill(vec.begin(), vec.end(), val); } -template void vec_randint_(std::vector &vec, int low, int high) -{ - std::random_device rd{}; - std::mt19937 gen{rd()}; - std::uniform_int_distribution d{low, high}; - for (size_t i = 0; i < vec.size(); ++i) { - vec[i] = T(d(gen)); - } +template +void vec_randint_(std::vector& vec, int low, int high) { + std::random_device rd{}; + std::mt19937 gen{rd()}; + std::uniform_int_distribution d{low, high}; + for (size_t i = 0; i < vec.size(); ++i) { + vec[i] = T(d(gen)); + } } -template size_t vec_bytes(const T &vec) -{ - return vec.size() * sizeof(typename T::value_type); +template +size_t vec_bytes(const T& vec) { + return vec.size() * sizeof(typename T::value_type); } template -bool isclose(T a, T b, float rtol = 1e-5, float atol = 1e-8) -{ - return fabs(a - b) <= (atol + rtol * fabs(b)); +bool isclose(T a, T b, float rtol = 1e-5, float atol = 1e-8) { + return fabs(a - b) <= (atol + rtol * fabs(b)); } -} // namespace utils +} // namespace utils using namespace flashinfer; template -void append_paged_kv_cache(paged_kv_t page_cpu, - const std::vector> &keys, - const std::vector> &values, - const std::vector &append_indptr) -{ - size_t batch_size = page_cpu.batch_size; - size_t num_heads = page_cpu.num_heads; - size_t head_dim = page_cpu.head_dim; - size_t page_size = page_cpu.page_size; - for (size_t i = 0; i < batch_size; ++i) { - const std::vector &ki = keys[i]; - const std::vector &vi = values[i]; - size_t append_seq_len = append_indptr[i + 1] - append_indptr[i]; - size_t num_pages_i = page_cpu.indptr[i + 1] - page_cpu.indptr[i]; - size_t seq_len = - (num_pages_i - 1) * page_size + page_cpu.last_page_len[i]; - assert(append_seq_len <= seq_len); - size_t append_start = seq_len - append_seq_len; - - for (size_t j = 0; j < append_seq_len; ++j) { - size_t page_seq_idx = j + append_start; - size_t page_idx = - page_cpu.indices[page_cpu.indptr[i] + page_seq_idx / page_size]; - size_t entry_idx = page_seq_idx % page_size; - for (size_t h = 0; h < num_heads; ++h) { - std::copy(ki.begin() + (j * num_heads + h) * head_dim, - ki.begin() + (j * num_heads + h + 1) * head_dim, - page_cpu.k_data + page_cpu.get_elem_offset( - page_idx, h, entry_idx, 0)); - std::copy(vi.begin() + (j * num_heads + h) * head_dim, - vi.begin() + (j * num_heads + h + 1) * head_dim, - page_cpu.v_data + page_cpu.get_elem_offset( - page_idx, h, entry_idx, 0)); - } - } +void append_paged_kv_cache(paged_kv_t page_cpu, const std::vector>& keys, + const std::vector>& values, + const std::vector& append_indptr) { + size_t batch_size = page_cpu.batch_size; + size_t num_heads = page_cpu.num_heads; + size_t head_dim = page_cpu.head_dim; + size_t page_size = page_cpu.page_size; + for (size_t i = 0; i < batch_size; ++i) { + const std::vector& ki = keys[i]; + const std::vector& vi = values[i]; + size_t append_seq_len = append_indptr[i + 1] - append_indptr[i]; + size_t num_pages_i = page_cpu.indptr[i + 1] - page_cpu.indptr[i]; + size_t seq_len = (num_pages_i - 1) * page_size + page_cpu.last_page_len[i]; + assert(append_seq_len <= seq_len); + size_t append_start = seq_len - append_seq_len; + + for (size_t j = 0; j < append_seq_len; ++j) { + size_t page_seq_idx = j + append_start; + size_t page_idx = page_cpu.indices[page_cpu.indptr[i] + page_seq_idx / page_size]; + size_t entry_idx = page_seq_idx % page_size; + for (size_t h = 0; h < num_heads; ++h) { + std::copy(ki.begin() + (j * num_heads + h) * head_dim, + ki.begin() + (j * num_heads + h + 1) * head_dim, + page_cpu.k_data + page_cpu.get_elem_offset(page_idx, h, entry_idx, 0)); + std::copy(vi.begin() + (j * num_heads + h) * head_dim, + vi.begin() + (j * num_heads + h + 1) * head_dim, + page_cpu.v_data + page_cpu.get_elem_offset(page_idx, h, entry_idx, 0)); + } } + } } -class PagedKVTest : public ::testing::Test -{ -protected: - void SetUp() override - { - // Ensure CUDA is available - ASSERT_TRUE(torch::cuda::is_available()); - } +class PagedKVTest : public ::testing::Test { + protected: + void SetUp() override { + // Ensure CUDA is available + ASSERT_TRUE(torch::cuda::is_available()); + } }; // Helper function to check for NaN values in a tensor -bool hasNaN(const torch::Tensor &tensor) -{ - return torch::isnan(tensor).any().item(); -} +bool hasNaN(const torch::Tensor& tensor) { return torch::isnan(tensor).any().item(); } // Helper function to convert vector to tensor -template torch::Tensor vectorToTensor(const std::vector &vec) -{ - torch::Tensor tensor = - torch::from_blob( - const_cast(vec.data()), {static_cast(vec.size())}, - torch::TensorOptions().dtype( - std::is_same::value ? torch::kFloat32 - : std::is_same::value ? torch::kFloat16 - : std::is_same::value ? torch::kInt32 - : torch::kFloat32)) - .clone(); - return tensor; +template +torch::Tensor vectorToTensor(const std::vector& vec) { + torch::Tensor tensor = + torch::from_blob( + const_cast(vec.data()), {static_cast(vec.size())}, + torch::TensorOptions().dtype(std::is_same::value ? torch::kFloat32 + : std::is_same::value ? torch::kFloat16 + : std::is_same::value ? torch::kInt32 + : torch::kFloat32)) + .clone(); + return tensor; } // Helper function to check tensor closeness -bool tensorIsClose(const torch::Tensor &a, - const torch::Tensor &b, - float atol = 1e-3, - float rtol = 1e-3) -{ - return torch::isclose(a, b, atol, rtol).all().item(); +bool tensorIsClose(const torch::Tensor& a, const torch::Tensor& b, float atol = 1e-3, + float rtol = 1e-3) { + return torch::isclose(a, b, atol, rtol).all().item(); } template -void _TestAppendPagedKVKernelCorrectness(size_t page_size, - size_t batch_size, - size_t num_heads, - size_t head_dim, - QKVLayout kv_layout) -{ - // number of conversation rounds - size_t num_conv_rounds = 3; - size_t max_decode_len = 1; - size_t max_prefill_len = 128; - size_t max_num_pages = num_conv_rounds * batch_size * - ((max_decode_len + max_prefill_len) / page_size + 1); - - // Define tensor options based on the type T - torch::TensorOptions tensor_options = - torch::TensorOptions() - .dtype(std::is_same::value ? torch::kFloat32 - : std::is_same::value - ? torch::kFloat16 - : torch::kFloat32) // Default to float32 if type is not - // recognized - .device(torch::kCUDA); - - torch::TensorOptions int_options = - torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA); - - // Create CPU tensors for reference - std::vector k_data_cpu_vec( - max_num_pages * page_size * num_heads * head_dim, 0); - std::vector v_data_cpu_vec( - max_num_pages * page_size * num_heads * head_dim, 0); - - // Create GPU tensors - torch::Tensor k_data_gpu = - torch::zeros({static_cast(max_num_pages * page_size * - num_heads * head_dim)}, - tensor_options); - - torch::Tensor v_data_gpu = - torch::zeros({static_cast(max_num_pages * page_size * - num_heads * head_dim)}, - tensor_options); - - std::vector seq_len(batch_size, 0); - std::vector> page_indices(batch_size); - std::vector last_page_len(batch_size, 0); - size_t page_counter = 0; - - for (size_t round = 0; round < 2 * num_conv_rounds; ++round) { - std::vector append_len(batch_size); - std::vector append_indptr{0}; - std::vector batch_indices; - std::vector positions; - std::vector> keys; - std::vector> values; - - // Generate random lengths for prefill rounds, fixed for decode rounds - if (round % 2 == 0) { - utils::vec_randint_(append_len, 1, max_prefill_len + 1); - } - else { - std::fill(append_len.begin(), append_len.end(), max_decode_len); - } +void _TestAppendPagedKVKernelCorrectness(size_t page_size, size_t batch_size, size_t num_heads, + size_t head_dim, QKVLayout kv_layout) { + // number of conversation rounds + size_t num_conv_rounds = 3; + size_t max_decode_len = 1; + size_t max_prefill_len = 128; + size_t max_num_pages = + num_conv_rounds * batch_size * ((max_decode_len + max_prefill_len) / page_size + 1); + + // Define tensor options based on the type T + torch::TensorOptions tensor_options = + torch::TensorOptions() + .dtype(std::is_same::value ? torch::kFloat32 + : std::is_same::value ? torch::kFloat16 + : torch::kFloat32) // Default to float32 if type is + // not recognized + .device(torch::kCUDA); + + torch::TensorOptions int_options = + torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA); + + // Create CPU tensors for reference + std::vector k_data_cpu_vec(max_num_pages * page_size * num_heads * head_dim, 0); + std::vector v_data_cpu_vec(max_num_pages * page_size * num_heads * head_dim, 0); + + // Create GPU tensors + torch::Tensor k_data_gpu = torch::zeros( + {static_cast(max_num_pages * page_size * num_heads * head_dim)}, tensor_options); + + torch::Tensor v_data_gpu = torch::zeros( + {static_cast(max_num_pages * page_size * num_heads * head_dim)}, tensor_options); + + std::vector seq_len(batch_size, 0); + std::vector> page_indices(batch_size); + std::vector last_page_len(batch_size, 0); + size_t page_counter = 0; + + for (size_t round = 0; round < 2 * num_conv_rounds; ++round) { + std::vector append_len(batch_size); + std::vector append_indptr{0}; + std::vector batch_indices; + std::vector positions; + std::vector> keys; + std::vector> values; + + // Generate random lengths for prefill rounds, fixed for decode rounds + if (round % 2 == 0) { + utils::vec_randint_(append_len, 1, max_prefill_len + 1); + } else { + std::fill(append_len.begin(), append_len.end(), max_decode_len); + } - for (size_t i = 0; i < batch_size; ++i) { - append_indptr.push_back(append_indptr.back() + append_len[i]); - seq_len[i] += append_len[i]; - for (size_t j = 0; j < append_len[i]; ++j) { - if (last_page_len[i] % page_size == 0) { - page_indices[i].push_back(page_counter++); - last_page_len[i] = 1; - } - else { - last_page_len[i] += 1; - } - batch_indices.push_back(i); - positions.push_back(seq_len[i] - append_len[i] + j); - } - - // Generate random keys and values - std::vector ki(append_len[i] * num_heads * head_dim); - std::vector vi(append_len[i] * num_heads * head_dim); - utils::vec_normal_(ki); - utils::vec_normal_(vi); - keys.push_back(ki); - values.push_back(vi); + for (size_t i = 0; i < batch_size; ++i) { + append_indptr.push_back(append_indptr.back() + append_len[i]); + seq_len[i] += append_len[i]; + for (size_t j = 0; j < append_len[i]; ++j) { + if (last_page_len[i] % page_size == 0) { + page_indices[i].push_back(page_counter++); + last_page_len[i] = 1; + } else { + last_page_len[i] += 1; } + batch_indices.push_back(i); + positions.push_back(seq_len[i] - append_len[i] + j); + } + + // Generate random keys and values + std::vector ki(append_len[i] * num_heads * head_dim); + std::vector vi(append_len[i] * num_heads * head_dim); + utils::vec_normal_(ki); + utils::vec_normal_(vi); + keys.push_back(ki); + values.push_back(vi); + } - // Create CPU paged KV cache - std::vector indptr_cpu{0}; - std::vector indices_cpu; - for (size_t i = 0; i < batch_size; ++i) { - for (size_t j = 0; j < page_indices[i].size(); ++j) { - indices_cpu.push_back(page_indices[i][j]); - } - indptr_cpu.push_back(indptr_cpu.back() + page_indices[i].size()); - } + // Create CPU paged KV cache + std::vector indptr_cpu{0}; + std::vector indices_cpu; + for (size_t i = 0; i < batch_size; ++i) { + for (size_t j = 0; j < page_indices[i].size(); ++j) { + indices_cpu.push_back(page_indices[i][j]); + } + indptr_cpu.push_back(indptr_cpu.back() + page_indices[i].size()); + } - paged_kv_t paged_kv_cpu( - num_heads, page_size, head_dim, batch_size, kv_layout, - /*k_data=*/k_data_cpu_vec.data(), - /*v_data=*/v_data_cpu_vec.data(), indices_cpu.data(), - indptr_cpu.data(), last_page_len.data()); - - // Apply CPU reference implementation - append_paged_kv_cache(paged_kv_cpu, keys, values, append_indptr); - - // Create GPU tensors for indices, indptr, and last_page_len - torch::Tensor indptr_gpu = - torch::from_blob(indptr_cpu.data(), - {static_cast(indptr_cpu.size())}, - torch::kInt32) - .clone() - .to(torch::kCUDA); - - torch::Tensor indices_gpu = - torch::from_blob(indices_cpu.data(), - {static_cast(indices_cpu.size())}, - torch::kInt32) - .clone() - .to(torch::kCUDA); - - torch::Tensor last_page_len_gpu = - torch::from_blob(last_page_len.data(), - {static_cast(batch_size)}, torch::kInt32) - .clone() - .to(torch::kCUDA); - - // Create GPU paged KV cache - paged_kv_t paged_kv_gpu( - num_heads, page_size, head_dim, batch_size, kv_layout, - /*k_data=*/static_cast(k_data_gpu.data_ptr()), - /*v_data=*/static_cast(v_data_gpu.data_ptr()), - static_cast(indices_gpu.data_ptr()), - static_cast(indptr_gpu.data_ptr()), - static_cast(last_page_len_gpu.data_ptr())); - - // Create batch indices and positions tensors - torch::Tensor batch_indices_gpu = - torch::from_blob(batch_indices.data(), - {static_cast(batch_indices.size())}, - torch::kInt32) - .clone() - .to(torch::kCUDA); - - torch::Tensor positions_gpu = - torch::from_blob(positions.data(), - {static_cast(positions.size())}, - torch::kInt32) - .clone() - .to(torch::kCUDA); - - // Create keys and values tensors - torch::Tensor keys_gpu = torch::zeros( - {static_cast(append_indptr.back() * num_heads * head_dim)}, - tensor_options); - - torch::Tensor values_gpu = torch::zeros( - {static_cast(append_indptr.back() * num_heads * head_dim)}, - tensor_options); - - // Copy keys and values to GPU - for (size_t i = 0; i < batch_size; ++i) { - torch::Tensor ki = - torch::from_blob(keys[i].data(), - {static_cast(keys[i].size())}, - tensor_options.device(torch::kCPU)) - .clone() - .to(torch::kCUDA); - - torch::Tensor vi = - torch::from_blob(values[i].data(), - {static_cast(values[i].size())}, - tensor_options.device(torch::kCPU)) - .clone() - .to(torch::kCUDA); - - keys_gpu - .slice(0, append_indptr[i] * num_heads * head_dim, - append_indptr[i + 1] * num_heads * head_dim) - .copy_(ki); - - values_gpu - .slice(0, append_indptr[i] * num_heads * head_dim, - append_indptr[i + 1] * num_heads * head_dim) - .copy_(vi); - } + paged_kv_t paged_kv_cpu(num_heads, page_size, head_dim, batch_size, kv_layout, + /*k_data=*/k_data_cpu_vec.data(), + /*v_data=*/v_data_cpu_vec.data(), indices_cpu.data(), + indptr_cpu.data(), last_page_len.data()); + + // Apply CPU reference implementation + append_paged_kv_cache(paged_kv_cpu, keys, values, append_indptr); + + // Create GPU tensors for indices, indptr, and last_page_len + torch::Tensor indptr_gpu = + torch::from_blob(indptr_cpu.data(), {static_cast(indptr_cpu.size())}, + torch::kInt32) + .clone() + .to(torch::kCUDA); + + torch::Tensor indices_gpu = + torch::from_blob(indices_cpu.data(), {static_cast(indices_cpu.size())}, + torch::kInt32) + .clone() + .to(torch::kCUDA); + + torch::Tensor last_page_len_gpu = + torch::from_blob(last_page_len.data(), {static_cast(batch_size)}, torch::kInt32) + .clone() + .to(torch::kCUDA); + + // Create GPU paged KV cache + paged_kv_t paged_kv_gpu(num_heads, page_size, head_dim, batch_size, kv_layout, + /*k_data=*/static_cast(k_data_gpu.data_ptr()), + /*v_data=*/static_cast(v_data_gpu.data_ptr()), + static_cast(indices_gpu.data_ptr()), + static_cast(indptr_gpu.data_ptr()), + static_cast(last_page_len_gpu.data_ptr())); + + // Create batch indices and positions tensors + torch::Tensor batch_indices_gpu = + torch::from_blob(batch_indices.data(), {static_cast(batch_indices.size())}, + torch::kInt32) + .clone() + .to(torch::kCUDA); + + torch::Tensor positions_gpu = + torch::from_blob(positions.data(), {static_cast(positions.size())}, torch::kInt32) + .clone() + .to(torch::kCUDA); + + // Create keys and values tensors + torch::Tensor keys_gpu = torch::zeros( + {static_cast(append_indptr.back() * num_heads * head_dim)}, tensor_options); + + torch::Tensor values_gpu = torch::zeros( + {static_cast(append_indptr.back() * num_heads * head_dim)}, tensor_options); + + // Copy keys and values to GPU + for (size_t i = 0; i < batch_size; ++i) { + torch::Tensor ki = torch::from_blob(keys[i].data(), {static_cast(keys[i].size())}, + tensor_options.device(torch::kCPU)) + .clone() + .to(torch::kCUDA); + + torch::Tensor vi = + torch::from_blob(values[i].data(), {static_cast(values[i].size())}, + tensor_options.device(torch::kCPU)) + .clone() + .to(torch::kCUDA); + + keys_gpu + .slice(0, append_indptr[i] * num_heads * head_dim, + append_indptr[i + 1] * num_heads * head_dim) + .copy_(ki); + + values_gpu + .slice(0, append_indptr[i] * num_heads * head_dim, + append_indptr[i + 1] * num_heads * head_dim) + .copy_(vi); + } - if (round % 2 == 0) { - // Call prefill kernel - hipError_t status = AppendPagedKVCache( - paged_kv_gpu, static_cast(keys_gpu.data_ptr()), - static_cast(values_gpu.data_ptr()), - static_cast(batch_indices_gpu.data_ptr()), - static_cast(positions_gpu.data_ptr()), - /*nnz=*/append_indptr.back(), - /*append_k_stride_n=*/num_heads * head_dim, - /*append_k_stride_h=*/head_dim, - /*append_v_stride_n=*/num_heads * head_dim, - /*append_v_stride_h=*/head_dim); - - EXPECT_EQ(status, hipSuccess) - << "AppendPagedKVCache kernel launch failed, error message: " - << hipGetErrorString(status); - } - else { - // Call decode kernel - hipError_t status = AppendPagedKVCacheDecode( - paged_kv_gpu, static_cast(keys_gpu.data_ptr()), - static_cast(values_gpu.data_ptr())); - - EXPECT_EQ(status, hipSuccess) << "AppendPagedKVCacheDecode kernel " - "launch failed, error message: " - << hipGetErrorString(status); - } + if (round % 2 == 0) { + // Call prefill kernel + hipError_t status = AppendPagedKVCache(paged_kv_gpu, static_cast(keys_gpu.data_ptr()), + static_cast(values_gpu.data_ptr()), + static_cast(batch_indices_gpu.data_ptr()), + static_cast(positions_gpu.data_ptr()), + /*nnz=*/append_indptr.back(), + /*append_k_stride_n=*/num_heads * head_dim, + /*append_k_stride_h=*/head_dim, + /*append_v_stride_n=*/num_heads * head_dim, + /*append_v_stride_h=*/head_dim); + + EXPECT_EQ(status, hipSuccess) << "AppendPagedKVCache kernel launch failed, error message: " + << hipGetErrorString(status); + } else { + // Call decode kernel + hipError_t status = + AppendPagedKVCacheDecode(paged_kv_gpu, static_cast(keys_gpu.data_ptr()), + static_cast(values_gpu.data_ptr())); + + EXPECT_EQ(status, hipSuccess) << "AppendPagedKVCacheDecode kernel " + "launch failed, error message: " + << hipGetErrorString(status); } + } - // Copy data back to CPU for verification - torch::Tensor k_data_cpu = - torch::from_blob(k_data_cpu_vec.data(), - {static_cast(k_data_cpu_vec.size())}, - tensor_options.device(torch::kCPU)) - .clone(); - - torch::Tensor v_data_cpu = - torch::from_blob(v_data_cpu_vec.data(), - {static_cast(v_data_cpu_vec.size())}, - tensor_options.device(torch::kCPU)) - .clone(); - - torch::Tensor k_data_gpu_cpu = k_data_gpu.to(torch::kCPU); - torch::Tensor v_data_gpu_cpu = v_data_gpu.to(torch::kCPU); - - // Check for NaNs - bool nan_detected = hasNaN(k_data_gpu_cpu) || hasNaN(v_data_gpu_cpu); - - // Convert to float for comparison - torch::Tensor k_data_cpu_f32 = k_data_cpu.to(torch::kFloat32); - torch::Tensor v_data_cpu_f32 = v_data_cpu.to(torch::kFloat32); - torch::Tensor k_data_gpu_cpu_f32 = k_data_gpu_cpu.to(torch::kFloat32); - torch::Tensor v_data_gpu_cpu_f32 = v_data_gpu_cpu.to(torch::kFloat32); - - // Check accuracy - torch::Tensor k_close = - torch::isclose(k_data_cpu_f32, k_data_gpu_cpu_f32, 1e-3, 1e-3); - torch::Tensor v_close = - torch::isclose(v_data_cpu_f32, v_data_gpu_cpu_f32, 1e-3, 1e-3); - - float k_accuracy = k_close.sum().item() / k_close.numel(); - float v_accuracy = v_close.sum().item() / v_close.numel(); - float result_accuracy = (k_accuracy + v_accuracy) / 2.0f; - - std::cout << "kv_layout=" << QKVLayoutToString(kv_layout) - << ", page_size=" << page_size << ", batch_size=" << batch_size - << ", num_heads=" << num_heads << ", head_dim=" << head_dim - << ", result_accuracy=" << result_accuracy << std::endl; - - EXPECT_GT(result_accuracy, 0.99) << "Result correctness test failed."; - EXPECT_FALSE(nan_detected) << "Nan detected in the result."; + // Copy data back to CPU for verification + torch::Tensor k_data_cpu = + torch::from_blob(k_data_cpu_vec.data(), {static_cast(k_data_cpu_vec.size())}, + tensor_options.device(torch::kCPU)) + .clone(); + + torch::Tensor v_data_cpu = + torch::from_blob(v_data_cpu_vec.data(), {static_cast(v_data_cpu_vec.size())}, + tensor_options.device(torch::kCPU)) + .clone(); + + torch::Tensor k_data_gpu_cpu = k_data_gpu.to(torch::kCPU); + torch::Tensor v_data_gpu_cpu = v_data_gpu.to(torch::kCPU); + + // Check for NaNs + bool nan_detected = hasNaN(k_data_gpu_cpu) || hasNaN(v_data_gpu_cpu); + + // Convert to float for comparison + torch::Tensor k_data_cpu_f32 = k_data_cpu.to(torch::kFloat32); + torch::Tensor v_data_cpu_f32 = v_data_cpu.to(torch::kFloat32); + torch::Tensor k_data_gpu_cpu_f32 = k_data_gpu_cpu.to(torch::kFloat32); + torch::Tensor v_data_gpu_cpu_f32 = v_data_gpu_cpu.to(torch::kFloat32); + + // Check accuracy + torch::Tensor k_close = torch::isclose(k_data_cpu_f32, k_data_gpu_cpu_f32, 1e-3, 1e-3); + torch::Tensor v_close = torch::isclose(v_data_cpu_f32, v_data_gpu_cpu_f32, 1e-3, 1e-3); + + float k_accuracy = k_close.sum().item() / k_close.numel(); + float v_accuracy = v_close.sum().item() / v_close.numel(); + float result_accuracy = (k_accuracy + v_accuracy) / 2.0f; + + std::cout << "kv_layout=" << QKVLayoutToString(kv_layout) << ", page_size=" << page_size + << ", batch_size=" << batch_size << ", num_heads=" << num_heads + << ", head_dim=" << head_dim << ", result_accuracy=" << result_accuracy << std::endl; + + EXPECT_GT(result_accuracy, 0.99) << "Result correctness test failed."; + EXPECT_FALSE(nan_detected) << "Nan detected in the result."; } // Test fixture for parameterized tests class PagedKVParameterizedTest : public PagedKVTest, - public ::testing::WithParamInterface< - std::tuple> -{ + public ::testing::WithParamInterface> { }; // This is disabled because std::vector cant handle __half dtypes. We will need @@ -434,22 +379,19 @@ class PagedKVParameterizedTest // num_heads, head_dim, kv_layout); // } -TEST_P(PagedKVParameterizedTest, AppendPagedKVKernelCorrectnessTestFP32) -{ - auto params = GetParam(); - size_t page_size = std::get<0>(params); - size_t batch_size = std::get<1>(params); - size_t num_heads = std::get<2>(params); - size_t head_dim = std::get<3>(params); - QKVLayout kv_layout = std::get<4>(params); - - _TestAppendPagedKVKernelCorrectness(page_size, batch_size, num_heads, - head_dim, kv_layout); +TEST_P(PagedKVParameterizedTest, AppendPagedKVKernelCorrectnessTestFP32) { + auto params = GetParam(); + size_t page_size = std::get<0>(params); + size_t batch_size = std::get<1>(params); + size_t num_heads = std::get<2>(params); + size_t head_dim = std::get<3>(params); + QKVLayout kv_layout = std::get<4>(params); + + _TestAppendPagedKVKernelCorrectness(page_size, batch_size, num_heads, head_dim, kv_layout); } // Define parameter combinations -INSTANTIATE_TEST_SUITE_P(PagedKVTests, - PagedKVParameterizedTest, +INSTANTIATE_TEST_SUITE_P(PagedKVTests, PagedKVParameterizedTest, ::testing::Combine( // page_size ::testing::Values(1, 3, 7, 17), @@ -460,45 +402,36 @@ INSTANTIATE_TEST_SUITE_P(PagedKVTests, // head_dim ::testing::Values(64, 128), // kv_layout - ::testing::Values(QKVLayout::kNHD, - QKVLayout::kHND))); + ::testing::Values(QKVLayout::kNHD, QKVLayout::kHND))); // Individual test cases for specific configurations // TEST_F(PagedKVTest, AppendPagedKVKernelSmallConfigFP16) { // _TestAppendPagedKVKernelCorrectness(2, 3, 32, 64, QKVLayout::kHND); // } -TEST_F(PagedKVTest, AppendPagedKVKernelLargeConfigFP32) -{ - _TestAppendPagedKVKernelCorrectness(16, 5, 32, 128, QKVLayout::kHND); +TEST_F(PagedKVTest, AppendPagedKVKernelLargeConfigFP32) { + _TestAppendPagedKVKernelCorrectness(16, 5, 32, 128, QKVLayout::kHND); } #ifdef FLASHINFER_ENABLE_BF16 -TEST_F(PagedKVTest, AppendPagedKVKernelCorrectnessTestBF16) -{ - _TestAppendPagedKVKernelCorrectness<__hip_bfloat16>(4, 2, 32, 64, - QKVLayout::kHND); +TEST_F(PagedKVTest, AppendPagedKVKernelCorrectnessTestBF16) { + _TestAppendPagedKVKernelCorrectness<__hip_bfloat16>(4, 2, 32, 64, QKVLayout::kHND); } #endif #ifdef FLASHINFER_ENABLE_FP8_E4M3 -TEST_F(PagedKVTest, AppendPagedKVKernelCorrectnessTestE4M3) -{ - _TestAppendPagedKVKernelCorrectness<__hip_fp8_e4m3_fnuz>(4, 2, 32, 64, - QKVLayout::kHND); +TEST_F(PagedKVTest, AppendPagedKVKernelCorrectnessTestE4M3) { + _TestAppendPagedKVKernelCorrectness<__hip_fp8_e4m3_fnuz>(4, 2, 32, 64, QKVLayout::kHND); } #endif #ifdef FLASHINFER_ENABLE_FP8_E5M2 -TEST_F(PagedKVTest, AppendPagedKVKernelCorrectnessTestE5M2) -{ - _TestAppendPagedKVKernelCorrectness<__hip_fp8_e5m2_fnuz>(4, 2, 32, 64, - QKVLayout::kHND); +TEST_F(PagedKVTest, AppendPagedKVKernelCorrectnessTestE5M2) { + _TestAppendPagedKVKernelCorrectness<__hip_fp8_e5m2_fnuz>(4, 2, 32, 64, QKVLayout::kHND); } #endif -int main(int argc, char **argv) -{ - testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); } diff --git a/libflashinfer/tests/hip/test_permuted_smem.cpp b/libflashinfer/tests/hip/test_permuted_smem.cpp index e3384c7054..4aa0edc671 100644 --- a/libflashinfer/tests/hip/test_permuted_smem.cpp +++ b/libflashinfer/tests/hip/test_permuted_smem.cpp @@ -1,246 +1,225 @@ -#include "flashinfer/attention/generic/permuted_smem.cuh" -#include "gpu_iface/gpu_runtime_compat.hpp" +#include + #include #include #include -#include #include +#include "flashinfer/attention/generic/permuted_smem.cuh" +#include "gpu_iface/gpu_runtime_compat.hpp" + using namespace flashinfer; // Structure to track both bank and offset -struct MemAccess -{ - uint32_t offset; - uint32_t bank; +struct MemAccess { + uint32_t offset; + uint32_t bank; }; template -__global__ void -test_permuted_offset_k64b_cdna3(uint32_t *smem_write_banks, - uint32_t *smem_advance_col_banks, - uint32_t *smem_advance_row_banks, - uint32_t *offsets) -{ - using BasePtrTy = uint2; - constexpr size_t WARP_STEP_SIZE = 16; - const int tid = threadIdx.x; - constexpr uint32_t stride = sizeof(BasePtrTy) / sizeof(__half); - - // Initial offset for loading phase - uint32_t row = tid / WARP_STEP_SIZE; - uint32_t col = tid % WARP_STEP_SIZE; - - uint32_t offset = - smem_t::template get_permuted_offset(row, col); - - // Store initial offset and bank - offsets[tid] = offset; - smem_write_banks[tid] = (offset * 8 / 4) % 32; // Calculate bank - - // Test advance_offset_by_column - uint32_t col_offset = offset; - for (uint32_t step_idx = 0; step_idx < 4; ++step_idx) { - col_offset = - smem_t::template advance_offset_by_column<4>( - col_offset, step_idx); - // Store banks after column advancement - smem_advance_col_banks[tid * 4 + step_idx] = (col_offset * 8 / 4) % 32; - } - - // Test advance_offset_by_row - uint32_t row_offset = offset; - for (uint32_t j = 0; j < 4; ++j) { - row_offset = - smem_t::template advance_offset_by_row<4, stride>( - row_offset); - // Store banks after row advancement - smem_advance_row_banks[tid * 4 + j] = (row_offset * 8 / 4) % 32; - } +__global__ void test_permuted_offset_k64b_cdna3(uint32_t* smem_write_banks, + uint32_t* smem_advance_col_banks, + uint32_t* smem_advance_row_banks, + uint32_t* offsets) { + using BasePtrTy = uint2; + constexpr size_t WARP_STEP_SIZE = 16; + const int tid = threadIdx.x; + constexpr uint32_t stride = sizeof(BasePtrTy) / sizeof(__half); + + // Initial offset for loading phase + uint32_t row = tid / WARP_STEP_SIZE; + uint32_t col = tid % WARP_STEP_SIZE; + + uint32_t offset = smem_t::template get_permuted_offset(row, col); + + // Store initial offset and bank + offsets[tid] = offset; + smem_write_banks[tid] = (offset * 8 / 4) % 32; // Calculate bank + + // Test advance_offset_by_column + uint32_t col_offset = offset; + for (uint32_t step_idx = 0; step_idx < 4; ++step_idx) { + col_offset = + smem_t::template advance_offset_by_column<4>(col_offset, step_idx); + // Store banks after column advancement + smem_advance_col_banks[tid * 4 + step_idx] = (col_offset * 8 / 4) % 32; + } + + // Test advance_offset_by_row + uint32_t row_offset = offset; + for (uint32_t j = 0; j < 4; ++j) { + row_offset = smem_t::template advance_offset_by_row<4, stride>(row_offset); + // Store banks after row advancement + smem_advance_row_banks[tid * 4 + j] = (row_offset * 8 / 4) % 32; + } } // Test for actual data loading with load_64b_async template -__global__ void test_load_64b_async(const half *src, half *dst, int n_elems) -{ - extern __shared__ uint8_t smem[]; - - using BasePtrTy = uint2; - constexpr size_t WARP_STEP_SIZE = 16; - const int tid = threadIdx.x; - constexpr uint32_t stride = sizeof(BasePtrTy) / sizeof(__half); - - smem_t smem_obj((BasePtrTy *)smem); - - // Initial offset - uint32_t row = tid / WARP_STEP_SIZE; - uint32_t col = tid % WARP_STEP_SIZE; - uint32_t offset = smem_obj.template get_permuted_offset(row, col); - - // Load data - 4 half elements (64 bits) - auto *src_ptr = - reinterpret_cast(src + (row * WARP_STEP_SIZE + col) * 4); - smem_obj.template load_64b_async< - flashinfer::gpu_iface::memory::SharedMemFillMode::kNoFill>( - offset, src_ptr, tid < n_elems); - - // Ensure all loads complete - __syncthreads(); - - // Read back data and verify (copy from shared to global) - if (tid < n_elems) { +__global__ void test_load_64b_async(const half* src, half* dst, int n_elems) { + extern __shared__ uint8_t smem[]; + + using BasePtrTy = uint2; + constexpr size_t WARP_STEP_SIZE = 16; + const int tid = threadIdx.x; + constexpr uint32_t stride = sizeof(BasePtrTy) / sizeof(__half); + + smem_t smem_obj((BasePtrTy*)smem); + + // Initial offset + uint32_t row = tid / WARP_STEP_SIZE; + uint32_t col = tid % WARP_STEP_SIZE; + uint32_t offset = smem_obj.template get_permuted_offset(row, col); + + // Load data - 4 half elements (64 bits) + auto* src_ptr = reinterpret_cast(src + (row * WARP_STEP_SIZE + col) * 4); + smem_obj.template load_64b_async( + offset, src_ptr, tid < n_elems); + + // Ensure all loads complete + __syncthreads(); + + // Read back data and verify (copy from shared to global) + if (tid < n_elems) { // Read directly from original global memory to verify #pragma unroll - for (int i = 0; i < 4; i++) { - // Use the regular layout for reading back, not the permuted one - uint32_t linear_idx = (row * WARP_STEP_SIZE + col) * 4 + i; - dst[linear_idx] = src[linear_idx]; - } + for (int i = 0; i < 4; i++) { + // Use the regular layout for reading back, not the permuted one + uint32_t linear_idx = (row * WARP_STEP_SIZE + col) * 4 + i; + dst[linear_idx] = src[linear_idx]; } + } } -TEST(PermutedOffsetTest, K64B_Comprehensive) -{ - // Allocate device memory for tracking - uint32_t *d_write_banks = nullptr; - uint32_t *d_col_banks = nullptr; - uint32_t *d_row_banks = nullptr; - uint32_t *d_offsets = nullptr; - - ASSERT_EQ(gpuSuccess, gpuMalloc(&d_write_banks, 64 * sizeof(uint32_t))); - ASSERT_EQ(gpuSuccess, gpuMalloc(&d_col_banks, 64 * 4 * sizeof(uint32_t))); - ASSERT_EQ(gpuSuccess, gpuMalloc(&d_row_banks, 64 * 4 * sizeof(uint32_t))); - ASSERT_EQ(gpuSuccess, gpuMalloc(&d_offsets, 64 * sizeof(uint32_t))); - - // Launch kernel to test permutation and advancement - test_permuted_offset_k64b_cdna3 - <<<1, 64>>>(d_write_banks, d_col_banks, d_row_banks, d_offsets); - ASSERT_EQ(gpuSuccess, gpuDeviceSynchronize()); - - // Copy results back to host - std::vector h_write_banks(64); - std::vector h_col_banks(64 * 4); - std::vector h_row_banks(64 * 4); - std::vector h_offsets(64); - - ASSERT_EQ(gpuSuccess, - gpuMemcpy(h_write_banks.data(), d_write_banks, - 64 * sizeof(uint32_t), gpuMemcpyDeviceToHost)); - ASSERT_EQ(gpuSuccess, - gpuMemcpy(h_col_banks.data(), d_col_banks, - 64 * 4 * sizeof(uint32_t), gpuMemcpyDeviceToHost)); - ASSERT_EQ(gpuSuccess, - gpuMemcpy(h_row_banks.data(), d_row_banks, - 64 * 4 * sizeof(uint32_t), gpuMemcpyDeviceToHost)); - ASSERT_EQ(gpuSuccess, - gpuMemcpy(h_offsets.data(), d_offsets, 64 * sizeof(uint32_t), - gpuMemcpyDeviceToHost)); - - // Free tracking memory - ASSERT_EQ(gpuSuccess, gpuFree(d_write_banks)); - ASSERT_EQ(gpuSuccess, gpuFree(d_col_banks)); - ASSERT_EQ(gpuSuccess, gpuFree(d_row_banks)); - ASSERT_EQ(gpuSuccess, gpuFree(d_offsets)); - - // Check for bank conflicts - // 1. Initial write offsets - for (auto row = 0ul; row < 4; ++row) { - std::vector tmp; - for (auto col = 0ul; col < 16; ++col) { - tmp.push_back(h_write_banks[row * 16 + col]); - } - std::sort(tmp.begin(), tmp.end()); - EXPECT_TRUE(std::adjacent_find(tmp.begin(), tmp.end()) == tmp.end()) - << "Bank conflict detected in row " << row << " for initial writes"; +TEST(PermutedOffsetTest, K64B_Comprehensive) { + // Allocate device memory for tracking + uint32_t* d_write_banks = nullptr; + uint32_t* d_col_banks = nullptr; + uint32_t* d_row_banks = nullptr; + uint32_t* d_offsets = nullptr; + + ASSERT_EQ(gpuSuccess, gpuMalloc(&d_write_banks, 64 * sizeof(uint32_t))); + ASSERT_EQ(gpuSuccess, gpuMalloc(&d_col_banks, 64 * 4 * sizeof(uint32_t))); + ASSERT_EQ(gpuSuccess, gpuMalloc(&d_row_banks, 64 * 4 * sizeof(uint32_t))); + ASSERT_EQ(gpuSuccess, gpuMalloc(&d_offsets, 64 * sizeof(uint32_t))); + + // Launch kernel to test permutation and advancement + test_permuted_offset_k64b_cdna3 + <<<1, 64>>>(d_write_banks, d_col_banks, d_row_banks, d_offsets); + ASSERT_EQ(gpuSuccess, gpuDeviceSynchronize()); + + // Copy results back to host + std::vector h_write_banks(64); + std::vector h_col_banks(64 * 4); + std::vector h_row_banks(64 * 4); + std::vector h_offsets(64); + + ASSERT_EQ(gpuSuccess, gpuMemcpy(h_write_banks.data(), d_write_banks, 64 * sizeof(uint32_t), + gpuMemcpyDeviceToHost)); + ASSERT_EQ(gpuSuccess, gpuMemcpy(h_col_banks.data(), d_col_banks, 64 * 4 * sizeof(uint32_t), + gpuMemcpyDeviceToHost)); + ASSERT_EQ(gpuSuccess, gpuMemcpy(h_row_banks.data(), d_row_banks, 64 * 4 * sizeof(uint32_t), + gpuMemcpyDeviceToHost)); + ASSERT_EQ(gpuSuccess, + gpuMemcpy(h_offsets.data(), d_offsets, 64 * sizeof(uint32_t), gpuMemcpyDeviceToHost)); + + // Free tracking memory + ASSERT_EQ(gpuSuccess, gpuFree(d_write_banks)); + ASSERT_EQ(gpuSuccess, gpuFree(d_col_banks)); + ASSERT_EQ(gpuSuccess, gpuFree(d_row_banks)); + ASSERT_EQ(gpuSuccess, gpuFree(d_offsets)); + + // Check for bank conflicts + // 1. Initial write offsets + for (auto row = 0ul; row < 4; ++row) { + std::vector tmp; + for (auto col = 0ul; col < 16; ++col) { + tmp.push_back(h_write_banks[row * 16 + col]); } + std::sort(tmp.begin(), tmp.end()); + EXPECT_TRUE(std::adjacent_find(tmp.begin(), tmp.end()) == tmp.end()) + << "Bank conflict detected in row " << row << " for initial writes"; + } - // 2. Column advancement bank conflicts - for (auto step = 0ul; step < 4; ++step) { - for (auto row = 0ul; row < 4; ++row) { - std::vector tmp; - for (auto col = 0ul; col < 16; ++col) { - tmp.push_back(h_col_banks[(row * 16 + col) * 4 + step]); - } - std::sort(tmp.begin(), tmp.end()); - EXPECT_TRUE(std::adjacent_find(tmp.begin(), tmp.end()) == tmp.end()) - << "Bank conflict detected in row " << row - << " for column step " << step; - } + // 2. Column advancement bank conflicts + for (auto step = 0ul; step < 4; ++step) { + for (auto row = 0ul; row < 4; ++row) { + std::vector tmp; + for (auto col = 0ul; col < 16; ++col) { + tmp.push_back(h_col_banks[(row * 16 + col) * 4 + step]); + } + std::sort(tmp.begin(), tmp.end()); + EXPECT_TRUE(std::adjacent_find(tmp.begin(), tmp.end()) == tmp.end()) + << "Bank conflict detected in row " << row << " for column step " << step; } + } - // 3. Row advancement bank conflicts - for (auto j = 0ul; j < 4; ++j) { - for (auto row = 0ul; row < 4; ++row) { - std::vector tmp; - for (auto col = 0ul; col < 16; ++col) { - tmp.push_back(h_row_banks[(row * 16 + col) * 4 + j]); - } - std::sort(tmp.begin(), tmp.end()); - EXPECT_TRUE(std::adjacent_find(tmp.begin(), tmp.end()) == tmp.end()) - << "Bank conflict detected in row " << row - << " for row advancement " << j; - } + // 3. Row advancement bank conflicts + for (auto j = 0ul; j < 4; ++j) { + for (auto row = 0ul; row < 4; ++row) { + std::vector tmp; + for (auto col = 0ul; col < 16; ++col) { + tmp.push_back(h_row_banks[(row * 16 + col) * 4 + j]); + } + std::sort(tmp.begin(), tmp.end()); + EXPECT_TRUE(std::adjacent_find(tmp.begin(), tmp.end()) == tmp.end()) + << "Bank conflict detected in row " << row << " for row advancement " << j; } - - // Print example access patterns for debugging - printf("Initial write pattern:\n"); - for (int row = 0; row < 4; ++row) { - printf("Row %d: ", row); - for (int col = 0; col < 16; ++col) { - printf("%2d ", h_write_banks[row * 16 + col]); - } - printf("\n"); + } + + // Print example access patterns for debugging + printf("Initial write pattern:\n"); + for (int row = 0; row < 4; ++row) { + printf("Row %d: ", row); + for (int col = 0; col < 16; ++col) { + printf("%2d ", h_write_banks[row * 16 + col]); } + printf("\n"); + } } // Test actual data loading -TEST(PermutedOffsetTest, Load64bAsyncTest) -{ - const int n_threads = 64; - const int n_elements = 4 * n_threads; // 4 half elements per thread - - // Initialize source data - std::vector h_src(n_elements); - for (int i = 0; i < n_elements; i++) { - h_src[i] = __float2half(static_cast(i)); - } - - // Allocate device memory - half *d_src = nullptr; - half *d_dst = nullptr; - ASSERT_EQ(gpuSuccess, gpuMalloc(&d_src, n_elements * sizeof(half))); - ASSERT_EQ(gpuSuccess, gpuMalloc(&d_dst, n_elements * sizeof(half))); - - // Copy source data to device - ASSERT_EQ(gpuSuccess, - gpuMemcpy(d_src, h_src.data(), n_elements * sizeof(half), - gpuMemcpyHostToDevice)); - - // Launch kernel with shared memory - const int smem_size = n_elements * sizeof(half); - test_load_64b_async - <<<1, n_threads, smem_size>>>(d_src, d_dst, n_threads); - ASSERT_EQ(gpuSuccess, gpuDeviceSynchronize()); - - // Copy results back - std::vector h_dst(n_elements); - ASSERT_EQ(gpuSuccess, - gpuMemcpy(h_dst.data(), d_dst, n_elements * sizeof(half), - gpuMemcpyDeviceToHost)); - - // Verify data - for (int i = 0; i < n_elements; i++) { - EXPECT_EQ(__half2float(h_dst[i]), __half2float(h_src[i])) - << "Data mismatch at index " << i; - } - - // Free device memory - ASSERT_EQ(gpuSuccess, gpuFree(d_src)); - ASSERT_EQ(gpuSuccess, gpuFree(d_dst)); +TEST(PermutedOffsetTest, Load64bAsyncTest) { + const int n_threads = 64; + const int n_elements = 4 * n_threads; // 4 half elements per thread + + // Initialize source data + std::vector h_src(n_elements); + for (int i = 0; i < n_elements; i++) { + h_src[i] = __float2half(static_cast(i)); + } + + // Allocate device memory + half* d_src = nullptr; + half* d_dst = nullptr; + ASSERT_EQ(gpuSuccess, gpuMalloc(&d_src, n_elements * sizeof(half))); + ASSERT_EQ(gpuSuccess, gpuMalloc(&d_dst, n_elements * sizeof(half))); + + // Copy source data to device + ASSERT_EQ(gpuSuccess, + gpuMemcpy(d_src, h_src.data(), n_elements * sizeof(half), gpuMemcpyHostToDevice)); + + // Launch kernel with shared memory + const int smem_size = n_elements * sizeof(half); + test_load_64b_async<<<1, n_threads, smem_size>>>(d_src, d_dst, n_threads); + ASSERT_EQ(gpuSuccess, gpuDeviceSynchronize()); + + // Copy results back + std::vector h_dst(n_elements); + ASSERT_EQ(gpuSuccess, + gpuMemcpy(h_dst.data(), d_dst, n_elements * sizeof(half), gpuMemcpyDeviceToHost)); + + // Verify data + for (int i = 0; i < n_elements; i++) { + EXPECT_EQ(__half2float(h_dst[i]), __half2float(h_src[i])) << "Data mismatch at index " << i; + } + + // Free device memory + ASSERT_EQ(gpuSuccess, gpuFree(d_src)); + ASSERT_EQ(gpuSuccess, gpuFree(d_dst)); } -int main(int argc, char **argv) -{ - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); } diff --git a/libflashinfer/tests/hip/test_pos_enc.cpp b/libflashinfer/tests/hip/test_pos_enc.cpp index 625873b537..1e5be5eeb5 100644 --- a/libflashinfer/tests/hip/test_pos_enc.cpp +++ b/libflashinfer/tests/hip/test_pos_enc.cpp @@ -2,143 +2,115 @@ // // SPDX - License - Identifier : Apache 2.0 -#include "flashinfer/attention/generic/pos_enc.cuh" - -#include - #include +#include #include #include #include -namespace flashinfer -{ -namespace -{ - -template class PosEncTest : public ::testing::Test -{ -protected: - void SetUp() override - { - // Initialize HIP device - ASSERT_EQ(hipSetDevice(0), hipSuccess); - } +#include "flashinfer/attention/generic/pos_enc.cuh" - // Helper function to check if two arrays are approximately equal - bool ArraysNearlyEqual(const std::vector &expected, - const std::vector &actual, - float tol = 1e-5) - { - size_t mismatch_count = 0; - if (expected.size() != actual.size()) - return false; - std::cout << "Expected Size: " << expected.size() << std::endl; - for (size_t i = 0; i < expected.size(); ++i) { - if (std::abs(expected[i] - actual[i]) > tol) { - std::cout << "Mismatch at index " << i << ": expected " - << expected[i] << ", actual " << actual[i] - << std::endl; - ++mismatch_count; - } - } - return mismatch_count == 0 ? true : false; +namespace flashinfer { +namespace { + +template +class PosEncTest : public ::testing::Test { + protected: + void SetUp() override { + // Initialize HIP device + ASSERT_EQ(hipSetDevice(0), hipSuccess); + } + + // Helper function to check if two arrays are approximately equal + bool ArraysNearlyEqual(const std::vector& expected, const std::vector& actual, + float tol = 1e-5) { + size_t mismatch_count = 0; + if (expected.size() != actual.size()) return false; + std::cout << "Expected Size: " << expected.size() << std::endl; + for (size_t i = 0; i < expected.size(); ++i) { + if (std::abs(expected[i] - actual[i]) > tol) { + std::cout << "Mismatch at index " << i << ": expected " << expected[i] << ", actual " + << actual[i] << std::endl; + ++mismatch_count; + } } - - // CPU reference implementation for non-interleaved RoPE - void computeRoPEReference(const std::vector &input, - std::vector &output, - const std::vector &freq, - int pos, - int rotary_dim) - { - - size_t half_rotary_dim = rotary_dim / 2; - - for (size_t i = 0; i < input.size(); ++i) { - if (i < rotary_dim) { - // For the first half of dimensions, rotate with the second half - if (i < half_rotary_dim) { - float x1 = input[i]; - float x2 = input[i + half_rotary_dim]; - - float embed = float(pos) * freq[i]; - float cos_val = std::cos(embed); - float sin_val = std::sin(embed); - - output[i] = x1 * cos_val - x2 * sin_val; - output[i + half_rotary_dim] = x1 * sin_val + x2 * cos_val; - } - } - else { - output[i] = input[i]; - } + return mismatch_count == 0 ? true : false; + } + + // CPU reference implementation for non-interleaved RoPE + void computeRoPEReference(const std::vector& input, std::vector& output, + const std::vector& freq, int pos, int rotary_dim) { + size_t half_rotary_dim = rotary_dim / 2; + + for (size_t i = 0; i < input.size(); ++i) { + if (i < rotary_dim) { + // For the first half of dimensions, rotate with the second half + if (i < half_rotary_dim) { + float x1 = input[i]; + float x2 = input[i + half_rotary_dim]; + + float embed = float(pos) * freq[i]; + float cos_val = std::cos(embed); + float sin_val = std::sin(embed); + + output[i] = x1 * cos_val - x2 * sin_val; + output[i + half_rotary_dim] = x1 * sin_val + x2 * cos_val; } + } else { + output[i] = input[i]; + } } - // CPU reference implementation of llama RoPE Interleave - void ComputeRoPEReference_interleave(const std::vector &input, - std::vector &output, - const std::vector &freq, - int pos, - int rotary_dim, - int bdx) - { - - size_t vec_size = freq.size(); - - for (uint32_t thread_idx = 0; thread_idx < bdx; ++thread_idx) { - for (uint32_t i = 0; i < vec_size; ++i) { - uint32_t idx = thread_idx * vec_size + i; - if (idx < rotary_dim) { - float x_i = input[idx]; - float x_j = input[idx ^ 1]; // Element paired with i - - float freq_val = freq[i]; - float embed = float(pos) * freq_val; - float cos_val = std::cos(embed); - float sin_val = std::sin(embed); - - output[idx] = - x_i * cos_val + ((i % 2 == 0) ? -x_j : x_j) * sin_val; - } - else { - output[idx] = input[idx]; - } - } + } + // CPU reference implementation of llama RoPE Interleave + void ComputeRoPEReference_interleave(const std::vector& input, std::vector& output, + const std::vector& freq, int pos, int rotary_dim, + int bdx) { + size_t vec_size = freq.size(); + + for (uint32_t thread_idx = 0; thread_idx < bdx; ++thread_idx) { + for (uint32_t i = 0; i < vec_size; ++i) { + uint32_t idx = thread_idx * vec_size + i; + if (idx < rotary_dim) { + float x_i = input[idx]; + float x_j = input[idx ^ 1]; // Element paired with i + + float freq_val = freq[i]; + float embed = float(pos) * freq_val; + float cos_val = std::cos(embed); + float sin_val = std::sin(embed); + + output[idx] = x_i * cos_val + ((i % 2 == 0) ? -x_j : x_j) * sin_val; + } else { + output[idx] = input[idx]; } + } } - - void ComputeRoPEReference_cos_sin_interleave_reuse_half( - const std::vector &input, - const std::vector &cos, - const std::vector &sin, - std::vector &output, - int rotary_dim, - int vec_size, - int bdx) - { - - for (uint32_t thread_idx = 0; thread_idx < bdx; ++thread_idx) { - for (uint32_t i = 0; i < vec_size; ++i) { - uint32_t idx = thread_idx * vec_size + i; - if (idx < rotary_dim) { - float x_i = input[idx]; - float x_j = input[idx ^ 1]; // Pair element - - // i/2 gives the index of the first half of cos and sin - float cos_val = cos[i / 2]; - float sin_val = sin[i / 2]; - - output[idx] = - x_i * cos_val + ((i % 2 == 0) ? -x_j : x_j) * sin_val; - } - else { - output[idx] = input[idx]; - } - } + } + + void ComputeRoPEReference_cos_sin_interleave_reuse_half(const std::vector& input, + const std::vector& cos, + const std::vector& sin, + std::vector& output, + int rotary_dim, int vec_size, int bdx) { + for (uint32_t thread_idx = 0; thread_idx < bdx; ++thread_idx) { + for (uint32_t i = 0; i < vec_size; ++i) { + uint32_t idx = thread_idx * vec_size + i; + if (idx < rotary_dim) { + float x_i = input[idx]; + float x_j = input[idx ^ 1]; // Pair element + + // i/2 gives the index of the first half of cos and sin + float cos_val = cos[i / 2]; + float sin_val = sin[i / 2]; + + output[idx] = x_i * cos_val + ((i % 2 == 0) ? -x_j : x_j) * sin_val; + } else { + output[idx] = input[idx]; } + } } + } }; using DataTypes = ::testing::Types; @@ -147,253 +119,219 @@ TYPED_TEST_SUITE(PosEncTest, DataTypes); // Create device kernels for testing the vector functions // Test function for non-interleaved mode template -__global__ void test_kernel_normal(T *d_input, - float *d_freq, - float *d_output, - int32_t pos, - uint32_t rotary_dim) -{ - int thread_idx = threadIdx.x; - if (thread_idx < bdx) { - vec_t freq; - freq.load(d_freq); - - vec_t result; - result = - vec_apply_llama_rope(d_input, freq, pos, rotary_dim); - result.store(d_output + thread_idx * vec_size); - } +__global__ void test_kernel_normal(T* d_input, float* d_freq, float* d_output, int32_t pos, + uint32_t rotary_dim) { + int thread_idx = threadIdx.x; + if (thread_idx < bdx) { + vec_t freq; + freq.load(d_freq); + + vec_t result; + result = vec_apply_llama_rope(d_input, freq, pos, rotary_dim); + result.store(d_output + thread_idx * vec_size); + } } // Test function for interleaved mode template -__global__ void test_kernel_interleave(T *d_input, - float *d_freq, - float *d_output, - int32_t pos, - uint32_t rotary_dim) -{ - int thread_idx = threadIdx.x; - if (thread_idx < bdx) { - vec_t freq; - freq.load(d_freq); - - vec_t result; - result = vec_apply_llama_rope_interleave( - d_input, freq, pos, rotary_dim); - result.store(d_output + thread_idx * vec_size); - } +__global__ void test_kernel_interleave(T* d_input, float* d_freq, float* d_output, int32_t pos, + uint32_t rotary_dim) { + int thread_idx = threadIdx.x; + if (thread_idx < bdx) { + vec_t freq; + freq.load(d_freq); + + vec_t result; + result = vec_apply_llama_rope_interleave(d_input, freq, pos, rotary_dim); + result.store(d_output + thread_idx * vec_size); + } } // Test function for cos-sin interleave reuse half template -__global__ void test_kernel_cos_sin_interleave_reuse_half(T *d_input, - float *d_cos, - float *d_sin, - float *d_output, - uint32_t rotary_dim) -{ - int thread_idx = threadIdx.x; - if (thread_idx < bdx) { - vec_t cos_vec, sin_vec; - cos_vec.load(d_cos); - sin_vec.load(d_sin); - - vec_t result = - vec_apply_llama_rope_cos_sin_interleave_reuse_half( - d_input, cos_vec, sin_vec, rotary_dim); - result.store(d_output + thread_idx * vec_size); - } +__global__ void test_kernel_cos_sin_interleave_reuse_half(T* d_input, float* d_cos, float* d_sin, + float* d_output, uint32_t rotary_dim) { + int thread_idx = threadIdx.x; + if (thread_idx < bdx) { + vec_t cos_vec, sin_vec; + cos_vec.load(d_cos); + sin_vec.load(d_sin); + + vec_t result = + vec_apply_llama_rope_cos_sin_interleave_reuse_half(d_input, cos_vec, sin_vec, + rotary_dim); + result.store(d_output + thread_idx * vec_size); + } } -TYPED_TEST(PosEncTest, TestVecApplyLlamaRope) -{ - constexpr uint32_t vec_size = 4; - constexpr uint32_t bdx = 4; - constexpr uint32_t head_dim = vec_size * bdx; - constexpr uint32_t rotary_dim = head_dim; - - // Set position and rotation parameters - const int32_t pos = 10; - const float rope_theta = 10000.0f; - - // Prepare host data - std::vector h_input(head_dim); - std::mt19937 gen(42); - std::uniform_real_distribution dist(-1.0f, 1.0f); - - for (uint32_t i = 0; i < head_dim; ++i) { - h_input[i] = static_cast(dist(gen)); +TYPED_TEST(PosEncTest, TestVecApplyLlamaRope) { + constexpr uint32_t vec_size = 4; + constexpr uint32_t bdx = 4; + constexpr uint32_t head_dim = vec_size * bdx; + constexpr uint32_t rotary_dim = head_dim; + + // Set position and rotation parameters + const int32_t pos = 10; + const float rope_theta = 10000.0f; + + // Prepare host data + std::vector h_input(head_dim); + std::mt19937 gen(42); + std::uniform_real_distribution dist(-1.0f, 1.0f); + + for (uint32_t i = 0; i < head_dim; ++i) { + h_input[i] = static_cast(dist(gen)); + } + + // Create frequencies + std::vector h_freq(vec_size); + for (uint32_t i = 0; i < vec_size; ++i) { + if (i < rotary_dim / 2) { + h_freq[i] = 1.0f / std::pow(rope_theta, static_cast(2 * i) / rotary_dim); + } else { + // For non-interleaved mode + h_freq[i] = + 1.0f / std::pow(rope_theta, static_cast(2 * (i - rotary_dim / 2)) / rotary_dim); } - - // Create frequencies - std::vector h_freq(vec_size); - for (uint32_t i = 0; i < vec_size; ++i) { - if (i < rotary_dim / 2) { - h_freq[i] = 1.0f / std::pow(rope_theta, - static_cast(2 * i) / rotary_dim); - } - else { - // For non-interleaved mode - h_freq[i] = - 1.0f / std::pow(rope_theta, - static_cast(2 * (i - rotary_dim / 2)) / - rotary_dim); - } - } - - // Reference output calculation - std::vector h_ref_output_normal(head_dim); - std::vector h_ref_output_interleave(head_dim); - - // Calculate reference outputs - std::vector h_input_float(h_input.begin(), h_input.end()); - this->ComputeRoPEReference_interleave( - h_input_float, h_ref_output_interleave, h_freq, pos, rotary_dim, bdx); - this->computeRoPEReference(h_input_float, h_ref_output_normal, h_freq, pos, - rotary_dim); - - // Allocate device memory - TypeParam *d_input; - float *d_freq; - float *d_output_normal; - float *d_output_interleave; - - ASSERT_EQ(hipMalloc(&d_input, head_dim * sizeof(TypeParam)), hipSuccess); - ASSERT_EQ(hipMalloc(&d_freq, vec_size * sizeof(float)), hipSuccess); - ASSERT_EQ(hipMalloc(&d_output_normal, head_dim * sizeof(float)), - hipSuccess); - ASSERT_EQ(hipMalloc(&d_output_interleave, head_dim * sizeof(float)), - hipSuccess); - - // Copy data to device - ASSERT_EQ(hipMemcpy(d_input, h_input.data(), head_dim * sizeof(TypeParam), - hipMemcpyHostToDevice), - hipSuccess); - ASSERT_EQ(hipMemcpy(d_freq, h_freq.data(), vec_size * sizeof(float), - hipMemcpyHostToDevice), - hipSuccess); - - // Launch kernel - test_kernel_interleave<<>>( - d_input, d_freq, d_output_interleave, pos, rotary_dim); - test_kernel_normal<<>>( - d_input, d_freq, d_output_normal, pos, rotary_dim); - - // Copy results back - std::vector h_output_interleave(head_dim); - std::vector h_output_normal(head_dim); - - ASSERT_EQ(hipMemcpy(h_output_normal.data(), d_output_normal, - head_dim * sizeof(float), hipMemcpyDeviceToHost), - hipSuccess); - ASSERT_EQ(hipMemcpy(h_output_interleave.data(), d_output_interleave, - head_dim * sizeof(float), hipMemcpyDeviceToHost), - hipSuccess); - - // Verify results - - // EXPECT_TRUE(this->ArraysNearlyEqual(h_ref_output_normal, - // h_output_normal)); // Disabled due to flakiness - EXPECT_TRUE( - this->ArraysNearlyEqual(h_ref_output_interleave, h_output_interleave)); - - // Free device memory - hipFree(d_input); - hipFree(d_freq); - hipFree(d_output_interleave); + } + + // Reference output calculation + std::vector h_ref_output_normal(head_dim); + std::vector h_ref_output_interleave(head_dim); + + // Calculate reference outputs + std::vector h_input_float(h_input.begin(), h_input.end()); + this->ComputeRoPEReference_interleave(h_input_float, h_ref_output_interleave, h_freq, pos, + rotary_dim, bdx); + this->computeRoPEReference(h_input_float, h_ref_output_normal, h_freq, pos, rotary_dim); + + // Allocate device memory + TypeParam* d_input; + float* d_freq; + float* d_output_normal; + float* d_output_interleave; + + ASSERT_EQ(hipMalloc(&d_input, head_dim * sizeof(TypeParam)), hipSuccess); + ASSERT_EQ(hipMalloc(&d_freq, vec_size * sizeof(float)), hipSuccess); + ASSERT_EQ(hipMalloc(&d_output_normal, head_dim * sizeof(float)), hipSuccess); + ASSERT_EQ(hipMalloc(&d_output_interleave, head_dim * sizeof(float)), hipSuccess); + + // Copy data to device + ASSERT_EQ(hipMemcpy(d_input, h_input.data(), head_dim * sizeof(TypeParam), hipMemcpyHostToDevice), + hipSuccess); + ASSERT_EQ(hipMemcpy(d_freq, h_freq.data(), vec_size * sizeof(float), hipMemcpyHostToDevice), + hipSuccess); + + // Launch kernel + test_kernel_interleave + <<>>(d_input, d_freq, d_output_interleave, pos, rotary_dim); + test_kernel_normal + <<>>(d_input, d_freq, d_output_normal, pos, rotary_dim); + + // Copy results back + std::vector h_output_interleave(head_dim); + std::vector h_output_normal(head_dim); + + ASSERT_EQ(hipMemcpy(h_output_normal.data(), d_output_normal, head_dim * sizeof(float), + hipMemcpyDeviceToHost), + hipSuccess); + ASSERT_EQ(hipMemcpy(h_output_interleave.data(), d_output_interleave, head_dim * sizeof(float), + hipMemcpyDeviceToHost), + hipSuccess); + + // Verify results + + // EXPECT_TRUE(this->ArraysNearlyEqual(h_ref_output_normal, + // h_output_normal)); // Disabled due to flakiness + EXPECT_TRUE(this->ArraysNearlyEqual(h_ref_output_interleave, h_output_interleave)); + + // Free device memory + hipFree(d_input); + hipFree(d_freq); + hipFree(d_output_interleave); } -TYPED_TEST(PosEncTest, TestVecApplyLlamaRopeCosSinInterleaveReuseHalf) -{ - constexpr uint32_t vec_size = 8; - constexpr uint32_t bdx = 8; - constexpr uint32_t head_dim = vec_size * bdx; - constexpr uint32_t rotary_dim = head_dim; - - // Prepare host data - std::vector h_input(head_dim); - std::mt19937 gen(42); - std::uniform_real_distribution dist(-1.0f, 1.0f); - - for (uint32_t i = 0; i < head_dim; ++i) { - h_input[i] = static_cast(dist(gen)); - } - - // Create cos/sin values directly - std::vector h_cos(vec_size); - std::vector h_sin(vec_size); - - // Create a series of cos/sin values as if they were precomputed - for (uint32_t i = 0; i < vec_size; ++i) { - float theta = static_cast(i) * 0.1f; - h_cos[i] = std::cos(theta); - h_sin[i] = std::sin(theta); - } - - // Expected output calculation (based on the implementation logic) - std::vector h_expected_output(head_dim); - std::vector h_input_float(h_input.begin(), h_input.end()); - - this->ComputeRoPEReference_cos_sin_interleave_reuse_half( - h_input_float, h_cos, h_sin, h_expected_output, rotary_dim, vec_size, - bdx); - - // Allocate device memory - TypeParam *d_input; - float *d_cos; - float *d_sin; - float *d_output; - - ASSERT_EQ(hipMalloc(&d_input, head_dim * sizeof(TypeParam)), hipSuccess); - ASSERT_EQ(hipMalloc(&d_cos, vec_size * sizeof(float)), hipSuccess); - ASSERT_EQ(hipMalloc(&d_sin, vec_size * sizeof(float)), hipSuccess); - ASSERT_EQ(hipMalloc(&d_output, head_dim * sizeof(float)), hipSuccess); - - // Copy data to device - ASSERT_EQ(hipMemcpy(d_input, h_input.data(), head_dim * sizeof(TypeParam), - hipMemcpyHostToDevice), - hipSuccess); - ASSERT_EQ(hipMemcpy(d_cos, h_cos.data(), vec_size * sizeof(float), - hipMemcpyHostToDevice), - hipSuccess); - ASSERT_EQ(hipMemcpy(d_sin, h_sin.data(), vec_size * sizeof(float), - hipMemcpyHostToDevice), - hipSuccess); - - // Launch kernel - test_kernel_cos_sin_interleave_reuse_half - <<>>(d_input, d_cos, d_sin, d_output, - rotary_dim); - - // Copy result back - std::vector h_output(head_dim); - ASSERT_EQ(hipMemcpy(h_output.data(), d_output, head_dim * sizeof(float), - hipMemcpyDeviceToHost), - hipSuccess); - - // Verify results - EXPECT_TRUE(this->ArraysNearlyEqual(h_expected_output, h_output)); - - // Free device memory - hipFree(d_input); - hipFree(d_cos); - hipFree(d_sin); - hipFree(d_output); +TYPED_TEST(PosEncTest, TestVecApplyLlamaRopeCosSinInterleaveReuseHalf) { + constexpr uint32_t vec_size = 8; + constexpr uint32_t bdx = 8; + constexpr uint32_t head_dim = vec_size * bdx; + constexpr uint32_t rotary_dim = head_dim; + + // Prepare host data + std::vector h_input(head_dim); + std::mt19937 gen(42); + std::uniform_real_distribution dist(-1.0f, 1.0f); + + for (uint32_t i = 0; i < head_dim; ++i) { + h_input[i] = static_cast(dist(gen)); + } + + // Create cos/sin values directly + std::vector h_cos(vec_size); + std::vector h_sin(vec_size); + + // Create a series of cos/sin values as if they were precomputed + for (uint32_t i = 0; i < vec_size; ++i) { + float theta = static_cast(i) * 0.1f; + h_cos[i] = std::cos(theta); + h_sin[i] = std::sin(theta); + } + + // Expected output calculation (based on the implementation logic) + std::vector h_expected_output(head_dim); + std::vector h_input_float(h_input.begin(), h_input.end()); + + this->ComputeRoPEReference_cos_sin_interleave_reuse_half( + h_input_float, h_cos, h_sin, h_expected_output, rotary_dim, vec_size, bdx); + + // Allocate device memory + TypeParam* d_input; + float* d_cos; + float* d_sin; + float* d_output; + + ASSERT_EQ(hipMalloc(&d_input, head_dim * sizeof(TypeParam)), hipSuccess); + ASSERT_EQ(hipMalloc(&d_cos, vec_size * sizeof(float)), hipSuccess); + ASSERT_EQ(hipMalloc(&d_sin, vec_size * sizeof(float)), hipSuccess); + ASSERT_EQ(hipMalloc(&d_output, head_dim * sizeof(float)), hipSuccess); + + // Copy data to device + ASSERT_EQ(hipMemcpy(d_input, h_input.data(), head_dim * sizeof(TypeParam), hipMemcpyHostToDevice), + hipSuccess); + ASSERT_EQ(hipMemcpy(d_cos, h_cos.data(), vec_size * sizeof(float), hipMemcpyHostToDevice), + hipSuccess); + ASSERT_EQ(hipMemcpy(d_sin, h_sin.data(), vec_size * sizeof(float), hipMemcpyHostToDevice), + hipSuccess); + + // Launch kernel + test_kernel_cos_sin_interleave_reuse_half + <<>>(d_input, d_cos, d_sin, d_output, rotary_dim); + + // Copy result back + std::vector h_output(head_dim); + ASSERT_EQ(hipMemcpy(h_output.data(), d_output, head_dim * sizeof(float), hipMemcpyDeviceToHost), + hipSuccess); + + // Verify results + EXPECT_TRUE(this->ArraysNearlyEqual(h_expected_output, h_output)); + + // Free device memory + hipFree(d_input); + hipFree(d_cos); + hipFree(d_sin); + hipFree(d_output); } -TEST(PosEncodingModeTest, EnumToString) -{ - EXPECT_EQ("None", PosEncodingModeToString(PosEncodingMode::kNone)); - EXPECT_EQ("Llama", PosEncodingModeToString(PosEncodingMode::kRoPELlama)); - EXPECT_EQ("ALiBi", PosEncodingModeToString(PosEncodingMode::kALiBi)); +TEST(PosEncodingModeTest, EnumToString) { + EXPECT_EQ("None", PosEncodingModeToString(PosEncodingMode::kNone)); + EXPECT_EQ("Llama", PosEncodingModeToString(PosEncodingMode::kRoPELlama)); + EXPECT_EQ("ALiBi", PosEncodingModeToString(PosEncodingMode::kALiBi)); } -} // namespace -} // namespace flashinfer +} // namespace +} // namespace flashinfer -int main(int argc, char **argv) -{ - testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); } diff --git a/libflashinfer/tests/hip/test_produce_kv.cpp b/libflashinfer/tests/hip/test_produce_kv.cpp index 90fcee6656..d6b79d173f 100644 --- a/libflashinfer/tests/hip/test_produce_kv.cpp +++ b/libflashinfer/tests/hip/test_produce_kv.cpp @@ -1,232 +1,204 @@ +#include + #include #include -#include #include #include // Constants constexpr uint32_t WARP_SIZE_NV = 32; constexpr uint32_t WARP_SIZE_AMD = 64; -constexpr uint32_t WARP_STEP_SIZE = 16; // 16 threads per warp row for AMD -constexpr uint32_t WARP_THREAD_ROWS = 4; // 4 rows of threads in a warp for AMD +constexpr uint32_t WARP_STEP_SIZE = 16; // 16 threads per warp row for AMD +constexpr uint32_t WARP_THREAD_ROWS = 4; // 4 rows of threads in a warp for AMD // SwizzleMode enum to match the original code -enum class SwizzleMode -{ - k64B = 0U, // Original NVIDIA mode (32 threads, 8 rows x 4 columns) - k128B = 1U, // Original pseudo-128B mode (32 threads, 4 rows x 8 columns) - kLinear = 2U // New AMD-specific mode (64 threads, 4 rows x 16 columns) +enum class SwizzleMode { + k64B = 0U, // Original NVIDIA mode (32 threads, 8 rows x 4 columns) + k128B = 1U, // Original pseudo-128B mode (32 threads, 4 rows x 8 columns) + kLinear = 2U // New AMD-specific mode (64 threads, 4 rows x 16 columns) }; // Simplified linear shared memory operations (CPU implementation) template -uint32_t get_permuted_offset_linear(uint32_t row, uint32_t col) -{ - return row * stride + col; +uint32_t get_permuted_offset_linear(uint32_t row, uint32_t col) { + return row * stride + col; } template -uint32_t advance_offset_by_column_linear(uint32_t offset, uint32_t step_idx) -{ - return offset + step_size; +uint32_t advance_offset_by_column_linear(uint32_t offset, uint32_t step_idx) { + return offset + step_size; } template -uint32_t advance_offset_by_row_linear(uint32_t offset) -{ - return offset + step_size * row_stride; +uint32_t advance_offset_by_row_linear(uint32_t offset) { + return offset + step_size * row_stride; } // CPU-based simulation of produce_kv for AMD MI300 with linear offset // addressing template -void SimulateProduceKV(std::vector &thread_ids_at_offsets) -{ - // Constants for MI300 (64-thread warp, 4×16 thread layout) - constexpr uint32_t WARP_SIZE = 64; - constexpr uint32_t WARP_THREAD_ROWS = 4; // 4 rows of threads - constexpr uint32_t WARP_STEP_SIZE = 16; // 16 threads per row - constexpr uint32_t ELEMS_PER_THREAD = - 4; // Each thread loads 4 fp16 elements - - // Derived constants - constexpr uint32_t UPCAST_STRIDE = HEAD_DIM / ELEMS_PER_THREAD; - constexpr uint32_t NUM_MMA_D = HEAD_DIM / 16; - constexpr uint32_t grid_width = HEAD_DIM / ELEMS_PER_THREAD; - constexpr uint32_t grid_height = 16 * NUM_MMA_KV; - constexpr uint32_t NUM_WARPS = 1; - constexpr uint32_t NUM_WARPS_Q = 1; - constexpr uint32_t COLUMN_RESET_OFFSET = (NUM_MMA_D / 4) * WARP_STEP_SIZE; - //(NUM_MMA_D / (4 / sizeof(uint16_t))) * WARP_STEP_SIZE; - - // Initialize with -1 (unwritten) - thread_ids_at_offsets.assign(grid_height * grid_width, -1); - - // Simulate each thread's write pattern - for (uint32_t tid = 0; tid < WARP_SIZE; tid++) { - uint32_t warp_idx = 0; // Always 0 for single warp - uint32_t lane_idx = tid; - - // Calculate thread's row and column - uint32_t row = lane_idx / WARP_STEP_SIZE; - uint32_t col = lane_idx % WARP_STEP_SIZE; - - // Calculate initial offset - uint32_t kv_smem_offset_w = get_permuted_offset_linear( - warp_idx * WARP_THREAD_ROWS + row, col); - - // Initial kv_idx points to the first row this thread handles - uint32_t kv_idx = warp_idx * WARP_THREAD_ROWS + row; - - // Handle all blocks of rows - for (uint32_t i = 0; i < NUM_MMA_KV * 4 / NUM_WARPS_Q; ++i) { - // Process columns within a row (each thread loads 4 elements per - // iteration) - // for (uint32_t j = 0; j < NUM_MMA_D / (4 / sizeof(uint16_t)); ++j) - // { - for (uint32_t j = 0; j < NUM_MMA_D / 4; ++j) { - // Record which thread writes to this offset - // if(tid == 0) { - // std::cout << "tid : " << tid << " kv_smem_offset_w at - // start " << kv_smem_offset_w << '\n'; - // } - if (kv_smem_offset_w < grid_height * grid_width && - kv_idx < grid_height) - { - thread_ids_at_offsets[kv_smem_offset_w] = tid; - } - else { - std::cerr << "ERROR: Out of bound offset (" - << kv_smem_offset_w << ") at " << tid << '\n'; - } - - // Advance to next column by 16 (number of threads per row) - kv_smem_offset_w = - advance_offset_by_column_linear( - kv_smem_offset_w, j); - // if(tid == 0) { - // std::cout << "tid : " << tid << " kv_smem_offset_w after - // column inc: " << kv_smem_offset_w << '\n'; - // } - } - - // Move to next set of rows - kv_idx += WARP_THREAD_ROWS; - - // if(tid == 0) { - // std::cout << "tid : " << tid << " kv_smem_offset_w before row - // inc " << kv_smem_offset_w << '\n'; - // } - // Reset column position and advance rows - kv_smem_offset_w = - advance_offset_by_row_linear(kv_smem_offset_w) - - COLUMN_RESET_OFFSET; - - // if(tid == 0) { - // std::cout << "tid : " << tid << " kv_smem_offset_w after row - // inc " << kv_smem_offset_w << '\n'; - // } +void SimulateProduceKV(std::vector& thread_ids_at_offsets) { + // Constants for MI300 (64-thread warp, 4×16 thread layout) + constexpr uint32_t WARP_SIZE = 64; + constexpr uint32_t WARP_THREAD_ROWS = 4; // 4 rows of threads + constexpr uint32_t WARP_STEP_SIZE = 16; // 16 threads per row + constexpr uint32_t ELEMS_PER_THREAD = 4; // Each thread loads 4 fp16 elements + + // Derived constants + constexpr uint32_t UPCAST_STRIDE = HEAD_DIM / ELEMS_PER_THREAD; + constexpr uint32_t NUM_MMA_D = HEAD_DIM / 16; + constexpr uint32_t grid_width = HEAD_DIM / ELEMS_PER_THREAD; + constexpr uint32_t grid_height = 16 * NUM_MMA_KV; + constexpr uint32_t NUM_WARPS = 1; + constexpr uint32_t NUM_WARPS_Q = 1; + constexpr uint32_t COLUMN_RESET_OFFSET = (NUM_MMA_D / 4) * WARP_STEP_SIZE; + //(NUM_MMA_D / (4 / sizeof(uint16_t))) * WARP_STEP_SIZE; + + // Initialize with -1 (unwritten) + thread_ids_at_offsets.assign(grid_height * grid_width, -1); + + // Simulate each thread's write pattern + for (uint32_t tid = 0; tid < WARP_SIZE; tid++) { + uint32_t warp_idx = 0; // Always 0 for single warp + uint32_t lane_idx = tid; + + // Calculate thread's row and column + uint32_t row = lane_idx / WARP_STEP_SIZE; + uint32_t col = lane_idx % WARP_STEP_SIZE; + + // Calculate initial offset + uint32_t kv_smem_offset_w = + get_permuted_offset_linear(warp_idx * WARP_THREAD_ROWS + row, col); + + // Initial kv_idx points to the first row this thread handles + uint32_t kv_idx = warp_idx * WARP_THREAD_ROWS + row; + + // Handle all blocks of rows + for (uint32_t i = 0; i < NUM_MMA_KV * 4 / NUM_WARPS_Q; ++i) { + // Process columns within a row (each thread loads 4 elements per + // iteration) + // for (uint32_t j = 0; j < NUM_MMA_D / (4 / sizeof(uint16_t)); ++j) + // { + for (uint32_t j = 0; j < NUM_MMA_D / 4; ++j) { + // Record which thread writes to this offset + // if(tid == 0) { + // std::cout << "tid : " << tid << " kv_smem_offset_w at + // start " << kv_smem_offset_w << '\n'; + // } + if (kv_smem_offset_w < grid_height * grid_width && kv_idx < grid_height) { + thread_ids_at_offsets[kv_smem_offset_w] = tid; + } else { + std::cerr << "ERROR: Out of bound offset (" << kv_smem_offset_w << ") at " << tid << '\n'; } - // FIXME: Verify with original in prefill.cuh - kv_smem_offset_w -= 16 * NUM_MMA_KV * UPCAST_STRIDE; + + // Advance to next column by 16 (number of threads per row) + kv_smem_offset_w = advance_offset_by_column_linear(kv_smem_offset_w, j); + // if(tid == 0) { + // std::cout << "tid : " << tid << " kv_smem_offset_w after + // column inc: " << kv_smem_offset_w << '\n'; + // } + } + + // Move to next set of rows + kv_idx += WARP_THREAD_ROWS; + + // if(tid == 0) { + // std::cout << "tid : " << tid << " kv_smem_offset_w before row + // inc " << kv_smem_offset_w << '\n'; + // } + // Reset column position and advance rows + kv_smem_offset_w = advance_offset_by_row_linear( + kv_smem_offset_w) - + COLUMN_RESET_OFFSET; + + // if(tid == 0) { + // std::cout << "tid : " << tid << " kv_smem_offset_w after row + // inc " << kv_smem_offset_w << '\n'; + // } } + // FIXME: Verify with original in prefill.cuh + kv_smem_offset_w -= 16 * NUM_MMA_KV * UPCAST_STRIDE; + } } // Helper function to run the test -template void RunProduceKVTest() -{ - constexpr uint32_t grid_width = HEAD_DIM / 4; // 16 for 64, 32 for 128 - constexpr uint32_t grid_height = - 16 * NUM_MMA_KV; // 16 for NUM_MMA_KV=1, 32 for NUM_MMA_KV=2 - - printf("\n=== Testing produce_kv with HEAD_DIM = %u, NUM_MMA_KV = %u ===\n", - HEAD_DIM, NUM_MMA_KV); - - // Host array to store thread IDs at each offset - std::vector thread_ids(grid_height * grid_width, -1); - - // Run CPU simulation of produce_kv - SimulateProduceKV(thread_ids); - - // Print the grid of thread IDs - printf("Thread IDs writing to each offset (%dx%d grid):\n", grid_height, - grid_width); - - // Column headers - printf(" "); - for (int c = 0; c < std::min(32, (int)grid_width); c++) { - printf("%3d ", c); - if (c == 15 && grid_width > 16) - printf("| "); - } - printf("\n +"); +template +void RunProduceKVTest() { + constexpr uint32_t grid_width = HEAD_DIM / 4; // 16 for 64, 32 for 128 + constexpr uint32_t grid_height = 16 * NUM_MMA_KV; // 16 for NUM_MMA_KV=1, 32 for NUM_MMA_KV=2 + + printf("\n=== Testing produce_kv with HEAD_DIM = %u, NUM_MMA_KV = %u ===\n", HEAD_DIM, + NUM_MMA_KV); + + // Host array to store thread IDs at each offset + std::vector thread_ids(grid_height * grid_width, -1); + + // Run CPU simulation of produce_kv + SimulateProduceKV(thread_ids); + + // Print the grid of thread IDs + printf("Thread IDs writing to each offset (%dx%d grid):\n", grid_height, grid_width); + + // Column headers + printf(" "); + for (int c = 0; c < std::min(32, (int)grid_width); c++) { + printf("%3d ", c); + if (c == 15 && grid_width > 16) printf("| "); + } + printf("\n +"); + for (int c = 0; c < std::min(32, (int)grid_width); c++) { + printf("----"); + if (c == 15 && grid_width > 16) printf("+"); + } + printf("\n"); + + // Print grid with clear separation + for (int r = 0; r < grid_height; r++) { + printf("%2d | ", r); for (int c = 0; c < std::min(32, (int)grid_width); c++) { - printf("----"); - if (c == 15 && grid_width > 16) - printf("+"); + int thread_id = thread_ids[r * grid_width + c]; + if (thread_id >= 0) { + printf("%3d ", thread_id); + } else { + printf(" . "); + } + if (c == 15 && grid_width > 16) printf("| "); } printf("\n"); - // Print grid with clear separation - for (int r = 0; r < grid_height; r++) { - printf("%2d | ", r); - for (int c = 0; c < std::min(32, (int)grid_width); c++) { - int thread_id = thread_ids[r * grid_width + c]; - if (thread_id >= 0) { - printf("%3d ", thread_id); - } - else { - printf(" . "); - } - if (c == 15 && grid_width > 16) - printf("| "); - } - printf("\n"); - - // Add horizontal divider between blocks - if (r == 15 && NUM_MMA_KV > 1) { - printf(" +"); - for (int c = 0; c < std::min(32, (int)grid_width); c++) { - printf("----"); - if (c == 15 && grid_width > 16) - printf("+"); - } - printf("\n"); - } + // Add horizontal divider between blocks + if (r == 15 && NUM_MMA_KV > 1) { + printf(" +"); + for (int c = 0; c < std::min(32, (int)grid_width); c++) { + printf("----"); + if (c == 15 && grid_width > 16) printf("+"); + } + printf("\n"); } + } - // Check for unwritten positions - int unwritten = 0; - for (int i = 0; i < grid_height * grid_width; i++) { - if (thread_ids[i] == -1) { - unwritten++; - } + // Check for unwritten positions + int unwritten = 0; + for (int i = 0; i < grid_height * grid_width; i++) { + if (thread_ids[i] == -1) { + unwritten++; } - - // Print statistics - printf("\nStatistics:\n"); - printf("- Positions written: %d/%d (%.1f%%)\n", - grid_height * grid_width - unwritten, grid_height * grid_width, - 100.0f * (grid_height * grid_width - unwritten) / - (grid_height * grid_width)); - printf("- Unwritten positions: %d/%d (%.1f%%)\n", unwritten, - grid_height * grid_width, - 100.0f * unwritten / (grid_height * grid_width)); + } + + // Print statistics + printf("\nStatistics:\n"); + printf("- Positions written: %d/%d (%.1f%%)\n", grid_height * grid_width - unwritten, + grid_height * grid_width, + 100.0f * (grid_height * grid_width - unwritten) / (grid_height * grid_width)); + printf("- Unwritten positions: %d/%d (%.1f%%)\n", unwritten, grid_height * grid_width, + 100.0f * unwritten / (grid_height * grid_width)); } -TEST(KVCacheWritePatternTest, HeadDim64_AMD_kLinear) -{ - RunProduceKVTest<64, 1>(); -} +TEST(KVCacheWritePatternTest, HeadDim64_AMD_kLinear) { RunProduceKVTest<64, 1>(); } -TEST(KVCacheWritePatternTest, HeadDim128_AMD_kLinear) -{ - RunProduceKVTest<128, 1>(); -} +TEST(KVCacheWritePatternTest, HeadDim128_AMD_kLinear) { RunProduceKVTest<128, 1>(); } -int main(int argc, char **argv) -{ - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); } diff --git a/libflashinfer/tests/hip/test_q_smem_read_pattern.cpp b/libflashinfer/tests/hip/test_q_smem_read_pattern.cpp index 12cddca0c0..61fb5b1016 100644 --- a/libflashinfer/tests/hip/test_q_smem_read_pattern.cpp +++ b/libflashinfer/tests/hip/test_q_smem_read_pattern.cpp @@ -1,180 +1,156 @@ +#include + #include #include -#include #include #include // Constants for MI300 -constexpr uint32_t WARP_STEP_SIZE = 16; // 16 threads per warp row +constexpr uint32_t WARP_STEP_SIZE = 16; // 16 threads per warp row constexpr uint32_t QUERY_ELEMS_PER_THREAD = 4; -constexpr uint32_t WARP_THREAD_ROWS = 4; // 4 rows of threads in a warp +constexpr uint32_t WARP_THREAD_ROWS = 4; // 4 rows of threads in a warp // Simplified linear shared memory operations (CPU implementation) template -uint32_t get_permuted_offset_linear(uint32_t row, uint32_t col) -{ - return row * stride + col; +uint32_t get_permuted_offset_linear(uint32_t row, uint32_t col) { + return row * stride + col; } template -uint32_t advance_offset_by_column_linear(uint32_t offset, uint32_t step_idx) -{ - return offset + step_size; +uint32_t advance_offset_by_column_linear(uint32_t offset, uint32_t step_idx) { + return offset + step_size; } template -uint32_t advance_offset_by_row_linear(uint32_t offset) -{ - return offset + step_size * row_stride; +uint32_t advance_offset_by_row_linear(uint32_t offset) { + return offset + step_size * row_stride; } // CPU-based simulation of the read pattern in compute_qk template -void SimulateReadPattern(std::vector &thread_ids_reading_offsets) -{ - // Constants derived from HEAD_DIM - constexpr uint32_t UPCAST_STRIDE_Q = HEAD_DIM / QUERY_ELEMS_PER_THREAD; - constexpr uint32_t NUM_MMA_D_QK = HEAD_DIM / 16; - constexpr uint32_t grid_width = - (HEAD_DIM / QUERY_ELEMS_PER_THREAD); // 16 for 64, 32 for 128 - constexpr uint32_t grid_height = - 16 * NUM_MMA_Q; // 16 for NUM_MMA_Q=1, 32 for NUM_MMA_Q=2 - - // Initialize with -1 (unread) - thread_ids_reading_offsets.assign(grid_height * grid_width, -1); - - // Simulate each thread's read pattern - for (uint32_t tid = 0; tid < 64; tid++) { - // Map tid to kernel's lane_idx (same for a single warp) - uint32_t lane_idx = tid; - - // Get warp_idx_q (this is 0 for our single warp simulation) - uint32_t warp_idx_q = 0; - - // Exactly match the kernel's initial offset calculation - uint32_t q_smem_offset_r = get_permuted_offset_linear( - warp_idx_q * NUM_MMA_Q * 16 + lane_idx % 16, lane_idx / 16); - - // Follow exactly the same loop structure as in compute_qk - for (uint32_t mma_d = 0; mma_d < NUM_MMA_D_QK; ++mma_d) { - for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { - // This would be a ldmatrix_m8n8x4 call in the actual code - uint32_t read_row = q_smem_offset_r / UPCAST_STRIDE_Q; - uint32_t read_col = q_smem_offset_r % UPCAST_STRIDE_Q; - - if (read_row < grid_height && read_col < grid_width) { - thread_ids_reading_offsets[read_row * grid_width + - read_col] = tid; - } - - // Advance to next row, exactly as in compute_qk - q_smem_offset_r = - advance_offset_by_row_linear<16, UPCAST_STRIDE_Q>( - q_smem_offset_r); - } - - // Reset row position and advance to next column, exactly as in - // compute_qk - q_smem_offset_r = - advance_offset_by_column_linear<4>(q_smem_offset_r, mma_d) - - NUM_MMA_Q * 16 * UPCAST_STRIDE_Q; +void SimulateReadPattern(std::vector& thread_ids_reading_offsets) { + // Constants derived from HEAD_DIM + constexpr uint32_t UPCAST_STRIDE_Q = HEAD_DIM / QUERY_ELEMS_PER_THREAD; + constexpr uint32_t NUM_MMA_D_QK = HEAD_DIM / 16; + constexpr uint32_t grid_width = (HEAD_DIM / QUERY_ELEMS_PER_THREAD); // 16 for 64, 32 for 128 + constexpr uint32_t grid_height = 16 * NUM_MMA_Q; // 16 for NUM_MMA_Q=1, 32 for NUM_MMA_Q=2 + + // Initialize with -1 (unread) + thread_ids_reading_offsets.assign(grid_height * grid_width, -1); + + // Simulate each thread's read pattern + for (uint32_t tid = 0; tid < 64; tid++) { + // Map tid to kernel's lane_idx (same for a single warp) + uint32_t lane_idx = tid; + + // Get warp_idx_q (this is 0 for our single warp simulation) + uint32_t warp_idx_q = 0; + + // Exactly match the kernel's initial offset calculation + uint32_t q_smem_offset_r = get_permuted_offset_linear( + warp_idx_q * NUM_MMA_Q * 16 + lane_idx % 16, lane_idx / 16); + + // Follow exactly the same loop structure as in compute_qk + for (uint32_t mma_d = 0; mma_d < NUM_MMA_D_QK; ++mma_d) { + for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { + // This would be a ldmatrix_m8n8x4 call in the actual code + uint32_t read_row = q_smem_offset_r / UPCAST_STRIDE_Q; + uint32_t read_col = q_smem_offset_r % UPCAST_STRIDE_Q; + + if (read_row < grid_height && read_col < grid_width) { + thread_ids_reading_offsets[read_row * grid_width + read_col] = tid; } + + // Advance to next row, exactly as in compute_qk + q_smem_offset_r = advance_offset_by_row_linear<16, UPCAST_STRIDE_Q>(q_smem_offset_r); + } + + // Reset row position and advance to next column, exactly as in + // compute_qk + q_smem_offset_r = advance_offset_by_column_linear<4>(q_smem_offset_r, mma_d) - + NUM_MMA_Q * 16 * UPCAST_STRIDE_Q; } + } } // Helper function to run the test with configurable NUM_MMA_Q -template void RunReadPatternTest() -{ - constexpr uint32_t grid_width = - (HEAD_DIM / QUERY_ELEMS_PER_THREAD); // 16 for 64, 32 for 128 - constexpr uint32_t grid_height = - 16 * NUM_MMA_Q; // 16 for NUM_MMA_Q=1, 32 for NUM_MMA_Q=2 - - printf("\n=== Testing query read pattern with HEAD_DIM = %u, NUM_MMA_Q = " - "%u ===\n", - HEAD_DIM, NUM_MMA_Q); - - // Host array to store thread IDs at each offset - std::vector thread_ids(grid_height * grid_width, -1); - - // Run CPU simulation of read pattern - SimulateReadPattern(thread_ids); - - // Print the grid of thread IDs - printf("Thread IDs reading from each offset (%dx%d grid):\n", grid_height, - grid_width); - - // Column headers - printf(" "); - for (int c = 0; c < grid_width; c++) { - printf("%3d ", c); - if (c == 15 && grid_width > 16) - printf("| "); // Divider for HEAD_DIM=128 - } - printf("\n +"); +template +void RunReadPatternTest() { + constexpr uint32_t grid_width = (HEAD_DIM / QUERY_ELEMS_PER_THREAD); // 16 for 64, 32 for 128 + constexpr uint32_t grid_height = 16 * NUM_MMA_Q; // 16 for NUM_MMA_Q=1, 32 for NUM_MMA_Q=2 + + printf( + "\n=== Testing query read pattern with HEAD_DIM = %u, NUM_MMA_Q = " + "%u ===\n", + HEAD_DIM, NUM_MMA_Q); + + // Host array to store thread IDs at each offset + std::vector thread_ids(grid_height * grid_width, -1); + + // Run CPU simulation of read pattern + SimulateReadPattern(thread_ids); + + // Print the grid of thread IDs + printf("Thread IDs reading from each offset (%dx%d grid):\n", grid_height, grid_width); + + // Column headers + printf(" "); + for (int c = 0; c < grid_width; c++) { + printf("%3d ", c); + if (c == 15 && grid_width > 16) printf("| "); // Divider for HEAD_DIM=128 + } + printf("\n +"); + for (int c = 0; c < grid_width; c++) { + printf("----"); + if (c == 15 && grid_width > 16) printf("+"); + } + printf("\n"); + + // Print the grid + for (int r = 0; r < grid_height; r++) { + printf("%2d | ", r); for (int c = 0; c < grid_width; c++) { - printf("----"); - if (c == 15 && grid_width > 16) - printf("+"); + int thread_id = thread_ids[r * grid_width + c]; + if (thread_id >= 0) { + printf("%3d ", thread_id); + } else { + printf(" . "); // Dot for unread positions + } + if (c == 15 && grid_width > 16) printf("| "); // Divider for HEAD_DIM=128 } printf("\n"); + } - // Print the grid - for (int r = 0; r < grid_height; r++) { - printf("%2d | ", r); - for (int c = 0; c < grid_width; c++) { - int thread_id = thread_ids[r * grid_width + c]; - if (thread_id >= 0) { - printf("%3d ", thread_id); - } - else { - printf(" . "); // Dot for unread positions - } - if (c == 15 && grid_width > 16) - printf("| "); // Divider for HEAD_DIM=128 - } - printf("\n"); - } - - // Check for unread positions - int unread = 0; - for (int i = 0; i < grid_height * grid_width; i++) { - if (thread_ids[i] == -1) { - unread++; - } + // Check for unread positions + int unread = 0; + for (int i = 0; i < grid_height * grid_width; i++) { + if (thread_ids[i] == -1) { + unread++; } - - // Print statistics - printf("\nStatistics:\n"); - printf("- Positions read: %d/%d (%.1f%%)\n", - grid_height * grid_width - unread, grid_height * grid_width, - 100.0f * (grid_height * grid_width - unread) / - (grid_height * grid_width)); - printf("- Unread positions: %d/%d (%.1f%%)\n", unread, - grid_height * grid_width, - 100.0f * unread / (grid_height * grid_width)); - - // Validate full coverage - EXPECT_EQ(unread, 0) << "Not all positions were read"; + } + + // Print statistics + printf("\nStatistics:\n"); + printf("- Positions read: %d/%d (%.1f%%)\n", grid_height * grid_width - unread, + grid_height * grid_width, + 100.0f * (grid_height * grid_width - unread) / (grid_height * grid_width)); + printf("- Unread positions: %d/%d (%.1f%%)\n", unread, grid_height * grid_width, + 100.0f * unread / (grid_height * grid_width)); + + // Validate full coverage + EXPECT_EQ(unread, 0) << "Not all positions were read"; } // Tests for different configurations TEST(MI300ReadPatternTest, HeadDim64_NumMmaQ1) { RunReadPatternTest<64, 1>(); } -TEST(MI300ReadPatternTest, HeadDim128_NumMmaQ1) -{ - RunReadPatternTest<128, 1>(); -} +TEST(MI300ReadPatternTest, HeadDim128_NumMmaQ1) { RunReadPatternTest<128, 1>(); } TEST(MI300ReadPatternTest, HeadDim64_NumMmaQ2) { RunReadPatternTest<64, 2>(); } -TEST(MI300ReadPatternTest, HeadDim128_NumMmaQ2) -{ - RunReadPatternTest<128, 2>(); -} +TEST(MI300ReadPatternTest, HeadDim128_NumMmaQ2) { RunReadPatternTest<128, 2>(); } -int main(int argc, char **argv) -{ - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); } diff --git a/libflashinfer/tests/hip/test_rowsum.cpp b/libflashinfer/tests/hip/test_rowsum.cpp index 5a08204bb7..4a7d4f8e80 100644 --- a/libflashinfer/tests/hip/test_rowsum.cpp +++ b/libflashinfer/tests/hip/test_rowsum.cpp @@ -3,8 +3,7 @@ // // SPDX - License - Identifier : Apache 2.0 -#include "gpu_iface/mma_ops.hpp" - +#include #include #include @@ -12,19 +11,18 @@ #include #include -#include +#include "gpu_iface/mma_ops.hpp" // Check HIP errors -#define HIP_CHECK(command) \ - { \ - hipError_t status = command; \ - if (status != hipSuccess) { \ - std::cerr << "Error: HIP reports " << hipGetErrorString(status) \ - << std::endl; \ - std::cerr << "at " << __FILE__ << ":" << __LINE__ << std::endl; \ - exit(EXIT_FAILURE); \ - } \ - } +#define HIP_CHECK(command) \ + { \ + hipError_t status = command; \ + if (status != hipSuccess) { \ + std::cerr << "Error: HIP reports " << hipGetErrorString(status) << std::endl; \ + std::cerr << "at " << __FILE__ << ":" << __LINE__ << std::endl; \ + exit(EXIT_FAILURE); \ + } \ + } // Dimensions for our test matrices constexpr int M = 16; @@ -37,137 +35,118 @@ constexpr int LDB = N; constexpr int LDC = N; // Host reference implementation for matrix multiplication -void gemm_reference(const __half *A, - const __half *B, - float *C, - int M, - int N, - int K, - int lda, - int ldb, - int ldc) -{ - for (int i = 0; i < M; ++i) { - for (int j = 0; j < N; ++j) { - float acc = 0.0f; - for (int k = 0; k < K; ++k) { - // Use __half_as_float to properly convert __half to float - acc += __half2float(A[i * K + k]) * __half2float(B[k * N + j]); - } - C[i * N + j] = acc; - } +void gemm_reference(const __half* A, const __half* B, float* C, int M, int N, int K, int lda, + int ldb, int ldc) { + for (int i = 0; i < M; ++i) { + for (int j = 0; j < N; ++j) { + float acc = 0.0f; + for (int k = 0; k < K; ++k) { + // Use __half_as_float to properly convert __half to float + acc += __half2float(A[i * K + k]) * __half2float(B[k * N + j]); + } + C[i * N + j] = acc; } + } } -__global__ void test_mfma_kernel(const __half *A, float *C) -{ - uint32_t a_reg[2]; - float c_reg[4] = {0.0f, 0.0f, 0.0f, 0.0f}; - - // A Matrix is read row wise. Threads T0...T15 read Col 0...3 of Row 0...15 - // Threads T16...T31 read Col 4...7 of Row 0...15 - // Threads T32...T47 read Col 8...11 of Row 0...15 - // Threads T48...T63 read Col 12...15 of Row 0...15 - int lane_row = threadIdx.x % 16; - int col_group = threadIdx.x / 16; // 0..3 - int a_idx = col_group * 4 + lane_row * LDA; - - flashinfer::gpu_iface::mma::load_fragment<__half>(a_reg, &A[a_idx]); - flashinfer::gpu_iface::mma::m16k16_rowsum_f16f16f32<__half>( - c_reg, reinterpret_cast<__half *>(a_reg)); - for (int i = 0; i < 4; ++i) { - const int d_idx = - threadIdx.x % 16 + i * LDC + (threadIdx.x / 16) * 4 * LDC; - C[d_idx] = c_reg[i]; - } +__global__ void test_mfma_kernel(const __half* A, float* C) { + uint32_t a_reg[2]; + float c_reg[4] = {0.0f, 0.0f, 0.0f, 0.0f}; + + // A Matrix is read row wise. Threads T0...T15 read Col 0...3 of Row 0...15 + // Threads T16...T31 read Col 4...7 of Row 0...15 + // Threads T32...T47 read Col 8...11 of Row 0...15 + // Threads T48...T63 read Col 12...15 of Row 0...15 + int lane_row = threadIdx.x % 16; + int col_group = threadIdx.x / 16; // 0..3 + int a_idx = col_group * 4 + lane_row * LDA; + + flashinfer::gpu_iface::mma::load_fragment<__half>(a_reg, &A[a_idx]); + flashinfer::gpu_iface::mma::m16k16_rowsum_f16f16f32<__half>(c_reg, + reinterpret_cast<__half*>(a_reg)); + for (int i = 0; i < 4; ++i) { + const int d_idx = threadIdx.x % 16 + i * LDC + (threadIdx.x / 16) * 4 * LDC; + C[d_idx] = c_reg[i]; + } } // Test class -class MfmaRowSumTest : public ::testing::Test -{ -protected: - std::vector<__half> A_host; - std::vector<__half> B_host; - std::vector C_host; - std::vector C_ref; - - __half *A_dev = nullptr; - float *C_dev = nullptr; - - void SetUp() override - { - // Initialize host data - A_host.resize(M * K); - B_host.resize(K * N); - C_host.resize(M * N, 0.0f); - C_ref.resize(M * N, 0.0f); - - // Fill with deterministic values - std::mt19937 gen(42); - std::uniform_real_distribution dist(-1.0f, 1.0f); - - for (int i = 0; i < M * K; ++i) { - A_host[i] = __float2half(dist(gen)); - } - - for (int i = 0; i < K * N; ++i) { - B_host[i] = __float2half(1.0f); - } - - // Calculate reference result - gemm_reference(A_host.data(), B_host.data(), C_ref.data(), M, N, K, LDA, - LDB, LDC); - - // Allocate device memory - HIP_CHECK(hipMalloc(&A_dev, M * K * sizeof(__half))); - HIP_CHECK(hipMalloc(&C_dev, M * N * sizeof(float))); - - // Copy input data to device - HIP_CHECK(hipMemcpy(A_dev, A_host.data(), M * K * sizeof(__half), - hipMemcpyHostToDevice)); - HIP_CHECK(hipMemset(C_dev, 0, M * N * sizeof(float))); +class MfmaRowSumTest : public ::testing::Test { + protected: + std::vector<__half> A_host; + std::vector<__half> B_host; + std::vector C_host; + std::vector C_ref; + + __half* A_dev = nullptr; + float* C_dev = nullptr; + + void SetUp() override { + // Initialize host data + A_host.resize(M * K); + B_host.resize(K * N); + C_host.resize(M * N, 0.0f); + C_ref.resize(M * N, 0.0f); + + // Fill with deterministic values + std::mt19937 gen(42); + std::uniform_real_distribution dist(-1.0f, 1.0f); + + for (int i = 0; i < M * K; ++i) { + A_host[i] = __float2half(dist(gen)); } - void TearDown() override - { - // Free device memory - HIP_CHECK(hipFree(A_dev)); - HIP_CHECK(hipFree(C_dev)); + for (int i = 0; i < K * N; ++i) { + B_host[i] = __float2half(1.0f); } + + // Calculate reference result + gemm_reference(A_host.data(), B_host.data(), C_ref.data(), M, N, K, LDA, LDB, LDC); + + // Allocate device memory + HIP_CHECK(hipMalloc(&A_dev, M * K * sizeof(__half))); + HIP_CHECK(hipMalloc(&C_dev, M * N * sizeof(float))); + + // Copy input data to device + HIP_CHECK(hipMemcpy(A_dev, A_host.data(), M * K * sizeof(__half), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemset(C_dev, 0, M * N * sizeof(float))); + } + + void TearDown() override { + // Free device memory + HIP_CHECK(hipFree(A_dev)); + HIP_CHECK(hipFree(C_dev)); + } }; // Test that verifies mfma_fp32_16x16x16fp16 calculates correct results -TEST_F(MfmaRowSumTest, CorrectResults) -{ - // Launch kernel with one block of 64 threads (one wavefront) - dim3 gridDim(1); - dim3 blockDim(64); - test_mfma_kernel<<>>(A_dev, C_dev); - - // Copy results back to host - HIP_CHECK(hipMemcpy(C_host.data(), C_dev, M * N * sizeof(float), - hipMemcpyDeviceToHost)); - - // Verify results with small tolerance for floating point differences - const float tolerance = 1e-3f; - bool all_pass = true; - for (int i = 0; i < M * N; ++i) { - float diff = std::abs(C_host[i] - C_ref[i]); - if (diff > tolerance) { - std::cout << "Mismatch at index " << i << ": " - << "Actual=" << C_host[i] << ", Expected=" << C_ref[i] - << ", Diff=" << diff << std::endl; - all_pass = false; - } +TEST_F(MfmaRowSumTest, CorrectResults) { + // Launch kernel with one block of 64 threads (one wavefront) + dim3 gridDim(1); + dim3 blockDim(64); + test_mfma_kernel<<>>(A_dev, C_dev); + + // Copy results back to host + HIP_CHECK(hipMemcpy(C_host.data(), C_dev, M * N * sizeof(float), hipMemcpyDeviceToHost)); + + // Verify results with small tolerance for floating point differences + const float tolerance = 1e-3f; + bool all_pass = true; + for (int i = 0; i < M * N; ++i) { + float diff = std::abs(C_host[i] - C_ref[i]); + if (diff > tolerance) { + std::cout << "Mismatch at index " << i << ": " + << "Actual=" << C_host[i] << ", Expected=" << C_ref[i] << ", Diff=" << diff + << std::endl; + all_pass = false; } + } - EXPECT_TRUE(all_pass) - << "Matrix multiplication results don't match reference implementation"; + EXPECT_TRUE(all_pass) << "Matrix multiplication results don't match reference implementation"; } // Main function that runs all tests -int main(int argc, char **argv) -{ - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); } diff --git a/libflashinfer/tests/hip/test_single_decode.cpp b/libflashinfer/tests/hip/test_single_decode.cpp index 3a3557e948..2a89a5efdc 100644 --- a/libflashinfer/tests/hip/test_single_decode.cpp +++ b/libflashinfer/tests/hip/test_single_decode.cpp @@ -3,13 +3,7 @@ // // SPDX - License - Identifier : Apache 2.0 -#include "flashinfer/attention/generic/decode.cuh" -#include "flashinfer/attention/generic/default_decode_params.cuh" -#include "flashinfer/attention/generic/variants.cuh" - -#include "../../utils/cpu_reference_hip.h" -#include "../../utils/utils_hip.h" - +#include #include #include #include @@ -18,223 +12,179 @@ #include #include -#include +#include "../../utils/cpu_reference_hip.h" +#include "../../utils/utils_hip.h" +#include "flashinfer/attention/generic/decode.cuh" +#include "flashinfer/attention/generic/default_decode_params.cuh" +#include "flashinfer/attention/generic/variants.cuh" using namespace flashinfer; -namespace test::ops -{ +namespace test::ops { template -hipError_t SingleDecodeWithKVCache( - DTypeQ *q, - DTypeKV *k, - DTypeKV *v, - DTypeO *o, - DTypeO *tmp, - uint32_t num_qo_heads, - uint32_t num_kv_heads, - uint32_t seq_len, - uint32_t head_dim, - QKVLayout kv_layout = QKVLayout::kNHD, - PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, - std::optional maybe_sm_scale = std::nullopt, - float rope_scale = 1.f, - float rope_theta = 1e4, - hipStream_t stream = nullptr) -{ - float sm_scale = maybe_sm_scale.value_or(1.f / std::sqrt(float(head_dim))); - if (num_qo_heads % num_kv_heads != 0) { - std::ostringstream err_msg; - err_msg << "num_qo_heads " << num_qo_heads - << " is not a multiple of num_kv_heads " << num_kv_heads; - FLASHINFER_ERROR(err_msg.str()); - } - - DISPATCH_head_dim( - head_dim, HEAD_DIM, - {DISPATCH_pos_encoding_mode(pos_encoding_mode, POS_ENCODING_MODE, { - using Params = SingleDecodeParams; - using AttentionVariant = DefaultAttention< - /*use_custom_mask=*/false, /*use_sliding_window=*/false, - /*use_logits_soft_cap=*/false, /*use_alibi=*/false>; - Params params(q, k, v, o, /*alibi_slopes=*/nullptr, seq_len, - num_qo_heads, num_kv_heads, kv_layout, head_dim, - /*window_left=*/-1, /*logits_soft_cap=*/0.f, sm_scale, - rope_scale, rope_theta); - - SingleDecodeWithKVCacheDispatched(params, tmp, - stream); - })}); - return hipSuccess; +hipError_t SingleDecodeWithKVCache(DTypeQ* q, DTypeKV* k, DTypeKV* v, DTypeO* o, DTypeO* tmp, + uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t seq_len, + uint32_t head_dim, QKVLayout kv_layout = QKVLayout::kNHD, + PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, + std::optional maybe_sm_scale = std::nullopt, + float rope_scale = 1.f, float rope_theta = 1e4, + hipStream_t stream = nullptr) { + float sm_scale = maybe_sm_scale.value_or(1.f / std::sqrt(float(head_dim))); + if (num_qo_heads % num_kv_heads != 0) { + std::ostringstream err_msg; + err_msg << "num_qo_heads " << num_qo_heads << " is not a multiple of num_kv_heads " + << num_kv_heads; + FLASHINFER_ERROR(err_msg.str()); + } + + DISPATCH_head_dim( + head_dim, HEAD_DIM, {DISPATCH_pos_encoding_mode(pos_encoding_mode, POS_ENCODING_MODE, { + using Params = SingleDecodeParams; + using AttentionVariant = DefaultAttention< + /*use_custom_mask=*/false, /*use_sliding_window=*/false, + /*use_logits_soft_cap=*/false, /*use_alibi=*/false>; + Params params(q, k, v, o, /*alibi_slopes=*/nullptr, seq_len, num_qo_heads, num_kv_heads, + kv_layout, head_dim, + /*window_left=*/-1, /*logits_soft_cap=*/0.f, sm_scale, rope_scale, + rope_theta); + + SingleDecodeWithKVCacheDispatched( + params, tmp, stream); + })}); + return hipSuccess; } -} // namespace test::ops +} // namespace test::ops template -std::vector getCPUReference(const std::vector &Q_host, - const std::vector &K_host, - const std::vector &V_host, - size_t num_qo_heads, - size_t num_kv_heads, - size_t seq_len, - size_t head_dim, - QKVLayout kv_layout, - PosEncodingMode pos_encoding_mode) -{ - - return cpu_reference::single_mha( - Q_host, K_host, V_host, 1, seq_len, num_qo_heads, num_kv_heads, - head_dim, false, kv_layout, pos_encoding_mode); +std::vector getCPUReference(const std::vector& Q_host, + const std::vector& K_host, + const std::vector& V_host, size_t num_qo_heads, + size_t num_kv_heads, size_t seq_len, size_t head_dim, + QKVLayout kv_layout, PosEncodingMode pos_encoding_mode) { + return cpu_reference::single_mha(Q_host, K_host, V_host, 1, seq_len, + num_qo_heads, num_kv_heads, head_dim, + false, kv_layout, pos_encoding_mode); } template -std::pair -nan_detection_and_accuracy(const std::vector &cpu_results, - const std::vector &gpu_results, - uint64_t num_qo_heads, - uint64_t head_dim) -{ - - size_t num_result_errors_atol_1e_3_rtol_1e_3 = 0; - bool nan_detected = false; - - for (size_t i = 0; i < num_qo_heads * head_dim; ++i) { - float cpu_result = - fi::con::explicit_casting(cpu_results[i]); - float gpu_result = - fi::con::explicit_casting(gpu_results[i]); - - if (isnan(gpu_result)) { - nan_detected = true; - } - num_result_errors_atol_1e_3_rtol_1e_3 += - (!utils::isclose(gpu_result, cpu_result, 1e-2, 1e-2)); +std::pair nan_detection_and_accuracy(const std::vector& cpu_results, + const std::vector& gpu_results, + uint64_t num_qo_heads, uint64_t head_dim) { + size_t num_result_errors_atol_1e_3_rtol_1e_3 = 0; + bool nan_detected = false; + + for (size_t i = 0; i < num_qo_heads * head_dim; ++i) { + float cpu_result = fi::con::explicit_casting(cpu_results[i]); + float gpu_result = fi::con::explicit_casting(gpu_results[i]); + + if (isnan(gpu_result)) { + nan_detected = true; } + num_result_errors_atol_1e_3_rtol_1e_3 += (!utils::isclose(gpu_result, cpu_result, 1e-2, 1e-2)); + } - float result_accuracy = 1. - float(num_result_errors_atol_1e_3_rtol_1e_3) / - float(num_qo_heads * head_dim); + float result_accuracy = + 1. - float(num_result_errors_atol_1e_3_rtol_1e_3) / float(num_qo_heads * head_dim); - return {result_accuracy, nan_detected}; + return {result_accuracy, nan_detected}; } template -void _TestDecodingKernelCorrectness(size_t num_qo_heads, - size_t num_kv_heads, - size_t seq_len, - size_t head_dim, - QKVLayout kv_layout, - PosEncodingMode pos_encoding_mode) -{ - - std::vector Q_host(num_qo_heads * head_dim); - std::vector K_host(seq_len * num_kv_heads * head_dim); - std::vector V_host(seq_len * num_kv_heads * head_dim); - std::vector O_host(num_qo_heads * head_dim); - - utils::vec_normal_(Q_host); - utils::vec_normal_(K_host); - utils::vec_normal_(V_host); - utils::vec_zero_(O_host); - - DTypeQO *Q; - DTypeKV *K; - DTypeKV *V; - DTypeQO *O; - DTypeQO *tmp; - - hipMalloc(&Q, num_qo_heads * head_dim * sizeof(DTypeQO)); - hipMalloc(&K, seq_len * num_kv_heads * head_dim * sizeof(DTypeKV)); - hipMalloc(&V, seq_len * num_kv_heads * head_dim * sizeof(DTypeKV)); - hipMalloc(&O, num_qo_heads * head_dim * sizeof(DTypeQO)); - hipMalloc(&tmp, num_qo_heads * head_dim * sizeof(DTypeQO)); - - hipMemcpy(Q, Q_host.data(), num_qo_heads * head_dim * sizeof(DTypeQO), - hipMemcpyHostToDevice); - hipMemcpy(K, K_host.data(), - seq_len * num_kv_heads * head_dim * sizeof(DTypeKV), - hipMemcpyHostToDevice); - hipMemcpy(V, V_host.data(), - seq_len * num_kv_heads * head_dim * sizeof(DTypeKV), - hipMemcpyHostToDevice); - hipMemcpy(O, O_host.data(), num_qo_heads * head_dim * sizeof(DTypeQO), - hipMemcpyHostToDevice); - hipMemcpy(tmp, O_host.data(), num_qo_heads * head_dim * sizeof(DTypeQO), - hipMemcpyHostToDevice); - - std::vector o_ref_host = getCPUReference( - Q_host, K_host, V_host, num_qo_heads, num_kv_heads, seq_len, head_dim, - QKVLayout(kv_layout), PosEncodingMode(pos_encoding_mode)); - - hipError_t status = - test::ops::SingleDecodeWithKVCache( - Q, K, V, O, tmp, num_qo_heads, num_kv_heads, seq_len, head_dim, - kv_layout, pos_encoding_mode); - - if (status != hipSuccess) { - std::cout - << "SingleDecodeWithKVCache kernel launch failed, error message: " - << hipGetErrorString(status) << std::endl; - } - - std::vector o_host(num_qo_heads * head_dim); - hipMemcpy(o_host.data(), O, num_qo_heads * head_dim * sizeof(DTypeQO), - hipMemcpyDeviceToHost); - - auto [result_accuracy, nan_detected] = - nan_detection_and_accuracy(o_ref_host, o_host, num_qo_heads, head_dim); - - std::cout << "num_qo_heads=" << num_qo_heads - << ", num_kv_heads=" << num_kv_heads << ", seq_len=" << seq_len - << ", head_dim=" << head_dim - << ", kv_layout=" << QKVLayoutToString(kv_layout) - << ", pos_encoding_mode=" - << PosEncodingModeToString(pos_encoding_mode) - << ", result accuracy (atol=1e-3, rtol=1e-3): " << result_accuracy - << std::endl; - EXPECT_GT(result_accuracy, 0.90) << "Result correctness test failed."; - EXPECT_FALSE(nan_detected) << "NaN detected."; - - hipFree(Q); - hipFree(K); - hipFree(V); - hipFree(O); - hipFree(tmp); +void _TestDecodingKernelCorrectness(size_t num_qo_heads, size_t num_kv_heads, size_t seq_len, + size_t head_dim, QKVLayout kv_layout, + PosEncodingMode pos_encoding_mode) { + std::vector Q_host(num_qo_heads * head_dim); + std::vector K_host(seq_len * num_kv_heads * head_dim); + std::vector V_host(seq_len * num_kv_heads * head_dim); + std::vector O_host(num_qo_heads * head_dim); + + utils::vec_normal_(Q_host); + utils::vec_normal_(K_host); + utils::vec_normal_(V_host); + utils::vec_zero_(O_host); + + DTypeQO* Q; + DTypeKV* K; + DTypeKV* V; + DTypeQO* O; + DTypeQO* tmp; + + hipMalloc(&Q, num_qo_heads * head_dim * sizeof(DTypeQO)); + hipMalloc(&K, seq_len * num_kv_heads * head_dim * sizeof(DTypeKV)); + hipMalloc(&V, seq_len * num_kv_heads * head_dim * sizeof(DTypeKV)); + hipMalloc(&O, num_qo_heads * head_dim * sizeof(DTypeQO)); + hipMalloc(&tmp, num_qo_heads * head_dim * sizeof(DTypeQO)); + + hipMemcpy(Q, Q_host.data(), num_qo_heads * head_dim * sizeof(DTypeQO), hipMemcpyHostToDevice); + hipMemcpy(K, K_host.data(), seq_len * num_kv_heads * head_dim * sizeof(DTypeKV), + hipMemcpyHostToDevice); + hipMemcpy(V, V_host.data(), seq_len * num_kv_heads * head_dim * sizeof(DTypeKV), + hipMemcpyHostToDevice); + hipMemcpy(O, O_host.data(), num_qo_heads * head_dim * sizeof(DTypeQO), hipMemcpyHostToDevice); + hipMemcpy(tmp, O_host.data(), num_qo_heads * head_dim * sizeof(DTypeQO), hipMemcpyHostToDevice); + + std::vector o_ref_host = getCPUReference( + Q_host, K_host, V_host, num_qo_heads, num_kv_heads, seq_len, head_dim, QKVLayout(kv_layout), + PosEncodingMode(pos_encoding_mode)); + + hipError_t status = test::ops::SingleDecodeWithKVCache( + Q, K, V, O, tmp, num_qo_heads, num_kv_heads, seq_len, head_dim, kv_layout, pos_encoding_mode); + + if (status != hipSuccess) { + std::cout << "SingleDecodeWithKVCache kernel launch failed, error message: " + << hipGetErrorString(status) << std::endl; + } + + std::vector o_host(num_qo_heads * head_dim); + hipMemcpy(o_host.data(), O, num_qo_heads * head_dim * sizeof(DTypeQO), hipMemcpyDeviceToHost); + + auto [result_accuracy, nan_detected] = + nan_detection_and_accuracy(o_ref_host, o_host, num_qo_heads, head_dim); + + std::cout << "num_qo_heads=" << num_qo_heads << ", num_kv_heads=" << num_kv_heads + << ", seq_len=" << seq_len << ", head_dim=" << head_dim + << ", kv_layout=" << QKVLayoutToString(kv_layout) + << ", pos_encoding_mode=" << PosEncodingModeToString(pos_encoding_mode) + << ", result accuracy (atol=1e-3, rtol=1e-3): " << result_accuracy << std::endl; + EXPECT_GT(result_accuracy, 0.90) << "Result correctness test failed."; + EXPECT_FALSE(nan_detected) << "NaN detected."; + + hipFree(Q); + hipFree(K); + hipFree(V); + hipFree(O); + hipFree(tmp); } template -void TestSingleDecodeKernelCorrectness() -{ - for (size_t num_qo_heads : {32}) { - for (size_t num_kv_heads : {4, 8, 32}) { - for (size_t seq_len : {1, 3, 9, 27, 81, 129, 257, 512, 1024, 2048, - 4096, 8192, 16384, 32768}) - { - for (size_t head_dim : {64, 128, 256}) { - for (unsigned int kv_layout : {0U, 1U}) { - for (unsigned int pos_encoding_mode : {0U, 1U}) { - if (std::is_same::value) { - pos_encoding_mode = 0U; - } - _TestDecodingKernelCorrectness( - num_qo_heads, num_kv_heads, seq_len, head_dim, - QKVLayout(kv_layout), - PosEncodingMode(pos_encoding_mode)); - } - } - } +void TestSingleDecodeKernelCorrectness() { + for (size_t num_qo_heads : {32}) { + for (size_t num_kv_heads : {4, 8, 32}) { + for (size_t seq_len : + {1, 3, 9, 27, 81, 129, 257, 512, 1024, 2048, 4096, 8192, 16384, 32768}) { + for (size_t head_dim : {64, 128, 256}) { + for (unsigned int kv_layout : {0U, 1U}) { + for (unsigned int pos_encoding_mode : {0U, 1U}) { + if (std::is_same::value) { + pos_encoding_mode = 0U; + } + _TestDecodingKernelCorrectness(num_qo_heads, num_kv_heads, seq_len, + head_dim, QKVLayout(kv_layout), + PosEncodingMode(pos_encoding_mode)); } + } } + } } + } } -TEST(FlashInferCorrectnessTest, SingleDecodeKernelCorrectnessTestFP16) -{ - TestSingleDecodeKernelCorrectness<__half, __half>(); +TEST(FlashInferCorrectnessTest, SingleDecodeKernelCorrectnessTestFP16) { + TestSingleDecodeKernelCorrectness<__half, __half>(); } -TEST(FlashInferCorrectnessTest, SingleDecodeKernelCorrectnessTestBF16) -{ - TestSingleDecodeKernelCorrectness<__hip_bfloat16, __hip_bfloat16>(); +TEST(FlashInferCorrectnessTest, SingleDecodeKernelCorrectnessTestBF16) { + TestSingleDecodeKernelCorrectness<__hip_bfloat16, __hip_bfloat16>(); } //***************************************************************************** @@ -249,8 +199,7 @@ TEST(FlashInferCorrectnessTest, SingleDecodeKernelCorrectnessTestBF16) // } //***************************************************************************** -int main(int argc, char **argv) -{ - testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); } diff --git a/libflashinfer/tests/hip/test_single_prefill.cpp b/libflashinfer/tests/hip/test_single_prefill.cpp index 5bb86ce1ac..ce747f98ab 100644 --- a/libflashinfer/tests/hip/test_single_prefill.cpp +++ b/libflashinfer/tests/hip/test_single_prefill.cpp @@ -3,16 +3,16 @@ // // SPDX - License - Identifier : Apache 2.0 +#include + +#include + #include "../../utils/cpu_reference_hip.h" #include "../../utils/flashinfer_prefill_ops.hip.h" #include "../../utils/utils_hip.h" #include "flashinfer/attention/generic/prefill.cuh" #include "gpu_iface/gpu_runtime_compat.hpp" -#include - -#include - #define HIP_ENABLE_WARP_SYNC_BUILTINS 1 using namespace flashinfer; @@ -178,110 +178,92 @@ void _TestComputeQKCorrectness(size_t qo_len, #endif template -void _TestSinglePrefillKernelCorrectness(size_t qo_len, - size_t kv_len, - size_t num_qo_heads, - size_t num_kv_heads, - size_t head_dim, - bool causal, - QKVLayout kv_layout, - PosEncodingMode pos_encoding_mode, - bool use_fp16_qk_reduction, - uint32_t debug_thread_id, - uint32_t debug_warp_id, - float rtol = 1e-3, - float atol = 1e-3) -{ - std::vector q(qo_len * num_qo_heads * head_dim); - std::vector k(kv_len * num_kv_heads * head_dim); - std::vector v(kv_len * num_kv_heads * head_dim); - std::vector o(qo_len * num_qo_heads * head_dim); - - utils::vec_normal_(q); - utils::vec_normal_(k); - utils::vec_normal_(v); - // utils::vec_lexicographic_(q); - // utils::vec_lexicographic_(k); - // utils::vec_fill_(v, __float2half(1.0f)); - utils::vec_zero_(o); - - DTypeQ *q_d; - FI_GPU_CALL(hipMalloc(&q_d, q.size() * sizeof(DTypeQ))); - FI_GPU_CALL(hipMemcpy(q_d, q.data(), q.size() * sizeof(DTypeQ), - hipMemcpyHostToDevice)); - - DTypeKV *k_d; - FI_GPU_CALL(hipMalloc(&k_d, k.size() * sizeof(DTypeKV))); - FI_GPU_CALL(hipMemcpy(k_d, k.data(), k.size() * sizeof(DTypeKV), - hipMemcpyHostToDevice)); - - DTypeKV *v_d; - FI_GPU_CALL(hipMalloc(&v_d, v.size() * sizeof(DTypeKV))); - FI_GPU_CALL(hipMemcpy(v_d, v.data(), v.size() * sizeof(DTypeKV), - hipMemcpyHostToDevice)); - - DTypeO *o_d; - FI_GPU_CALL(hipMalloc(&o_d, o.size() * sizeof(DTypeO))); - FI_GPU_CALL(hipMemcpy(o_d, o.data(), o.size() * sizeof(DTypeO), - hipMemcpyHostToDevice)); - - DTypeO *tmp_d; - FI_GPU_CALL(hipMalloc(&tmp_d, 16 * 1024 * 1024 * sizeof(DTypeO))); - - hipError_t status = - flashinfer::SinglePrefillWithKVCache( - q_d, k_d, v_d, o_d, tmp_d, - /*lse=*/nullptr, num_qo_heads, num_kv_heads, qo_len, kv_len, - head_dim, causal, kv_layout, pos_encoding_mode, - use_fp16_qk_reduction, debug_thread_id, debug_warp_id); - - EXPECT_EQ(status, hipSuccess) - << "SinglePrefillWithKVCache kernel launch failed, error message: " - << hipGetErrorString(status); - - std::vector o_h(o.size()); - FI_GPU_CALL(hipMemcpy(o_h.data(), o_d, o_h.size() * sizeof(DTypeO), - hipMemcpyDeviceToHost)); - - // Print the first 10 elements of the output vector for debugging - // std::cout << "Output vector (first 10 elements):"; - // std::cout << "[" << std::endl; - // for (int i = 0; i < 10; ++i) { - // std::cout << fi::con::explicit_casting(o_h[i]) << " - // "; - // } - // std::cout << "]" << std::endl; - - bool isEmpty = o_h.empty(); - EXPECT_EQ(isEmpty, false) << "Output vector is empty"; - - std::vector att_out; - std::vector o_ref = - cpu_reference::single_mha( - q, k, v, qo_len, kv_len, num_qo_heads, num_kv_heads, head_dim, - causal, kv_layout, pos_encoding_mode); - size_t num_results_error_atol = 0; - bool nan_detected = false; - - for (size_t i = 0; i < o_ref.size(); ++i) { - float o_h_val = fi::con::explicit_casting(o_h[i]); - float o_ref_val = fi::con::explicit_casting(o_ref[i]); - - if (isnan(o_h_val)) { - nan_detected = true; - } - - num_results_error_atol += - (!utils::isclose(o_ref_val, o_h_val, rtol, atol)); - // if (!utils::isclose(o_ref_val, o_h_val, rtol, atol)) { - // std::cout << "i=" << i << ", o_ref[i]=" << o_ref_val - // << ", o_h[i]=" << o_h_val << std::endl; - // } +void _TestSinglePrefillKernelCorrectness(size_t qo_len, size_t kv_len, size_t num_qo_heads, + size_t num_kv_heads, size_t head_dim, bool causal, + QKVLayout kv_layout, PosEncodingMode pos_encoding_mode, + bool use_fp16_qk_reduction, uint32_t debug_thread_id, + uint32_t debug_warp_id, float rtol = 1e-3, + float atol = 1e-3) { + std::vector q(qo_len * num_qo_heads * head_dim); + std::vector k(kv_len * num_kv_heads * head_dim); + std::vector v(kv_len * num_kv_heads * head_dim); + std::vector o(qo_len * num_qo_heads * head_dim); + + utils::vec_normal_(q); + utils::vec_normal_(k); + utils::vec_normal_(v); + // utils::vec_lexicographic_(q); + // utils::vec_lexicographic_(k); + // utils::vec_fill_(v, __float2half(1.0f)); + utils::vec_zero_(o); + + DTypeQ* q_d; + FI_GPU_CALL(hipMalloc(&q_d, q.size() * sizeof(DTypeQ))); + FI_GPU_CALL(hipMemcpy(q_d, q.data(), q.size() * sizeof(DTypeQ), hipMemcpyHostToDevice)); + + DTypeKV* k_d; + FI_GPU_CALL(hipMalloc(&k_d, k.size() * sizeof(DTypeKV))); + FI_GPU_CALL(hipMemcpy(k_d, k.data(), k.size() * sizeof(DTypeKV), hipMemcpyHostToDevice)); + + DTypeKV* v_d; + FI_GPU_CALL(hipMalloc(&v_d, v.size() * sizeof(DTypeKV))); + FI_GPU_CALL(hipMemcpy(v_d, v.data(), v.size() * sizeof(DTypeKV), hipMemcpyHostToDevice)); + + DTypeO* o_d; + FI_GPU_CALL(hipMalloc(&o_d, o.size() * sizeof(DTypeO))); + FI_GPU_CALL(hipMemcpy(o_d, o.data(), o.size() * sizeof(DTypeO), hipMemcpyHostToDevice)); + + DTypeO* tmp_d; + FI_GPU_CALL(hipMalloc(&tmp_d, 16 * 1024 * 1024 * sizeof(DTypeO))); + + hipError_t status = flashinfer::SinglePrefillWithKVCache( + q_d, k_d, v_d, o_d, tmp_d, + /*lse=*/nullptr, num_qo_heads, num_kv_heads, qo_len, kv_len, head_dim, causal, kv_layout, + pos_encoding_mode, use_fp16_qk_reduction, debug_thread_id, debug_warp_id); + + EXPECT_EQ(status, hipSuccess) << "SinglePrefillWithKVCache kernel launch failed, error message: " + << hipGetErrorString(status); + + std::vector o_h(o.size()); + FI_GPU_CALL(hipMemcpy(o_h.data(), o_d, o_h.size() * sizeof(DTypeO), hipMemcpyDeviceToHost)); + + // Print the first 10 elements of the output vector for debugging + // std::cout << "Output vector (first 10 elements):"; + // std::cout << "[" << std::endl; + // for (int i = 0; i < 10; ++i) { + // std::cout << fi::con::explicit_casting(o_h[i]) << " + // "; + // } + // std::cout << "]" << std::endl; + + bool isEmpty = o_h.empty(); + EXPECT_EQ(isEmpty, false) << "Output vector is empty"; + + std::vector att_out; + std::vector o_ref = cpu_reference::single_mha( + q, k, v, qo_len, kv_len, num_qo_heads, num_kv_heads, head_dim, causal, kv_layout, + pos_encoding_mode); + size_t num_results_error_atol = 0; + bool nan_detected = false; + + for (size_t i = 0; i < o_ref.size(); ++i) { + float o_h_val = fi::con::explicit_casting(o_h[i]); + float o_ref_val = fi::con::explicit_casting(o_ref[i]); + + if (isnan(o_h_val)) { + nan_detected = true; } - // std::cout<<"Printing att_out vector:\n"; - // for(auto i: att_out) { - // std::cout << i << "\n"; + + num_results_error_atol += (!utils::isclose(o_ref_val, o_h_val, rtol, atol)); + // if (!utils::isclose(o_ref_val, o_h_val, rtol, atol)) { + // std::cout << "i=" << i << ", o_ref[i]=" << o_ref_val + // << ", o_h[i]=" << o_h_val << std::endl; // } + } + // std::cout<<"Printing att_out vector:\n"; + // for(auto i: att_out) { + // std::cout << i << "\n"; + // } #if 0 float result_accuracy = 1. - float(num_results_error_atol) / float(o_ref.size()); @@ -297,11 +279,11 @@ void _TestSinglePrefillKernelCorrectness(size_t qo_len, EXPECT_GT(result_accuracy, 0.90) << "Result correctness test failed."; EXPECT_FALSE(nan_detected) << "Nan detected in the result."; #endif - FI_GPU_CALL(hipFree(q_d)); - FI_GPU_CALL(hipFree(k_d)); - FI_GPU_CALL(hipFree(v_d)); - FI_GPU_CALL(hipFree(o_d)); - FI_GPU_CALL(hipFree(tmp_d)); + FI_GPU_CALL(hipFree(q_d)); + FI_GPU_CALL(hipFree(k_d)); + FI_GPU_CALL(hipFree(v_d)); + FI_GPU_CALL(hipFree(o_d)); + FI_GPU_CALL(hipFree(tmp_d)); } // template @@ -554,47 +536,39 @@ void _TestSinglePrefillKernelCorrectness(size_t qo_len, // } // #endif -int main(int argc, char **argv) -{ - // ::testing::InitGoogleTest(&argc, argv); - // return RUN_ALL_TESTS(); - using DTypeIn = __half; - using DTypeO = __half; - uint32_t debug_thread_id = 0; - uint32_t debug_warp_id = 0; - bool use_fp16_qk_reduction = false; - size_t qo_len = 128; - size_t kv_len = 128; - size_t num_heads = 1; - size_t head_dim = 64; - bool causal = false; - size_t pos_encoding_mode = 0; // 1 == kRopeLLama - size_t kv_layout = 0; - - for (int i = 1; i < argc; i++) { - std::string arg = argv[i]; - - if (arg == "--thread" && i + 1 < argc) { - debug_thread_id = std::stoi(argv[++i]); - std::cout << "Debug thread ID set to: " << debug_thread_id - << std::endl; - } - else if (arg == "--warp" && i + 1 < argc) { - debug_warp_id = std::stoi(argv[++i]); - std::cout << "Debug warp ID set to: " << debug_warp_id << std::endl; - } - else if (arg == "--qo_len" && i + 1 < argc) { - qo_len = std::stoi(argv[++i]); - } - else if (arg == "--kv_len" && i + 1 < argc) { - kv_len = std::stoi(argv[++i]); - } - else if (arg == "--heads" && i + 1 < argc) { - num_heads = std::stoi(argv[++i]); - } - else if (arg == "--help") { - std::cout - << "Usage: " << argv[0] << " [options]\n" +int main(int argc, char** argv) { + // ::testing::InitGoogleTest(&argc, argv); + // return RUN_ALL_TESTS(); + using DTypeIn = __half; + using DTypeO = __half; + uint32_t debug_thread_id = 0; + uint32_t debug_warp_id = 0; + bool use_fp16_qk_reduction = false; + size_t qo_len = 128; + size_t kv_len = 128; + size_t num_heads = 1; + size_t head_dim = 64; + bool causal = false; + size_t pos_encoding_mode = 0; // 1 == kRopeLLama + size_t kv_layout = 0; + + for (int i = 1; i < argc; i++) { + std::string arg = argv[i]; + + if (arg == "--thread" && i + 1 < argc) { + debug_thread_id = std::stoi(argv[++i]); + std::cout << "Debug thread ID set to: " << debug_thread_id << std::endl; + } else if (arg == "--warp" && i + 1 < argc) { + debug_warp_id = std::stoi(argv[++i]); + std::cout << "Debug warp ID set to: " << debug_warp_id << std::endl; + } else if (arg == "--qo_len" && i + 1 < argc) { + qo_len = std::stoi(argv[++i]); + } else if (arg == "--kv_len" && i + 1 < argc) { + kv_len = std::stoi(argv[++i]); + } else if (arg == "--heads" && i + 1 < argc) { + num_heads = std::stoi(argv[++i]); + } else if (arg == "--help") { + std::cout << "Usage: " << argv[0] << " [options]\n" << "Options:\n" << " --thread Debug thread ID (0-255 for 4 warps)\n" << " --warp Debug warp ID (0-3 for 4 warps)\n" @@ -602,14 +576,13 @@ int main(int argc, char **argv) << " --kv_len Key/Value length (default: 128)\n" << " --heads Number of heads (default: 1)\n" << " --help Show this help message\n"; - return 0; - } + return 0; } + } - _TestSinglePrefillKernelCorrectness( - qo_len, kv_len, num_heads, num_heads, head_dim, causal, - QKVLayout(kv_layout), PosEncodingMode(pos_encoding_mode), - use_fp16_qk_reduction, debug_thread_id, debug_warp_id); + _TestSinglePrefillKernelCorrectness( + qo_len, kv_len, num_heads, num_heads, head_dim, causal, QKVLayout(kv_layout), + PosEncodingMode(pos_encoding_mode), use_fp16_qk_reduction, debug_thread_id, debug_warp_id); } // int main(int argc, char **argv) diff --git a/libflashinfer/tests/hip/test_transpose_4x4_half_registers.cpp b/libflashinfer/tests/hip/test_transpose_4x4_half_registers.cpp index 86813310b7..d0b87df639 100644 --- a/libflashinfer/tests/hip/test_transpose_4x4_half_registers.cpp +++ b/libflashinfer/tests/hip/test_transpose_4x4_half_registers.cpp @@ -1,306 +1,277 @@ // test_transpose_4x4_half_registers.cpp -#include "gpu_iface/backend/hip/mma_hip.h" #include + #include #include -#define FI_GPU_CALL(call) \ - do { \ - gpuError_t err = (call); \ - if (err != gpuSuccess) { \ - std::ostringstream err_msg; \ - err_msg << "GPU error: " << gpuGetErrorString(err) << " at " \ - << __FILE__ << ":" << __LINE__; \ - throw std::runtime_error(err_msg.str()); \ - } \ - } while (0) - -__device__ __forceinline__ void debug_print_registers(const char *stage, - uint32_t lane_id, - uint32_t lane_in_group, - uint32_t *regs, - int num_regs, - uint32_t debug_group = 0) -{ - - // Only debug a specific group to avoid excessive output - if (lane_id / 4 != debug_group) - return; - - // Print identification info - printf("STAGE: %s | Thread %d (lane_in_group=%d): ", stage, lane_id, - lane_in_group); - - // Print raw 32-bit values - printf("RAW=["); - for (int i = 0; i < num_regs; i++) { - printf("0x%08x", regs[i]); - if (i < num_regs - 1) - printf(", "); - } - printf("] | "); - - // Print unpacked 16-bit values - printf("UNPACKED=["); - for (int i = 0; i < num_regs; i++) { - uint16_t hi = (regs[i] >> 16) & 0xFFFF; - uint16_t lo = regs[i] & 0xFFFF; - printf("%d,%d", hi, lo); - if (i < num_regs - 1) - printf(", "); - } - printf("]\n"); -} +#include "gpu_iface/backend/hip/mma_hip.h" -__device__ __forceinline__ void transpose_4x4_half_registers_opt(uint32_t *R) -{ - // Calculate lane within 4-thread group - uint32_t lane_id = threadIdx.x % 64; - uint32_t lane_in_group = lane_id % 4; - - // === ROUND 1: Exchange with neighbor (XOR with 1) === - // T0↔T1, T2↔T3 partial exchange - uint32_t reg_idx = (lane_in_group >> 1) & 0x1; - uint32_t exchanged_val = __shfl_xor(R[reg_idx], 0x1); - uint32_t shift = (lane_in_group & 1) * 16; - uint32_t keep_mask = 0xFFFF0000 >> shift; - int right_shift_amount = 16 * (1 - (lane_in_group & 1)); - int left_shift_amount = 16 * (lane_in_group & 1); - R[reg_idx] = (R[reg_idx] & keep_mask) | - ((exchanged_val >> right_shift_amount) << left_shift_amount); - - // === ROUND 2: Exchange with one hop (XOR with 2) === - // T0↔T2, T1↔T3 exchange R[0] and R[1] - // Swap entire registers based on thread position - uint32_t is_top = 1 - reg_idx; - uint32_t temp0 = __shfl_xor(R[0], 0x2); - uint32_t temp1 = __shfl_xor(R[1], 0x2); - - // Compute both possibilities and select - R[0] = R[0] * is_top + temp1 * reg_idx; - R[1] = temp0 * is_top + R[1] * reg_idx; - - // === ROUND 3: Exchange with neighbor again (XOR with 1) === - // T0↔T1, T2↔T3 exchange remaining parts - - reg_idx = 1 - reg_idx; - exchanged_val = __shfl_xor(R[reg_idx], 0x1); - R[reg_idx] = (R[reg_idx] & keep_mask) | - ((exchanged_val >> right_shift_amount) << left_shift_amount); +#define FI_GPU_CALL(call) \ + do { \ + gpuError_t err = (call); \ + if (err != gpuSuccess) { \ + std::ostringstream err_msg; \ + err_msg << "GPU error: " << gpuGetErrorString(err) << " at " << __FILE__ << ":" << __LINE__; \ + throw std::runtime_error(err_msg.str()); \ + } \ + } while (0) + +__device__ __forceinline__ void debug_print_registers(const char* stage, uint32_t lane_id, + uint32_t lane_in_group, uint32_t* regs, + int num_regs, uint32_t debug_group = 0) { + // Only debug a specific group to avoid excessive output + if (lane_id / 4 != debug_group) return; + + // Print identification info + printf("STAGE: %s | Thread %d (lane_in_group=%d): ", stage, lane_id, lane_in_group); + + // Print raw 32-bit values + printf("RAW=["); + for (int i = 0; i < num_regs; i++) { + printf("0x%08x", regs[i]); + if (i < num_regs - 1) printf(", "); + } + printf("] | "); + + // Print unpacked 16-bit values + printf("UNPACKED=["); + for (int i = 0; i < num_regs; i++) { + uint16_t hi = (regs[i] >> 16) & 0xFFFF; + uint16_t lo = regs[i] & 0xFFFF; + printf("%d,%d", hi, lo); + if (i < num_regs - 1) printf(", "); + } + printf("]\n"); } -__device__ __forceinline__ void transpose_4x4_half_registers_naive(uint32_t *R) -{ - // Calculate lane within 4-thread group - uint32_t lane_id = threadIdx.x % 64; - uint32_t lane_in_group = lane_id % 4; - - if (lane_id == 0) { - debug_print_registers("Initial", lane_id, lane_in_group, R, 2, 0); - } - - // === ROUND 1: Exchange with neighbor (XOR with 1) === - // T0↔T1, T2↔T3 partial exchange - - // Update based on thread position - if (lane_in_group < 2) { - uint32_t r0_exchanged = __shfl_xor(R[0], 0x1); - // Top half (T0, T1) update R[0] - if (lane_in_group & 1) { // T1 - R[0] = (R[0] & 0x0000FFFF) | (r0_exchanged << 16); - } - else { // T0 - R[0] = (R[0] & 0xFFFF0000) | (r0_exchanged >> 16); - } - } - else { - uint32_t r1_exchanged = __shfl_xor(R[1], 0x1); - // Bottom half (T2, T3) update R[1] - if (lane_in_group & 1) { // T1 - R[1] = (R[1] & 0x0000FFFF) | (r1_exchanged << 16); - } - else { // T0 - R[1] = (R[1] & 0xFFFF0000) | (r1_exchanged >> 16); - } - } +__device__ __forceinline__ void transpose_4x4_half_registers_opt(uint32_t* R) { + // Calculate lane within 4-thread group + uint32_t lane_id = threadIdx.x % 64; + uint32_t lane_in_group = lane_id % 4; + + // === ROUND 1: Exchange with neighbor (XOR with 1) === + // T0↔T1, T2↔T3 partial exchange + uint32_t reg_idx = (lane_in_group >> 1) & 0x1; + uint32_t exchanged_val = __shfl_xor(R[reg_idx], 0x1); + uint32_t shift = (lane_in_group & 1) * 16; + uint32_t keep_mask = 0xFFFF0000 >> shift; + int right_shift_amount = 16 * (1 - (lane_in_group & 1)); + int left_shift_amount = 16 * (lane_in_group & 1); + R[reg_idx] = + (R[reg_idx] & keep_mask) | ((exchanged_val >> right_shift_amount) << left_shift_amount); + + // === ROUND 2: Exchange with one hop (XOR with 2) === + // T0↔T2, T1↔T3 exchange R[0] and R[1] + // Swap entire registers based on thread position + uint32_t is_top = 1 - reg_idx; + uint32_t temp0 = __shfl_xor(R[0], 0x2); + uint32_t temp1 = __shfl_xor(R[1], 0x2); + + // Compute both possibilities and select + R[0] = R[0] * is_top + temp1 * reg_idx; + R[1] = temp0 * is_top + R[1] * reg_idx; + + // === ROUND 3: Exchange with neighbor again (XOR with 1) === + // T0↔T1, T2↔T3 exchange remaining parts + + reg_idx = 1 - reg_idx; + exchanged_val = __shfl_xor(R[reg_idx], 0x1); + R[reg_idx] = + (R[reg_idx] & keep_mask) | ((exchanged_val >> right_shift_amount) << left_shift_amount); +} - // Debug after first recombination - if (lane_id == 3) { - debug_print_registers("After Round 1 shuffles", lane_id, lane_in_group, - R, 2, 0); +__device__ __forceinline__ void transpose_4x4_half_registers_naive(uint32_t* R) { + // Calculate lane within 4-thread group + uint32_t lane_id = threadIdx.x % 64; + uint32_t lane_in_group = lane_id % 4; + + if (lane_id == 0) { + debug_print_registers("Initial", lane_id, lane_in_group, R, 2, 0); + } + + // === ROUND 1: Exchange with neighbor (XOR with 1) === + // T0↔T1, T2↔T3 partial exchange + + // Update based on thread position + if (lane_in_group < 2) { + uint32_t r0_exchanged = __shfl_xor(R[0], 0x1); + // Top half (T0, T1) update R[0] + if (lane_in_group & 1) { // T1 + R[0] = (R[0] & 0x0000FFFF) | (r0_exchanged << 16); + } else { // T0 + R[0] = (R[0] & 0xFFFF0000) | (r0_exchanged >> 16); } - - // === ROUND 2: Exchange with one hop (XOR with 2) === - // T0↔T2, T1↔T3 exchange R[0] and R[1] - uint32_t temp0_exchanged = __shfl_xor(R[0], 0x2); - uint32_t temp1_exchanged = __shfl_xor(R[1], 0x2); - - // Swap entire registers based on thread position - if (lane_in_group < 2) { - R[1] = temp0_exchanged; + } else { + uint32_t r1_exchanged = __shfl_xor(R[1], 0x1); + // Bottom half (T2, T3) update R[1] + if (lane_in_group & 1) { // T1 + R[1] = (R[1] & 0x0000FFFF) | (r1_exchanged << 16); + } else { // T0 + R[1] = (R[1] & 0xFFFF0000) | (r1_exchanged >> 16); } - else { - // Bottom threads (T2, T3) get R[1] from partner, keep own R[0] - R[0] = temp1_exchanged; + } + + // Debug after first recombination + if (lane_id == 3) { + debug_print_registers("After Round 1 shuffles", lane_id, lane_in_group, R, 2, 0); + } + + // === ROUND 2: Exchange with one hop (XOR with 2) === + // T0↔T2, T1↔T3 exchange R[0] and R[1] + uint32_t temp0_exchanged = __shfl_xor(R[0], 0x2); + uint32_t temp1_exchanged = __shfl_xor(R[1], 0x2); + + // Swap entire registers based on thread position + if (lane_in_group < 2) { + R[1] = temp0_exchanged; + } else { + // Bottom threads (T2, T3) get R[1] from partner, keep own R[0] + R[0] = temp1_exchanged; + } + + if (lane_id == 0) { + debug_print_registers("After Round 2 shuffles", lane_id, lane_in_group, R, 2, 0); + } + + // === ROUND 3: Exchange with neighbor again (XOR with 1) === + // T0↔T1, T2↔T3 exchange remaining parts + + if (lane_in_group < 2) { + uint32_t r1_exchanged = __shfl_xor(R[1], 0x1); + // Top half (T0, T1) update R[0] + if (lane_in_group & 1) { // T1 + R[1] = (R[1] & 0x0000FFFF) | (r1_exchanged << 16); + } else { // T0 + R[1] = (R[1] & 0xFFFF0000) | (r1_exchanged >> 16); } - - if (lane_id == 0) { - debug_print_registers("After Round 2 shuffles", lane_id, lane_in_group, - R, 2, 0); + } else { + uint32_t r1_exchanged = __shfl_xor(R[0], 0x1); + // Bottom half (T2, T3) update R[1] + if (lane_in_group & 1) { // T1 + R[0] = (R[0] & 0x0000FFFF) | (r1_exchanged << 16); + } else { // T0 + R[0] = (R[0] & 0xFFFF0000) | (r1_exchanged >> 16); } + } - // === ROUND 3: Exchange with neighbor again (XOR with 1) === - // T0↔T1, T2↔T3 exchange remaining parts - - if (lane_in_group < 2) { - uint32_t r1_exchanged = __shfl_xor(R[1], 0x1); - // Top half (T0, T1) update R[0] - if (lane_in_group & 1) { // T1 - R[1] = (R[1] & 0x0000FFFF) | (r1_exchanged << 16); - } - else { // T0 - R[1] = (R[1] & 0xFFFF0000) | (r1_exchanged >> 16); - } - } - else { - uint32_t r1_exchanged = __shfl_xor(R[0], 0x1); - // Bottom half (T2, T3) update R[1] - if (lane_in_group & 1) { // T1 - R[0] = (R[0] & 0x0000FFFF) | (r1_exchanged << 16); - } - else { // T0 - R[0] = (R[0] & 0xFFFF0000) | (r1_exchanged >> 16); - } - } - - if (lane_id == 3) { - debug_print_registers("After Round 2 shuffles", lane_id, lane_in_group, - R, 2, 0); - } + if (lane_id == 3) { + debug_print_registers("After Round 2 shuffles", lane_id, lane_in_group, R, 2, 0); + } } // Helper function to convert two uint16_t values to a single uint32_t -__host__ __device__ uint32_t pack_half2(uint16_t a, uint16_t b) -{ - return ((uint32_t)a << 16) | (uint32_t)b; +__host__ __device__ uint32_t pack_half2(uint16_t a, uint16_t b) { + return ((uint32_t)a << 16) | (uint32_t)b; } // Helper function to extract two uint16_t values from a single uint32_t -__host__ __device__ void unpack_half2(uint32_t packed, uint16_t &a, uint16_t &b) -{ - a = (packed >> 16) & 0xFFFF; - b = packed & 0xFFFF; +__host__ __device__ void unpack_half2(uint32_t packed, uint16_t& a, uint16_t& b) { + a = (packed >> 16) & 0xFFFF; + b = packed & 0xFFFF; } // Kernel to test the transpose function -__global__ void test_transpose_kernel(uint16_t *output) -{ - uint32_t thread_id = threadIdx.x + blockIdx.x * blockDim.x; - uint32_t lane_id = thread_id % 64; - - // Calculate the thread's position in the logical 4x4 grid - uint32_t lane_in_group = lane_id % 4; // Position within group - - // Initialize test data - each thread creates a row of the matrix B - // Values are designed for easy verification: lane_in_group * 100 + column - uint16_t row_elements[4]; - for (int i = 0; i < 4; i++) { - row_elements[i] = lane_in_group * 100 + i; // B[lane_in_group][i] - } - - // Pack the 4 half-precision values into 2 registers - uint32_t R[2]; - R[0] = pack_half2(row_elements[0], row_elements[1]); - R[1] = pack_half2(row_elements[2], row_elements[3]); - - // Call the transpose function - flashinfer::gpu_iface::mma_impl::hip::transpose_4x4_half_registers(R); - - // Unpack the transposed results - uint16_t transposed[4]; - unpack_half2(R[0], transposed[0], transposed[1]); - unpack_half2(R[1], transposed[2], transposed[3]); - - // Write output - store both original and transposed values for verification - for (int i = 0; i < 4; i++) { - // Original values (row-major layout) - output[thread_id * 8 + i] = row_elements[i]; - // Transposed values (column-major layout) - output[thread_id * 8 + 4 + i] = transposed[i]; - } +__global__ void test_transpose_kernel(uint16_t* output) { + uint32_t thread_id = threadIdx.x + blockIdx.x * blockDim.x; + uint32_t lane_id = thread_id % 64; + + // Calculate the thread's position in the logical 4x4 grid + uint32_t lane_in_group = lane_id % 4; // Position within group + + // Initialize test data - each thread creates a row of the matrix B + // Values are designed for easy verification: lane_in_group * 100 + column + uint16_t row_elements[4]; + for (int i = 0; i < 4; i++) { + row_elements[i] = lane_in_group * 100 + i; // B[lane_in_group][i] + } + + // Pack the 4 half-precision values into 2 registers + uint32_t R[2]; + R[0] = pack_half2(row_elements[0], row_elements[1]); + R[1] = pack_half2(row_elements[2], row_elements[3]); + + // Call the transpose function + flashinfer::gpu_iface::mma_impl::hip::transpose_4x4_half_registers(R); + + // Unpack the transposed results + uint16_t transposed[4]; + unpack_half2(R[0], transposed[0], transposed[1]); + unpack_half2(R[1], transposed[2], transposed[3]); + + // Write output - store both original and transposed values for verification + for (int i = 0; i < 4; i++) { + // Original values (row-major layout) + output[thread_id * 8 + i] = row_elements[i]; + // Transposed values (column-major layout) + output[thread_id * 8 + 4 + i] = transposed[i]; + } } -int main() -{ - // Allocate memory for output (both original and transposed data) - const int num_threads = 64; // One wavefront - const int values_per_thread = - 8; // Each thread stores 4 original + 4 transposed values - const int total_values = num_threads * values_per_thread; +int main() { + // Allocate memory for output (both original and transposed data) + const int num_threads = 64; // One wavefront + const int values_per_thread = 8; // Each thread stores 4 original + 4 transposed values + const int total_values = num_threads * values_per_thread; - std::vector h_output(total_values); - uint16_t *d_output; + std::vector h_output(total_values); + uint16_t* d_output; - FI_GPU_CALL(hipMalloc(&d_output, total_values * sizeof(uint16_t))); + FI_GPU_CALL(hipMalloc(&d_output, total_values * sizeof(uint16_t))); - // Launch the kernel - test_transpose_kernel<<<1, num_threads>>>(d_output); + // Launch the kernel + test_transpose_kernel<<<1, num_threads>>>(d_output); - // Copy results back to host - FI_GPU_CALL(hipMemcpy(h_output.data(), d_output, - total_values * sizeof(uint16_t), - hipMemcpyDeviceToHost)); + // Copy results back to host + FI_GPU_CALL( + hipMemcpy(h_output.data(), d_output, total_values * sizeof(uint16_t), hipMemcpyDeviceToHost)); - // Verify the results - bool success = true; - std::cout << "Testing matrix transposition with shuffle operations..." - << std::endl; + // Verify the results + bool success = true; + std::cout << "Testing matrix transposition with shuffle operations..." << std::endl; + + for (int group = 0; group < num_threads / 4; group++) { + std::cout << "\nGroup " << group << " results:" << std::endl; + + for (int lane = 0; lane < 4; lane++) { + int thread_idx = group * 4 + lane; + + // Print original values + std::cout << "Thread " << thread_idx << " original: "; + for (int i = 0; i < 4; i++) { + std::cout << h_output[thread_idx * 8 + i] << " "; + } + std::cout << std::endl; - for (int group = 0; group < num_threads / 4; group++) { - std::cout << "\nGroup " << group << " results:" << std::endl; - - for (int lane = 0; lane < 4; lane++) { - int thread_idx = group * 4 + lane; - - // Print original values - std::cout << "Thread " << thread_idx << " original: "; - for (int i = 0; i < 4; i++) { - std::cout << h_output[thread_idx * 8 + i] << " "; - } - std::cout << std::endl; - - // Print and verify transposed values - std::cout << "Thread " << thread_idx << " transposed: "; - for (int i = 0; i < 4; i++) { - uint16_t actual = h_output[thread_idx * 8 + 4 + i]; - std::cout << actual << " "; - - // Expected after transpose: Thread N gets column N - // Thread 0 should have [0*100+0, 1*100+0, 2*100+0, 3*100+0] - // Thread 1 should have [0*100+1, 1*100+1, 2*100+1, 3*100+1] - uint16_t expected = i * 100 + lane; - - if (actual != expected) { - success = false; - std::cout << "(Expected: " << expected << ") "; - } - } - std::cout << std::endl; + // Print and verify transposed values + std::cout << "Thread " << thread_idx << " transposed: "; + for (int i = 0; i < 4; i++) { + uint16_t actual = h_output[thread_idx * 8 + 4 + i]; + std::cout << actual << " "; + + // Expected after transpose: Thread N gets column N + // Thread 0 should have [0*100+0, 1*100+0, 2*100+0, 3*100+0] + // Thread 1 should have [0*100+1, 1*100+1, 2*100+1, 3*100+1] + uint16_t expected = i * 100 + lane; + + if (actual != expected) { + success = false; + std::cout << "(Expected: " << expected << ") "; } + } + std::cout << std::endl; } + } - if (success) { - std::cout << "\nTranspose test PASSED! All values correctly transposed." - << std::endl; - } - else { - std::cout << "\nTranspose test FAILED! Some values were not correctly " - "transposed." - << std::endl; - } + if (success) { + std::cout << "\nTranspose test PASSED! All values correctly transposed." << std::endl; + } else { + std::cout << "\nTranspose test FAILED! Some values were not correctly " + "transposed." + << std::endl; + } - // Clean up - FI_GPU_CALL(hipFree(d_output)); + // Clean up + FI_GPU_CALL(hipFree(d_output)); - return success ? 0 : 1; + return success ? 0 : 1; } diff --git a/libflashinfer/utils/conversion_utils.h b/libflashinfer/utils/conversion_utils.h index 768bde165d..0263722cb3 100644 --- a/libflashinfer/utils/conversion_utils.h +++ b/libflashinfer/utils/conversion_utils.h @@ -8,59 +8,46 @@ #include #include -namespace fi::con -{ +namespace fi::con { template -__host__ __device__ __inline__ DTypeOut explicit_casting(DTypeIn value) -{ - return DTypeOut(value); +__host__ __device__ __inline__ DTypeOut explicit_casting(DTypeIn value) { + return DTypeOut(value); } template <> -__host__ __device__ __inline__ float -explicit_casting<__half, float>(__half value) -{ - return __half2float(value); +__host__ __device__ __inline__ float explicit_casting<__half, float>(__half value) { + return __half2float(value); } template <> -__host__ __device__ __inline__ float -explicit_casting<__hip_bfloat16, float>(__hip_bfloat16 value) -{ - return __bfloat162float(value); +__host__ __device__ __inline__ float explicit_casting<__hip_bfloat16, float>(__hip_bfloat16 value) { + return __bfloat162float(value); } template <> -__host__ __device__ __inline__ __half -explicit_casting(float value) -{ - return __float2half(value); +__host__ __device__ __inline__ __half explicit_casting(float value) { + return __float2half(value); } template <> -__host__ __device__ __inline__ __hip_bfloat16 -explicit_casting<__half, __hip_bfloat16>(__half value) -{ - return __float2bfloat16(__half2float(value)); +__host__ __device__ __inline__ __hip_bfloat16 explicit_casting<__half, __hip_bfloat16>( + __half value) { + return __float2bfloat16(__half2float(value)); } template <> -__host__ __device__ __inline__ float explicit_casting(float value) -{ - return value; +__host__ __device__ __inline__ float explicit_casting(float value) { + return value; } template <> -__host__ __device__ __inline__ __half -explicit_casting<__half, __half>(__half value) -{ - return value; +__host__ __device__ __inline__ __half explicit_casting<__half, __half>(__half value) { + return value; } template <> -__host__ __device__ __inline__ __hip_bfloat16 -explicit_casting<__hip_bfloat16, __hip_bfloat16>(__hip_bfloat16 value) -{ - return value; +__host__ __device__ __inline__ __hip_bfloat16 explicit_casting<__hip_bfloat16, __hip_bfloat16>( + __hip_bfloat16 value) { + return value; } -} // namespace fi::con +} // namespace fi::con diff --git a/libflashinfer/utils/cpu_reference.h b/libflashinfer/utils/cpu_reference.h index 54fc24736a..d3d418de4d 100644 --- a/libflashinfer/utils/cpu_reference.h +++ b/libflashinfer/utils/cpu_reference.h @@ -22,220 +22,171 @@ #include "utils.h" -namespace cpu_reference -{ +namespace cpu_reference { using namespace flashinfer; template -inline std::vector rms_norm(const T *input, - const T *weight, - size_t batch_size, - size_t d, - float eps = 1e-5) -{ - std::vector output(batch_size * d); - for (size_t i = 0; i < batch_size; ++i) { - float sum = 0; - for (size_t j = 0; j < d; ++j) { - sum += float(input[i * d + j]) * float(input[i * d + j]); - } - float rms_rcp = 1.f / (std::sqrt(sum / float(d)) + eps); - for (size_t j = 0; j < d; ++j) { - output[i * d + j] = - (float(input[i * d + j]) * rms_rcp) * float(weight[j]); - } +inline std::vector rms_norm(const T* input, const T* weight, size_t batch_size, size_t d, + float eps = 1e-5) { + std::vector output(batch_size * d); + for (size_t i = 0; i < batch_size; ++i) { + float sum = 0; + for (size_t j = 0; j < d; ++j) { + sum += float(input[i * d + j]) * float(input[i * d + j]); + } + float rms_rcp = 1.f / (std::sqrt(sum / float(d)) + eps); + for (size_t j = 0; j < d; ++j) { + output[i * d + j] = (float(input[i * d + j]) * rms_rcp) * float(weight[j]); } - return std::move(output); + } + return std::move(output); } template -inline std::vector -exclusive_prefix_sum(const T *input, size_t batch_size, size_t d) -{ - std::vector output(batch_size * d); - for (size_t i = 0; i < batch_size; ++i) { - for (size_t j = 0; j < d; ++j) { - output[i * d + j] = - (j == 0) ? 0 : output[i * d + j - 1] + input[i * d + j - 1]; - } +inline std::vector exclusive_prefix_sum(const T* input, size_t batch_size, size_t d) { + std::vector output(batch_size * d); + for (size_t i = 0; i < batch_size; ++i) { + for (size_t j = 0; j < d; ++j) { + output[i * d + j] = (j == 0) ? 0 : output[i * d + j - 1] + input[i * d + j - 1]; } - return std::move(output); + } + return std::move(output); } template -inline std::vector apply_llama_rope(const T *input, - size_t D, - size_t offset, - float rope_scale, - float rope_theta) -{ - std::vector rst(D); - std::vector permuted_input(D); - for (size_t k = 0; k < D; ++k) { - permuted_input[k] = - (k < D / 2) ? -float(input[k + D / 2]) : float(input[k - D / 2]); - } +inline std::vector apply_llama_rope(const T* input, size_t D, size_t offset, + float rope_scale, float rope_theta) { + std::vector rst(D); + std::vector permuted_input(D); + for (size_t k = 0; k < D; ++k) { + permuted_input[k] = (k < D / 2) ? -float(input[k + D / 2]) : float(input[k - D / 2]); + } - for (size_t k = 0; k < D; ++k) { - float inv_freq = - (offset / rope_scale) / - (std::pow(rope_theta, float(2 * (k % (D / 2))) / float(D))); - float cos = std::cos(inv_freq); - float sin = std::sin(inv_freq); - rst[k] = cos * float(input[k]) + sin * permuted_input[k]; - } - return std::move(rst); + for (size_t k = 0; k < D; ++k) { + float inv_freq = + (offset / rope_scale) / (std::pow(rope_theta, float(2 * (k % (D / 2))) / float(D))); + float cos = std::cos(inv_freq); + float sin = std::sin(inv_freq); + rst[k] = cos * float(input[k]) + sin * permuted_input[k]; + } + return std::move(rst); } template -std::vector -single_mha(const std::vector &q, - const std::vector &k, - const std::vector &v, - size_t qo_len, - size_t kv_len, - size_t num_qo_heads, - size_t num_kv_heads, - size_t head_dim, - bool causal = true, - QKVLayout kv_layout = QKVLayout::kHND, - PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, - float rope_scale = 1.f, - float rope_theta = 1e4) -{ - assert(qo_len <= kv_len); - assert(num_qo_heads % num_kv_heads == 0); - float sm_scale = 1.f / std::sqrt(float(head_dim)); - std::vector o(qo_len * num_qo_heads * head_dim); - std::vector att(kv_len); - std::vector q_rotary_local(head_dim); - std::vector k_rotary_local(head_dim); - DISPATCH_head_dim(head_dim, HEAD_DIM, { - tensor_info_t info(qo_len, kv_len, num_qo_heads, num_kv_heads, - kv_layout, HEAD_DIM); - for (size_t qo_head_idx = 0; qo_head_idx < num_qo_heads; ++qo_head_idx) - { - const size_t kv_head_idx = qo_head_idx / info.get_group_size(); - for (size_t q_idx = 0; q_idx < qo_len; ++q_idx) { - float max_val = -5e4; - if (pos_encoding_mode == PosEncodingMode::kRoPELlama) { - q_rotary_local = std::move(cpu_reference::apply_llama_rope( - q.data() + - info.get_q_elem_offset(q_idx, qo_head_idx, 0), - head_dim, q_idx + kv_len - qo_len, rope_scale, - rope_theta)); - } - for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { - att[kv_idx] = 0.; - switch (pos_encoding_mode) { - case PosEncodingMode::kNone: - { - for (size_t feat_idx = 0; feat_idx < head_dim; - ++feat_idx) - { - att[kv_idx] += - float(q[info.get_q_elem_offset( - q_idx, qo_head_idx, feat_idx)]) * - float(k[info.get_kv_elem_offset( - kv_idx, kv_head_idx, feat_idx)]) * - sm_scale; - } - break; - } - case PosEncodingMode::kRoPELlama: - { - k_rotary_local = - std::move(cpu_reference::apply_llama_rope( - k.data() + info.get_kv_elem_offset( - kv_idx, kv_head_idx, 0), - head_dim, kv_idx, rope_scale, rope_theta)); - for (size_t feat_idx = 0; feat_idx < head_dim; - ++feat_idx) - { - att[kv_idx] += q_rotary_local[feat_idx] * - k_rotary_local[feat_idx] * sm_scale; - } - break; - } - default: - { - std::ostringstream err_msg; - err_msg << "Unsupported rotary mode."; - FLASHINFER_ERROR(err_msg.str()); - } - } - // apply mask - if (causal && kv_idx > kv_len + q_idx - qo_len) { - att[kv_idx] = -5e4; - } - max_val = std::max(max_val, att[kv_idx]); - } - // exp minus max - float denom = 0; - for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { - att[kv_idx] = std::exp(att[kv_idx] - max_val); - denom += att[kv_idx]; - } +std::vector single_mha(const std::vector& q, const std::vector& k, + const std::vector& v, size_t qo_len, size_t kv_len, + size_t num_qo_heads, size_t num_kv_heads, size_t head_dim, + bool causal = true, QKVLayout kv_layout = QKVLayout::kHND, + PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, + float rope_scale = 1.f, float rope_theta = 1e4) { + assert(qo_len <= kv_len); + assert(num_qo_heads % num_kv_heads == 0); + float sm_scale = 1.f / std::sqrt(float(head_dim)); + std::vector o(qo_len * num_qo_heads * head_dim); + std::vector att(kv_len); + std::vector q_rotary_local(head_dim); + std::vector k_rotary_local(head_dim); + DISPATCH_head_dim(head_dim, HEAD_DIM, { + tensor_info_t info(qo_len, kv_len, num_qo_heads, num_kv_heads, kv_layout, HEAD_DIM); + for (size_t qo_head_idx = 0; qo_head_idx < num_qo_heads; ++qo_head_idx) { + const size_t kv_head_idx = qo_head_idx / info.get_group_size(); + for (size_t q_idx = 0; q_idx < qo_len; ++q_idx) { + float max_val = -5e4; + if (pos_encoding_mode == PosEncodingMode::kRoPELlama) { + q_rotary_local = std::move(cpu_reference::apply_llama_rope( + q.data() + info.get_q_elem_offset(q_idx, qo_head_idx, 0), head_dim, + q_idx + kv_len - qo_len, rope_scale, rope_theta)); + } + for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { + att[kv_idx] = 0.; + switch (pos_encoding_mode) { + case PosEncodingMode::kNone: { + for (size_t feat_idx = 0; feat_idx < head_dim; ++feat_idx) { + att[kv_idx] += float(q[info.get_q_elem_offset(q_idx, qo_head_idx, feat_idx)]) * + float(k[info.get_kv_elem_offset(kv_idx, kv_head_idx, feat_idx)]) * + sm_scale; + } + break; + } + case PosEncodingMode::kRoPELlama: { + k_rotary_local = std::move(cpu_reference::apply_llama_rope( + k.data() + info.get_kv_elem_offset(kv_idx, kv_head_idx, 0), head_dim, kv_idx, + rope_scale, rope_theta)); + for (size_t feat_idx = 0; feat_idx < head_dim; ++feat_idx) { + att[kv_idx] += q_rotary_local[feat_idx] * k_rotary_local[feat_idx] * sm_scale; + } + break; + } + default: { + std::ostringstream err_msg; + err_msg << "Unsupported rotary mode."; + FLASHINFER_ERROR(err_msg.str()); + } + } + // apply mask + if (causal && kv_idx > kv_len + q_idx - qo_len) { + att[kv_idx] = -5e4; + } + max_val = std::max(max_val, att[kv_idx]); + } + // exp minus max + float denom = 0; + for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { + att[kv_idx] = std::exp(att[kv_idx] - max_val); + denom += att[kv_idx]; + } - // divide by denom - for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { - att[kv_idx] /= denom; - } + // divide by denom + for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { + att[kv_idx] /= denom; + } - for (size_t feat_idx = 0; feat_idx < head_dim; ++feat_idx) { - float o_float = 0.; - for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { - o_float += - att[kv_idx] * float(v[info.get_kv_elem_offset( - kv_idx, kv_head_idx, feat_idx)]); - } - o[info.get_o_elem_offset(q_idx, qo_head_idx, feat_idx)] = - dtype_out(o_float); - } - } + for (size_t feat_idx = 0; feat_idx < head_dim; ++feat_idx) { + float o_float = 0.; + for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { + o_float += + att[kv_idx] * float(v[info.get_kv_elem_offset(kv_idx, kv_head_idx, feat_idx)]); + } + o[info.get_o_elem_offset(q_idx, qo_head_idx, feat_idx)] = dtype_out(o_float); } - }); - return std::move(o); + } + } + }); + return std::move(o); } template -void append_paged_kv_cache(paged_kv_t page_cpu, - const std::vector> &keys, - const std::vector> &values, - const std::vector &append_indptr) -{ - size_t batch_size = page_cpu.batch_size; - size_t num_heads = page_cpu.num_heads; - size_t head_dim = page_cpu.head_dim; - size_t page_size = page_cpu.page_size; - for (size_t i = 0; i < batch_size; ++i) { - const std::vector &ki = keys[i]; - const std::vector &vi = values[i]; - size_t append_seq_len = append_indptr[i + 1] - append_indptr[i]; - size_t num_pages_i = page_cpu.indptr[i + 1] - page_cpu.indptr[i]; - size_t seq_len = - (num_pages_i - 1) * page_size + page_cpu.last_page_len[i]; - assert(append_seq_len <= seq_len); - size_t append_start = seq_len - append_seq_len; +void append_paged_kv_cache(paged_kv_t page_cpu, const std::vector>& keys, + const std::vector>& values, + const std::vector& append_indptr) { + size_t batch_size = page_cpu.batch_size; + size_t num_heads = page_cpu.num_heads; + size_t head_dim = page_cpu.head_dim; + size_t page_size = page_cpu.page_size; + for (size_t i = 0; i < batch_size; ++i) { + const std::vector& ki = keys[i]; + const std::vector& vi = values[i]; + size_t append_seq_len = append_indptr[i + 1] - append_indptr[i]; + size_t num_pages_i = page_cpu.indptr[i + 1] - page_cpu.indptr[i]; + size_t seq_len = (num_pages_i - 1) * page_size + page_cpu.last_page_len[i]; + assert(append_seq_len <= seq_len); + size_t append_start = seq_len - append_seq_len; - for (size_t j = 0; j < append_seq_len; ++j) { - size_t page_seq_idx = j + append_start; - size_t page_idx = - page_cpu.indices[page_cpu.indptr[i] + page_seq_idx / page_size]; - size_t entry_idx = page_seq_idx % page_size; - for (size_t h = 0; h < num_heads; ++h) { - std::copy(ki.begin() + (j * num_heads + h) * head_dim, - ki.begin() + (j * num_heads + h + 1) * head_dim, - page_cpu.k_data + page_cpu.get_elem_offset( - page_idx, h, entry_idx, 0)); - std::copy(vi.begin() + (j * num_heads + h) * head_dim, - vi.begin() + (j * num_heads + h + 1) * head_dim, - page_cpu.v_data + page_cpu.get_elem_offset( - page_idx, h, entry_idx, 0)); - } - } + for (size_t j = 0; j < append_seq_len; ++j) { + size_t page_seq_idx = j + append_start; + size_t page_idx = page_cpu.indices[page_cpu.indptr[i] + page_seq_idx / page_size]; + size_t entry_idx = page_seq_idx % page_size; + for (size_t h = 0; h < num_heads; ++h) { + std::copy(ki.begin() + (j * num_heads + h) * head_dim, + ki.begin() + (j * num_heads + h + 1) * head_dim, + page_cpu.k_data + page_cpu.get_elem_offset(page_idx, h, entry_idx, 0)); + std::copy(vi.begin() + (j * num_heads + h) * head_dim, + vi.begin() + (j * num_heads + h + 1) * head_dim, + page_cpu.v_data + page_cpu.get_elem_offset(page_idx, h, entry_idx, 0)); + } } + } } -} // namespace cpu_reference +} // namespace cpu_reference diff --git a/libflashinfer/utils/cpu_reference_hip.h b/libflashinfer/utils/cpu_reference_hip.h index 94ed05ea76..3b7c4a1a23 100644 --- a/libflashinfer/utils/cpu_reference_hip.h +++ b/libflashinfer/utils/cpu_reference_hip.h @@ -5,399 +5,324 @@ #pragma once -#include "flashinfer/exception.h" +#include +#include + +#include +#include #include "flashinfer/attention/generic/page.cuh" #include "flashinfer/attention/generic/pos_enc.cuh" - +#include "flashinfer/exception.h" #include "utils_hip.h" -#include -#include -#include -#include - -namespace cpu_reference -{ +namespace cpu_reference { using namespace flashinfer; template -inline std::vector rms_norm(const T *input, - const T *weight, - size_t batch_size, - size_t d, - float eps = 1e-5) -{ - std::vector output(batch_size * d); - for (size_t i = 0; i < batch_size; ++i) { - float sum = 0; - for (size_t j = 0; j < d; ++j) { - sum += float(input[i * d + j]) * float(input[i * d + j]); - } - float rms_rcp = 1.f / (std::sqrt(sum / float(d)) + eps); - for (size_t j = 0; j < d; ++j) { - output[i * d + j] = - (float(input[i * d + j]) * rms_rcp) * float(weight[j]); - } +inline std::vector rms_norm(const T* input, const T* weight, size_t batch_size, size_t d, + float eps = 1e-5) { + std::vector output(batch_size * d); + for (size_t i = 0; i < batch_size; ++i) { + float sum = 0; + for (size_t j = 0; j < d; ++j) { + sum += float(input[i * d + j]) * float(input[i * d + j]); } - return std::move(output); + float rms_rcp = 1.f / (std::sqrt(sum / float(d)) + eps); + for (size_t j = 0; j < d; ++j) { + output[i * d + j] = (float(input[i * d + j]) * rms_rcp) * float(weight[j]); + } + } + return std::move(output); } template -inline std::vector -exclusive_prefix_sum(const T *input, size_t batch_size, size_t d) -{ - std::vector output(batch_size * d); - for (size_t i = 0; i < batch_size; ++i) { - for (size_t j = 0; j < d; ++j) { - output[i * d + j] = - (j == 0) ? 0 : output[i * d + j - 1] + input[i * d + j - 1]; - } +inline std::vector exclusive_prefix_sum(const T* input, size_t batch_size, size_t d) { + std::vector output(batch_size * d); + for (size_t i = 0; i < batch_size; ++i) { + for (size_t j = 0; j < d; ++j) { + output[i * d + j] = (j == 0) ? 0 : output[i * d + j - 1] + input[i * d + j - 1]; } - return std::move(output); + } + return std::move(output); } template -inline std::vector apply_llama_rope_debug(const T *input, - size_t D, - size_t offset, - float rope_scale, - float rope_theta) -{ - std::vector rst(D); - std::vector permuted_input(D); - // Print the input parameters - // Only print for first position to avoid flood - if (offset == 134) { // First position in your log - std::cout << "=== CPU ROPE DEBUG ===\n"; - std::cout << "D: " << D << ", offset: " << offset - << ", rope_scale: " << rope_scale - << ", rope_theta: " << rope_theta << std::endl; - - std::cout << "CPU Frequencies vs GPU comparison:\n"; - for (size_t k = 0; k < min(4ul, D); ++k) { - float freq_base = float(2 * (k % (D / 2))) / float(D); - float frequency = - 1.0f / std::pow(rope_theta, freq_base); // This should match GPU - float angle = - (offset / rope_scale) / std::pow(rope_theta, freq_base); - - std::cout << "CPU: feature[" << k << "] freq_base=" << freq_base - << " frequency=" << frequency << " angle=" << angle - << std::endl; - } +inline std::vector apply_llama_rope_debug(const T* input, size_t D, size_t offset, + float rope_scale, float rope_theta) { + std::vector rst(D); + std::vector permuted_input(D); + // Print the input parameters + // Only print for first position to avoid flood + if (offset == 134) { // First position in your log + std::cout << "=== CPU ROPE DEBUG ===\n"; + std::cout << "D: " << D << ", offset: " << offset << ", rope_scale: " << rope_scale + << ", rope_theta: " << rope_theta << std::endl; + + std::cout << "CPU Frequencies vs GPU comparison:\n"; + for (size_t k = 0; k < min(4ul, D); ++k) { + float freq_base = float(2 * (k % (D / 2))) / float(D); + float frequency = 1.0f / std::pow(rope_theta, freq_base); // This should match GPU + float angle = (offset / rope_scale) / std::pow(rope_theta, freq_base); + + std::cout << "CPU: feature[" << k << "] freq_base=" << freq_base << " frequency=" << frequency + << " angle=" << angle << std::endl; } - - for (size_t k = 0; k < D; ++k) { - permuted_input[k] = - (k < D / 2) ? -fi::con::explicit_casting(input[k + D / 2]) - : fi::con::explicit_casting(input[k - D / 2]); - } - - for (size_t k = 0; k < D; ++k) { - float inv_freq = - (offset / rope_scale) / - (std::pow(rope_theta, float(2 * (k % (D / 2))) / float(D))); - float cos = std::cos(inv_freq); - float sin = std::sin(inv_freq); - - if (std::is_same_v) - rst[k] = cos * fi::con::explicit_casting(input[k]) + - sin * permuted_input[k]; - } - return rst; + } + + for (size_t k = 0; k < D; ++k) { + permuted_input[k] = (k < D / 2) ? -fi::con::explicit_casting(input[k + D / 2]) + : fi::con::explicit_casting(input[k - D / 2]); + } + + for (size_t k = 0; k < D; ++k) { + float inv_freq = + (offset / rope_scale) / (std::pow(rope_theta, float(2 * (k % (D / 2))) / float(D))); + float cos = std::cos(inv_freq); + float sin = std::sin(inv_freq); + + if (std::is_same_v) + rst[k] = cos * fi::con::explicit_casting(input[k]) + sin * permuted_input[k]; + } + return rst; } template -inline std::vector apply_llama_rope(const T *input, - size_t D, - size_t offset, - float rope_scale, - float rope_theta) -{ - std::vector rst(D); - std::vector permuted_input(D); - for (size_t k = 0; k < D; ++k) { - - permuted_input[k] = - (k < D / 2) ? -fi::con::explicit_casting(input[k + D / 2]) - : fi::con::explicit_casting(input[k - D / 2]); - } - - for (size_t k = 0; k < D; ++k) { - float inv_freq = - (offset / rope_scale) / - (std::pow(rope_theta, float(2 * (k % (D / 2))) / float(D))); - float cos = std::cos(inv_freq); - float sin = std::sin(inv_freq); - - if (std::is_same_v) - rst[k] = cos * fi::con::explicit_casting(input[k]) + - sin * permuted_input[k]; - } - return rst; +inline std::vector apply_llama_rope(const T* input, size_t D, size_t offset, + float rope_scale, float rope_theta) { + std::vector rst(D); + std::vector permuted_input(D); + for (size_t k = 0; k < D; ++k) { + permuted_input[k] = (k < D / 2) ? -fi::con::explicit_casting(input[k + D / 2]) + : fi::con::explicit_casting(input[k - D / 2]); + } + + for (size_t k = 0; k < D; ++k) { + float inv_freq = + (offset / rope_scale) / (std::pow(rope_theta, float(2 * (k % (D / 2))) / float(D))); + float cos = std::cos(inv_freq); + float sin = std::sin(inv_freq); + + if (std::is_same_v) + rst[k] = cos * fi::con::explicit_casting(input[k]) + sin * permuted_input[k]; + } + return rst; } template -std::vector compute_qk(const std::vector &q, - const std::vector &k, - size_t qo_len, - size_t kv_len, - size_t num_qo_heads, - size_t num_kv_heads, - size_t head_dim, - QKVLayout kv_layout = QKVLayout::kHND) -{ - - assert(num_qo_heads % num_kv_heads == 0); - assert(q.size() == qo_len * num_qo_heads * head_dim); - assert(k.size() == kv_len * num_kv_heads * head_dim); - - std::vector qk_scores(qo_len * num_qo_heads * kv_len); - - DISPATCH_head_dim(head_dim, HEAD_DIM, { - tensor_info_t info(qo_len, kv_len, num_qo_heads, num_kv_heads, - kv_layout, HEAD_DIM); - - for (size_t qo_head_idx = 0; qo_head_idx < num_qo_heads; ++qo_head_idx) - { - const size_t kv_head_idx = qo_head_idx / info.get_group_size(); - - for (size_t q_idx = 0; q_idx < qo_len; ++q_idx) { - for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { - float qk_score = 0.0f; - - // Pure Q*K^T - NO scaling (matching HIP compute_qk) - for (size_t feat_idx = 0; feat_idx < head_dim; ++feat_idx) { - qk_score += fi::con::explicit_casting( - q[info.get_q_elem_offset( - q_idx, qo_head_idx, feat_idx)]) * - fi::con::explicit_casting( - k[info.get_kv_elem_offset( - kv_idx, kv_head_idx, feat_idx)]); - } - - size_t output_idx = - qo_head_idx * qo_len * kv_len + q_idx * kv_len + kv_idx; - qk_scores[output_idx] = qk_score; - } - } +std::vector compute_qk(const std::vector& q, const std::vector& k, + size_t qo_len, size_t kv_len, size_t num_qo_heads, + size_t num_kv_heads, size_t head_dim, + QKVLayout kv_layout = QKVLayout::kHND) { + assert(num_qo_heads % num_kv_heads == 0); + assert(q.size() == qo_len * num_qo_heads * head_dim); + assert(k.size() == kv_len * num_kv_heads * head_dim); + + std::vector qk_scores(qo_len * num_qo_heads * kv_len); + + DISPATCH_head_dim(head_dim, HEAD_DIM, { + tensor_info_t info(qo_len, kv_len, num_qo_heads, num_kv_heads, kv_layout, HEAD_DIM); + + for (size_t qo_head_idx = 0; qo_head_idx < num_qo_heads; ++qo_head_idx) { + const size_t kv_head_idx = qo_head_idx / info.get_group_size(); + + for (size_t q_idx = 0; q_idx < qo_len; ++q_idx) { + for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { + float qk_score = 0.0f; + + // Pure Q*K^T - NO scaling (matching HIP compute_qk) + for (size_t feat_idx = 0; feat_idx < head_dim; ++feat_idx) { + qk_score += fi::con::explicit_casting( + q[info.get_q_elem_offset(q_idx, qo_head_idx, feat_idx)]) * + fi::con::explicit_casting( + k[info.get_kv_elem_offset(kv_idx, kv_head_idx, feat_idx)]); + } + + size_t output_idx = qo_head_idx * qo_len * kv_len + q_idx * kv_len + kv_idx; + qk_scores[output_idx] = qk_score; } - }); + } + } + }); - return qk_scores; + return qk_scores; } template -std::vector -single_mha(const std::vector &q, - const std::vector &k, - const std::vector &v, - size_t qo_len, - size_t kv_len, - size_t num_qo_heads, - size_t num_kv_heads, - size_t head_dim, - bool causal = true, - QKVLayout kv_layout = QKVLayout::kHND, - PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, - float logits_soft_cap = 8.0f, - float rope_scale = 1.f, - float rope_theta = 1e4, - bool use_soft_cap = false) -{ - assert(qo_len <= kv_len); - assert(num_qo_heads % num_kv_heads == 0); - float sm_scale = 1.f / std::sqrt(float(head_dim)); - // float sm_scale = 1.0; - std::vector o(qo_len * num_qo_heads * head_dim); - std::vector att(kv_len); - std::vector q_rotary_local(head_dim); - std::vector k_rotary_local(head_dim); - - float soft_cap_pre_tanh_scale = sm_scale / logits_soft_cap; - - DISPATCH_head_dim(head_dim, HEAD_DIM, { - tensor_info_t info(qo_len, kv_len, num_qo_heads, num_kv_heads, - kv_layout, HEAD_DIM); +std::vector single_mha(const std::vector& q, const std::vector& k, + const std::vector& v, size_t qo_len, size_t kv_len, + size_t num_qo_heads, size_t num_kv_heads, size_t head_dim, + bool causal = true, QKVLayout kv_layout = QKVLayout::kHND, + PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, + float logits_soft_cap = 8.0f, float rope_scale = 1.f, + float rope_theta = 1e4, bool use_soft_cap = false) { + assert(qo_len <= kv_len); + assert(num_qo_heads % num_kv_heads == 0); + float sm_scale = 1.f / std::sqrt(float(head_dim)); + // float sm_scale = 1.0; + std::vector o(qo_len * num_qo_heads * head_dim); + std::vector att(kv_len); + std::vector q_rotary_local(head_dim); + std::vector k_rotary_local(head_dim); + + float soft_cap_pre_tanh_scale = sm_scale / logits_soft_cap; + + DISPATCH_head_dim(head_dim, HEAD_DIM, { + tensor_info_t info(qo_len, kv_len, num_qo_heads, num_kv_heads, kv_layout, HEAD_DIM); #if Debug1 - std::cout << "DEBUG: Original Q (CPU): " << '\n'; - for (auto i = 0ul; i < 128; ++i) { - for (int j = 0; j < 64; ++j) { - std::cout << (float)q[info.get_q_elem_offset(i, 0, j)] << " "; - } - std::cout << std::endl; + std::cout << "DEBUG: Original Q (CPU): " << '\n'; + for (auto i = 0ul; i < 128; ++i) { + for (int j = 0; j < 64; ++j) { + std::cout << (float)q[info.get_q_elem_offset(i, 0, j)] << " "; + } + std::cout << std::endl; + } + std::cout << std::endl; + + std::cout << "DEBUG: Original K (CPU): " << '\n'; + for (auto i = 0ul; i < 128; ++i) { + for (int j = 0ul; j < 64; ++j) { + std::cout << (float)k[info.get_kv_elem_offset(i, 0, j)] << " "; + } + std::cout << std::endl; + } + std::cout << std::endl; +#endif + for (size_t qo_head_idx = 0; qo_head_idx < num_qo_heads; ++qo_head_idx) { + const size_t kv_head_idx = qo_head_idx / info.get_group_size(); + for (size_t q_idx = 0; q_idx < qo_len; ++q_idx) { + float max_val = -5e4; + if (pos_encoding_mode == PosEncodingMode::kRoPELlama) { + q_rotary_local = std::move(cpu_reference::apply_llama_rope( + q.data() + info.get_q_elem_offset(q_idx, qo_head_idx, 0), head_dim, + q_idx + kv_len - qo_len, rope_scale, rope_theta)); } - std::cout << std::endl; - - std::cout << "DEBUG: Original K (CPU): " << '\n'; - for (auto i = 0ul; i < 128; ++i) { - for (int j = 0ul; j < 64; ++j) { - std::cout << (float)k[info.get_kv_elem_offset(i, 0, j)] << " "; + for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { + att[kv_idx] = 0.; + switch (pos_encoding_mode) { + case PosEncodingMode::kNone: { + for (size_t feat_idx = 0; feat_idx < head_dim; ++feat_idx) { + att[kv_idx] += fi::con::explicit_casting( + q[info.get_q_elem_offset(q_idx, qo_head_idx, feat_idx)]) * + fi::con::explicit_casting( + k[info.get_kv_elem_offset(kv_idx, kv_head_idx, feat_idx)]) * + sm_scale; + } + break; + } + case PosEncodingMode::kRoPELlama: { + k_rotary_local = std::move(cpu_reference::apply_llama_rope( + k.data() + info.get_kv_elem_offset(kv_idx, kv_head_idx, 0), head_dim, kv_idx, + rope_scale, rope_theta)); + for (size_t feat_idx = 0; feat_idx < head_dim; ++feat_idx) { + att[kv_idx] += q_rotary_local[feat_idx] * k_rotary_local[feat_idx] * sm_scale; + } + break; + } + default: { + std::ostringstream err_msg; + err_msg << "Unsupported rotary mode."; + FLASHINFER_ERROR(err_msg.str()); } - std::cout << std::endl; + } + // apply mask + if (causal && kv_idx > kv_len + q_idx - qo_len) { + att[kv_idx] = -5e4; + } + max_val = std::max(max_val, att[kv_idx]); } - std::cout << std::endl; -#endif - for (size_t qo_head_idx = 0; qo_head_idx < num_qo_heads; ++qo_head_idx) - { - const size_t kv_head_idx = qo_head_idx / info.get_group_size(); - for (size_t q_idx = 0; q_idx < qo_len; ++q_idx) { - float max_val = -5e4; - if (pos_encoding_mode == PosEncodingMode::kRoPELlama) { - q_rotary_local = std::move(cpu_reference::apply_llama_rope( - q.data() + - info.get_q_elem_offset(q_idx, qo_head_idx, 0), - head_dim, q_idx + kv_len - qo_len, rope_scale, - rope_theta)); - } - for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { - att[kv_idx] = 0.; - switch (pos_encoding_mode) { - case PosEncodingMode::kNone: - { - for (size_t feat_idx = 0; feat_idx < head_dim; - ++feat_idx) - { - att[kv_idx] += - fi::con::explicit_casting( - q[info.get_q_elem_offset(q_idx, qo_head_idx, - feat_idx)]) * - fi::con::explicit_casting( - k[info.get_kv_elem_offset( - kv_idx, kv_head_idx, feat_idx)]) * - sm_scale; - } - break; - } - case PosEncodingMode::kRoPELlama: - { - k_rotary_local = - std::move(cpu_reference::apply_llama_rope( - k.data() + info.get_kv_elem_offset( - kv_idx, kv_head_idx, 0), - head_dim, kv_idx, rope_scale, rope_theta)); - for (size_t feat_idx = 0; feat_idx < head_dim; - ++feat_idx) - { - att[kv_idx] += q_rotary_local[feat_idx] * - k_rotary_local[feat_idx] * sm_scale; - } - break; - } - default: - { - std::ostringstream err_msg; - err_msg << "Unsupported rotary mode."; - FLASHINFER_ERROR(err_msg.str()); - } - } - // apply mask - if (causal && kv_idx > kv_len + q_idx - qo_len) { - att[kv_idx] = -5e4; - } - max_val = std::max(max_val, att[kv_idx]); - } #if Debug1 - if (qo_head_idx == 0) { - // for qo_len = 128, each warp on the GPU will store 128/4, - // that is, 32 attention scores. For CDNA3, these 32 scores - // are spread across 4 threads. - for (auto i = 0ul; i < 128; ++i) { - std::cout << att[i] / sm_scale << " "; - } - std::cout << std::endl; - } + if (qo_head_idx == 0) { + // for qo_len = 128, each warp on the GPU will store 128/4, + // that is, 32 attention scores. For CDNA3, these 32 scores + // are spread across 4 threads. + for (auto i = 0ul; i < 128; ++i) { + std::cout << att[i] / sm_scale << " "; + } + std::cout << std::endl; + } #endif - // exp minus max - float denom = 0; - for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { - att[kv_idx] = std::exp(att[kv_idx] - max_val); - denom += att[kv_idx]; - } + // exp minus max + float denom = 0; + for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { + att[kv_idx] = std::exp(att[kv_idx] - max_val); + denom += att[kv_idx]; + } #if Debug1 - if (qo_head_idx == 0) { - // for qo_len = 128, each warp on the GPU will store 128/4, - // that is, 32 attention scores. For CDNA3, these 32 scores - // are spread across 4 threads. - for (auto i = 0ul; i < 128; ++i) { - std::cout << att[i] << " "; - } - std::cout << std::endl; - } + if (qo_head_idx == 0) { + // for qo_len = 128, each warp on the GPU will store 128/4, + // that is, 32 attention scores. For CDNA3, these 32 scores + // are spread across 4 threads. + for (auto i = 0ul; i < 128; ++i) { + std::cout << att[i] << " "; + } + std::cout << std::endl; + } #endif #if Debug1 - if (qo_head_idx == 0) { - for (auto i = 0ul; i < 128; ++i) { - std::cout << denom << " "; - } - std::cout << std::endl; - } + if (qo_head_idx == 0) { + for (auto i = 0ul; i < 128; ++i) { + std::cout << denom << " "; + } + std::cout << std::endl; + } #endif - // divide by denom - for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { - att[kv_idx] /= denom; - } - - for (size_t feat_idx = 0; feat_idx < head_dim; ++feat_idx) { - float o_float = 0.; - for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { - o_float += att[kv_idx] * - fi::con::explicit_casting( - v[info.get_kv_elem_offset( - kv_idx, kv_head_idx, feat_idx)]); - } - o[info.get_o_elem_offset(q_idx, qo_head_idx, feat_idx)] = - fi::con::explicit_casting(o_float); - } - } + // divide by denom + for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { + att[kv_idx] /= denom; } - }); - return std::move(o); + + for (size_t feat_idx = 0; feat_idx < head_dim; ++feat_idx) { + float o_float = 0.; + for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { + o_float += att[kv_idx] * fi::con::explicit_casting( + v[info.get_kv_elem_offset(kv_idx, kv_head_idx, feat_idx)]); + } + o[info.get_o_elem_offset(q_idx, qo_head_idx, feat_idx)] = + fi::con::explicit_casting(o_float); + } + } + } + }); + return std::move(o); } template -void append_paged_kv_cache(paged_kv_t page_cpu, - const std::vector> &keys, - const std::vector> &values, - const std::vector &append_indptr) -{ - size_t batch_size = page_cpu.batch_size; - size_t num_heads = page_cpu.num_heads; - size_t head_dim = page_cpu.head_dim; - size_t page_size = page_cpu.page_size; - for (size_t i = 0; i < batch_size; ++i) { - const std::vector &ki = keys[i]; - const std::vector &vi = values[i]; - size_t append_seq_len = append_indptr[i + 1] - append_indptr[i]; - size_t num_pages_i = page_cpu.indptr[i + 1] - page_cpu.indptr[i]; - size_t seq_len = - (num_pages_i - 1) * page_size + page_cpu.last_page_len[i]; - assert(append_seq_len <= seq_len); - size_t append_start = seq_len - append_seq_len; - - for (size_t j = 0; j < append_seq_len; ++j) { - size_t page_seq_idx = j + append_start; - size_t page_idx = - page_cpu.indices[page_cpu.indptr[i] + page_seq_idx / page_size]; - size_t entry_idx = page_seq_idx % page_size; - for (size_t h = 0; h < num_heads; ++h) { - std::copy(ki.begin() + (j * num_heads + h) * head_dim, - ki.begin() + (j * num_heads + h + 1) * head_dim, - page_cpu.k_data + page_cpu.get_elem_offset( - page_idx, h, entry_idx, 0)); - std::copy(vi.begin() + (j * num_heads + h) * head_dim, - vi.begin() + (j * num_heads + h + 1) * head_dim, - page_cpu.v_data + page_cpu.get_elem_offset( - page_idx, h, entry_idx, 0)); - } - } +void append_paged_kv_cache(paged_kv_t page_cpu, const std::vector>& keys, + const std::vector>& values, + const std::vector& append_indptr) { + size_t batch_size = page_cpu.batch_size; + size_t num_heads = page_cpu.num_heads; + size_t head_dim = page_cpu.head_dim; + size_t page_size = page_cpu.page_size; + for (size_t i = 0; i < batch_size; ++i) { + const std::vector& ki = keys[i]; + const std::vector& vi = values[i]; + size_t append_seq_len = append_indptr[i + 1] - append_indptr[i]; + size_t num_pages_i = page_cpu.indptr[i + 1] - page_cpu.indptr[i]; + size_t seq_len = (num_pages_i - 1) * page_size + page_cpu.last_page_len[i]; + assert(append_seq_len <= seq_len); + size_t append_start = seq_len - append_seq_len; + + for (size_t j = 0; j < append_seq_len; ++j) { + size_t page_seq_idx = j + append_start; + size_t page_idx = page_cpu.indices[page_cpu.indptr[i] + page_seq_idx / page_size]; + size_t entry_idx = page_seq_idx % page_size; + for (size_t h = 0; h < num_heads; ++h) { + std::copy(ki.begin() + (j * num_heads + h) * head_dim, + ki.begin() + (j * num_heads + h + 1) * head_dim, + page_cpu.k_data + page_cpu.get_elem_offset(page_idx, h, entry_idx, 0)); + std::copy(vi.begin() + (j * num_heads + h) * head_dim, + vi.begin() + (j * num_heads + h + 1) * head_dim, + page_cpu.v_data + page_cpu.get_elem_offset(page_idx, h, entry_idx, 0)); + } } + } } -} // namespace cpu_reference +} // namespace cpu_reference diff --git a/libflashinfer/utils/flashinfer_batch_decode_test_ops.hip.h b/libflashinfer/utils/flashinfer_batch_decode_test_ops.hip.h index eabaa163eb..dfb9023165 100644 --- a/libflashinfer/utils/flashinfer_batch_decode_test_ops.hip.h +++ b/libflashinfer/utils/flashinfer_batch_decode_test_ops.hip.h @@ -3,143 +3,114 @@ // // SPDX - License - Identifier : Apache 2.0 +#include + #include "flashinfer/attention/generic/default_decode_params.cuh" #include "flashinfer/attention/generic/scheduler.cuh" #include "flashinfer/attention/generic/variants.cuh" - #include "gpu_iface/enums.hpp" #include "gpu_iface/layout.cuh" - #include "utils_hip.h" -#include - -namespace flashinfer -{ -class BatchDecodeHandler -{ -public: - template - hipError_t PlanDispatched(void *float_buffer, - size_t float_workspace_size_in_bytes, - void *int_buffer, - size_t int_workspace_size_in_bytes, - IdType *indptr_h, - IdType *last_page_len_h, - uint32_t batch_size, - uint32_t num_qo_heads, - uint32_t page_size) - { - int_buffer_ = int_buffer; - float_buffer_ = float_buffer; - using Params = BatchDecodeParams; - using AttentionVariant = DefaultAttention< - /*use_custom_mask=*/false, /*use_sliding_window=*/false, - /*use_logits_soft_cap=*/false, /*use_alibi=*/false>; - - auto work_estimation_func = - BatchDecodeWithPagedKVCacheWorkEstimationDispatched< - GROUP_SIZE, HEAD_DIM, POS_ENCODING_MODE, AttentionVariant, - Params>; - return DecodePlan( - float_buffer, float_workspace_size_in_bytes, int_buffer, - page_locked_buffer_, int_workspace_size_in_bytes, plan_info_, - indptr_h, batch_size, num_qo_heads, page_size, cuda_graph_enabled_, - stream_, work_estimation_func); - } - - void UpdatePageLockedBufferSize(size_t int_workspace_size_in_bytes) - { - hipFreeHost(page_locked_buffer_); - hipMallocHost(&page_locked_buffer_, int_workspace_size_in_bytes); +namespace flashinfer { +class BatchDecodeHandler { + public: + template + hipError_t PlanDispatched(void* float_buffer, size_t float_workspace_size_in_bytes, + void* int_buffer, size_t int_workspace_size_in_bytes, IdType* indptr_h, + IdType* last_page_len_h, uint32_t batch_size, uint32_t num_qo_heads, + uint32_t page_size) { + int_buffer_ = int_buffer; + float_buffer_ = float_buffer; + using Params = BatchDecodeParams; + using AttentionVariant = DefaultAttention< + /*use_custom_mask=*/false, /*use_sliding_window=*/false, + /*use_logits_soft_cap=*/false, /*use_alibi=*/false>; + + auto work_estimation_func = + BatchDecodeWithPagedKVCacheWorkEstimationDispatched; + return DecodePlan( + float_buffer, float_workspace_size_in_bytes, int_buffer, page_locked_buffer_, + int_workspace_size_in_bytes, plan_info_, indptr_h, batch_size, num_qo_heads, page_size, + cuda_graph_enabled_, stream_, work_estimation_func); + } + + void UpdatePageLockedBufferSize(size_t int_workspace_size_in_bytes) { + hipFreeHost(page_locked_buffer_); + hipMallocHost(&page_locked_buffer_, int_workspace_size_in_bytes); + } + + hipStream_t GetCUDAStream() const { return stream_; } + + void SetCUDAStream(hipStream_t stream) { stream_ = stream; } + + /*! + * \brief Constructor of BatchDecodeHandler + * \param enable_cuda_graph A boolean indicates whether to enable CUDA graph + * \param batch_size If enable_cuda_graph is true, we must specify a fixed + * batch_size + */ + BatchDecodeHandler(bool enable_cuda_graph = false, uint32_t batch_size = 0) + : cuda_graph_enabled_(enable_cuda_graph), stream_(nullptr) { + hipMallocHost(&page_locked_buffer_, 8 * 1024 * 1024); + } + ~BatchDecodeHandler() { hipFreeHost(page_locked_buffer_); } + + bool IsCUDAGraphEnabled() const { return cuda_graph_enabled_; } + + DecodePlanInfo GetPlanInfo() const { return plan_info_; } + + template + IdType* GetRequestIndices() { + return GetPtrFromBaseOffset(int_buffer_, plan_info_.request_indices_offset); + } + + template + IdType* GetKVTileIndices() { + return GetPtrFromBaseOffset(int_buffer_, plan_info_.kv_tile_indices_offset); + } + + template + IdType* GetOIndptr() { + return GetPtrFromBaseOffset(int_buffer_, plan_info_.o_indptr_offset); + } + + template + IdType* GetKVChunkSizePtr() { + return GetPtrFromBaseOffset(int_buffer_, plan_info_.kv_chunk_size_ptr_offset); + } + + template + DTypeO* GetTmpV() { + if (plan_info_.split_kv) { + return GetPtrFromBaseOffset(float_buffer_, plan_info_.v_offset); } + return nullptr; + } - hipStream_t GetCUDAStream() const { return stream_; } - - void SetCUDAStream(hipStream_t stream) { stream_ = stream; } - - /*! - * \brief Constructor of BatchDecodeHandler - * \param enable_cuda_graph A boolean indicates whether to enable CUDA graph - * \param batch_size If enable_cuda_graph is true, we must specify a fixed - * batch_size - */ - BatchDecodeHandler(bool enable_cuda_graph = false, uint32_t batch_size = 0) - : cuda_graph_enabled_(enable_cuda_graph), stream_(nullptr) - { - hipMallocHost(&page_locked_buffer_, 8 * 1024 * 1024); - } - ~BatchDecodeHandler() { hipFreeHost(page_locked_buffer_); } - - bool IsCUDAGraphEnabled() const { return cuda_graph_enabled_; } - - DecodePlanInfo GetPlanInfo() const { return plan_info_; } - - template IdType *GetRequestIndices() - { - return GetPtrFromBaseOffset(int_buffer_, - plan_info_.request_indices_offset); + float* GetTmpS() { + if (plan_info_.split_kv) { + return GetPtrFromBaseOffset(float_buffer_, plan_info_.s_offset); } + return nullptr; + } - template IdType *GetKVTileIndices() - { - return GetPtrFromBaseOffset(int_buffer_, - plan_info_.kv_tile_indices_offset); + bool* GetBlockValidMask() { + if (plan_info_.split_kv && plan_info_.enable_cuda_graph) { + return GetPtrFromBaseOffset(int_buffer_, plan_info_.block_valid_mask_offset); } - - template IdType *GetOIndptr() - { - return GetPtrFromBaseOffset(int_buffer_, - plan_info_.o_indptr_offset); - } - - template IdType *GetKVChunkSizePtr() - { - return GetPtrFromBaseOffset( - int_buffer_, plan_info_.kv_chunk_size_ptr_offset); - } - - template DTypeO *GetTmpV() - { - if (plan_info_.split_kv) { - return GetPtrFromBaseOffset(float_buffer_, - plan_info_.v_offset); - } - return nullptr; - } - - float *GetTmpS() - { - if (plan_info_.split_kv) { - return GetPtrFromBaseOffset(float_buffer_, - plan_info_.s_offset); - } - return nullptr; - } - - bool *GetBlockValidMask() - { - if (plan_info_.split_kv && plan_info_.enable_cuda_graph) { - return GetPtrFromBaseOffset( - int_buffer_, plan_info_.block_valid_mask_offset); - } - return nullptr; - } - -protected: - void *page_locked_buffer_; - void *int_buffer_; - void *float_buffer_; - DecodePlanInfo plan_info_; - bool cuda_graph_enabled_; - hipStream_t stream_; + return nullptr; + } + + protected: + void* page_locked_buffer_; + void* int_buffer_; + void* float_buffer_; + DecodePlanInfo plan_info_; + bool cuda_graph_enabled_; + hipStream_t stream_; }; /*! @@ -162,91 +133,71 @@ class BatchDecodeHandler */ template hipError_t BatchDecodeWithPagedKVCacheWrapper( - BatchDecodeHandler *handler, - DTypeQ *q, - IdType *q_rope_offset, - paged_kv_t paged_kv, - DTypeO *o, - float *lse, - uint32_t num_qo_heads, + BatchDecodeHandler* handler, DTypeQ* q, IdType* q_rope_offset, + paged_kv_t paged_kv, DTypeO* o, float* lse, uint32_t num_qo_heads, PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, - std::optional maybe_sm_scale = std::nullopt, - float rope_scale = 1.f, - float rope_theta = 1e4, - hipStream_t stream = nullptr) -{ - float sm_scale = - maybe_sm_scale.value_or(1.f / std::sqrt(float(paged_kv.head_dim))); - const uint32_t num_kv_heads = paged_kv.num_heads; - if (num_qo_heads % num_kv_heads != 0) { - std::ostringstream err_msg; - err_msg << "num_qo_heads " << num_qo_heads - << " is not a multiple of num_kv_heads " << num_kv_heads; - FLASHINFER_ERROR(err_msg.str()); - } - - DISPATCH_head_dim( - paged_kv.head_dim, HEAD_DIM, - {DISPATCH_pos_encoding_mode(pos_encoding_mode, POS_ENCODING_MODE, { - using Params = BatchDecodeParams; - using AttentionVariant = DefaultAttention< - /*use_custom_mask=*/false, /*use_sliding_window=*/false, - /*use_logits_soft_cap=*/false, /*use_alibi=*/false>; - Params params(q, q_rope_offset, paged_kv, o, lse, - /*alibi_slopes=*/nullptr, num_qo_heads, - /*q_stride_n*/ num_qo_heads * HEAD_DIM, - /*q_stride_h*/ HEAD_DIM, - /*window_left=*/-1, /*logits_soft_cap=*/0.f, sm_scale, - rope_scale, rope_theta); - params.request_indices = handler->GetRequestIndices(); - params.kv_tile_indices = handler->GetKVTileIndices(); - params.o_indptr = handler->GetOIndptr(); - params.kv_chunk_size_ptr = handler->GetKVChunkSizePtr(); - params.block_valid_mask = handler->GetBlockValidMask(); - params.padded_batch_size = handler->GetPlanInfo().padded_batch_size; - - return BatchDecodeWithPagedKVCacheDispatched< - HEAD_DIM, POS_ENCODING_MODE, AttentionVariant>( - params, handler->GetTmpV(), handler->GetTmpS(), stream); - })}); - return hipSuccess; + std::optional maybe_sm_scale = std::nullopt, float rope_scale = 1.f, + float rope_theta = 1e4, hipStream_t stream = nullptr) { + float sm_scale = maybe_sm_scale.value_or(1.f / std::sqrt(float(paged_kv.head_dim))); + const uint32_t num_kv_heads = paged_kv.num_heads; + if (num_qo_heads % num_kv_heads != 0) { + std::ostringstream err_msg; + err_msg << "num_qo_heads " << num_qo_heads << " is not a multiple of num_kv_heads " + << num_kv_heads; + FLASHINFER_ERROR(err_msg.str()); + } + + DISPATCH_head_dim( + paged_kv.head_dim, HEAD_DIM, + {DISPATCH_pos_encoding_mode(pos_encoding_mode, POS_ENCODING_MODE, { + using Params = BatchDecodeParams; + using AttentionVariant = DefaultAttention< + /*use_custom_mask=*/false, /*use_sliding_window=*/false, + /*use_logits_soft_cap=*/false, /*use_alibi=*/false>; + Params params(q, q_rope_offset, paged_kv, o, lse, + /*alibi_slopes=*/nullptr, num_qo_heads, + /*q_stride_n*/ num_qo_heads * HEAD_DIM, + /*q_stride_h*/ HEAD_DIM, + /*window_left=*/-1, /*logits_soft_cap=*/0.f, sm_scale, rope_scale, + rope_theta); + params.request_indices = handler->GetRequestIndices(); + params.kv_tile_indices = handler->GetKVTileIndices(); + params.o_indptr = handler->GetOIndptr(); + params.kv_chunk_size_ptr = handler->GetKVChunkSizePtr(); + params.block_valid_mask = handler->GetBlockValidMask(); + params.padded_batch_size = handler->GetPlanInfo().padded_batch_size; + + return BatchDecodeWithPagedKVCacheDispatched( + params, handler->GetTmpV(), handler->GetTmpS(), stream); + })}); + return hipSuccess; } template -hipError_t BatchDecodeHandlerPlan(BatchDecodeHandler *handler, - void *float_buffer, - size_t float_workspace_size_in_bytes, - void *int_buffer, - size_t int_workspace_size_in_bytes, - IdType *indptr_h, - IdType *last_page_len_h, - uint32_t batch_size, - uint32_t num_qo_heads, - uint32_t num_kv_heads, - uint32_t head_dim, - uint32_t page_size, - PosEncodingMode pos_encoding_mode) -{ - if (num_qo_heads % num_kv_heads != 0) { - std::ostringstream err_msg; - err_msg << "num_qo_heads " << num_qo_heads - << " should be divisible by num_kv_heads " << num_kv_heads; - FLASHINFER_ERROR(err_msg.str()); - } - DISPATCH_head_dim(head_dim, HEAD_DIM, { - DISPATCH_pos_encoding_mode(pos_encoding_mode, POS_ENCODING_MODE, { - DISPATCH_GQA_GROUP_SIZE(num_qo_heads / num_kv_heads, GROUP_SIZE, { - return handler - ->PlanDispatched( - float_buffer, float_workspace_size_in_bytes, int_buffer, - int_workspace_size_in_bytes, indptr_h, last_page_len_h, - batch_size, num_qo_heads, page_size); - }); - }); +hipError_t BatchDecodeHandlerPlan(BatchDecodeHandler* handler, void* float_buffer, + size_t float_workspace_size_in_bytes, void* int_buffer, + size_t int_workspace_size_in_bytes, IdType* indptr_h, + IdType* last_page_len_h, uint32_t batch_size, + uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim, + uint32_t page_size, PosEncodingMode pos_encoding_mode) { + if (num_qo_heads % num_kv_heads != 0) { + std::ostringstream err_msg; + err_msg << "num_qo_heads " << num_qo_heads << " should be divisible by num_kv_heads " + << num_kv_heads; + FLASHINFER_ERROR(err_msg.str()); + } + DISPATCH_head_dim(head_dim, HEAD_DIM, { + DISPATCH_pos_encoding_mode(pos_encoding_mode, POS_ENCODING_MODE, { + DISPATCH_GQA_GROUP_SIZE(num_qo_heads / num_kv_heads, GROUP_SIZE, { + return handler->PlanDispatched( + float_buffer, float_workspace_size_in_bytes, int_buffer, int_workspace_size_in_bytes, + indptr_h, last_page_len_h, batch_size, num_qo_heads, page_size); + }); }); + }); - return hipSuccess; + return hipSuccess; } -} // namespace flashinfer +} // namespace flashinfer diff --git a/libflashinfer/utils/flashinfer_prefill_ops.hip.h b/libflashinfer/utils/flashinfer_prefill_ops.hip.h index 5c282c4b8a..971907203d 100644 --- a/libflashinfer/utils/flashinfer_prefill_ops.hip.h +++ b/libflashinfer/utils/flashinfer_prefill_ops.hip.h @@ -13,74 +13,50 @@ #include "flashinfer/attention/generic/exception.h" #include "flashinfer/attention/generic/prefill.cuh" // #include "flashinfer/attention/generic/prefill_tester.cuh" +#include + #include "flashinfer/attention/generic/scheduler.cuh" #include "flashinfer/attention/generic/variants.cuh" - #include "gpu_iface/enums.hpp" #include "gpu_iface/layout.cuh" -#include -namespace flashinfer -{ +namespace flashinfer { -template -hipError_t SinglePrefillWithKVCacheDispatched(Params params, - typename Params::DTypeO *tmp, +hipError_t SinglePrefillWithKVCacheDispatched(Params params, typename Params::DTypeO* tmp, hipStream_t stream); template hipError_t SinglePrefillWithKVCacheCustomMask( - DTypeIn *q, - DTypeIn *k, - DTypeIn *v, - uint8_t *custom_mask, - DTypeO *o, - DTypeO *tmp, - float *lse, - uint32_t num_qo_heads, - uint32_t num_kv_heads, - uint32_t qo_len, - uint32_t kv_len, - uint32_t head_dim, - QKVLayout kv_layout = QKVLayout::kNHD, - PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, - bool use_fp16_qk_reduction = false, - std::optional maybe_sm_scale = std::nullopt, - float rope_scale = 1.f, - float rope_theta = 1e4, - hipStream_t stream = nullptr) -{ - const float sm_scale = 1.f; - auto [qo_stride_n, qo_stride_h, kv_stride_n, kv_stride_h] = get_qkv_strides( - kv_layout, kv_len, num_qo_heads, num_kv_heads, head_dim); - DISPATCH_use_fp16_qk_reduction( - static_cast(use_fp16_qk_reduction), USE_FP16_QK_REDUCTION, - {DISPATCH_head_dim( - head_dim, HEAD_DIM, - {DISPATCH_pos_encoding_mode(pos_encoding_mode, POS_ENCODING_MODE, { - using Params = SinglePrefillParams; - using AttentionVariant = DefaultAttention< - /*use_custom_mask=*/true, /*use_sliding_window=*/false, - /*use_logits_soft_cap=*/false, /*use_alibi=*/false>; - Params params(q, k, v, custom_mask, o, lse, - /*alibi_slopes=*/nullptr, num_qo_heads, - num_kv_heads, qo_len, kv_len, qo_stride_n, - qo_stride_h, kv_stride_n, kv_stride_h, head_dim, - /*window_left=*/-1, - /*logits_soft_cap=*/0.f, sm_scale, rope_scale, - rope_theta); - return SinglePrefillWithKVCacheDispatched< - HEAD_DIM, HEAD_DIM, POS_ENCODING_MODE, - USE_FP16_QK_REDUCTION, MaskMode::kCustom, AttentionVariant>( - params, tmp, stream); - })})}); - return hipSuccess; + DTypeIn* q, DTypeIn* k, DTypeIn* v, uint8_t* custom_mask, DTypeO* o, DTypeO* tmp, float* lse, + uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t qo_len, uint32_t kv_len, + uint32_t head_dim, QKVLayout kv_layout = QKVLayout::kNHD, + PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, bool use_fp16_qk_reduction = false, + std::optional maybe_sm_scale = std::nullopt, float rope_scale = 1.f, + float rope_theta = 1e4, hipStream_t stream = nullptr) { + const float sm_scale = 1.f; + auto [qo_stride_n, qo_stride_h, kv_stride_n, kv_stride_h] = + get_qkv_strides(kv_layout, kv_len, num_qo_heads, num_kv_heads, head_dim); + DISPATCH_use_fp16_qk_reduction( + static_cast(use_fp16_qk_reduction), USE_FP16_QK_REDUCTION, + {DISPATCH_head_dim( + head_dim, HEAD_DIM, {DISPATCH_pos_encoding_mode(pos_encoding_mode, POS_ENCODING_MODE, { + using Params = SinglePrefillParams; + using AttentionVariant = DefaultAttention< + /*use_custom_mask=*/true, /*use_sliding_window=*/false, + /*use_logits_soft_cap=*/false, /*use_alibi=*/false>; + Params params(q, k, v, custom_mask, o, lse, + /*alibi_slopes=*/nullptr, num_qo_heads, num_kv_heads, qo_len, kv_len, + qo_stride_n, qo_stride_h, kv_stride_n, kv_stride_h, head_dim, + /*window_left=*/-1, + /*logits_soft_cap=*/0.f, sm_scale, rope_scale, rope_theta); + return SinglePrefillWithKVCacheDispatched(params, tmp, stream); + })})}); + return hipSuccess; } /*! @@ -108,62 +84,43 @@ hipError_t SinglePrefillWithKVCacheCustomMask( * \return status Indicates whether hip calls are successful */ template -hipError_t SinglePrefillWithKVCache( - DTypeQ *q, - DTypeKV *k, - DTypeKV *v, - DTypeO *o, - DTypeO *tmp, - float *lse, - uint32_t num_qo_heads, - uint32_t num_kv_heads, - uint32_t qo_len, - uint32_t kv_len, - uint32_t head_dim, - bool causal = true, - QKVLayout kv_layout = QKVLayout::kNHD, - PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, - bool use_fp16_qk_reduction = false, - uint32_t debug_thread_id = 0, - uint32_t debug_warp_id = 0, - std::optional maybe_sm_scale = std::nullopt, - float rope_scale = 1.f, - float rope_theta = 1e4, - hipStream_t stream = nullptr) -{ - const float sm_scale = 1.f; - const MaskMode mask_mode = causal ? MaskMode::kCausal : MaskMode::kNone; - auto [qo_stride_n, qo_stride_h, kv_stride_n, kv_stride_h] = get_qkv_strides( - kv_layout, kv_len, num_qo_heads, num_kv_heads, head_dim); - DISPATCH_use_fp16_qk_reduction( - static_cast(use_fp16_qk_reduction), USE_FP16_QK_REDUCTION, - {DISPATCH_mask_mode( - mask_mode, MASK_MODE, - {DISPATCH_head_dim( - head_dim, HEAD_DIM, - {DISPATCH_pos_encoding_mode( - pos_encoding_mode, POS_ENCODING_MODE, { - using Params = - SinglePrefillParams; - using AttentionVariant = DefaultAttention< - /*use_custom_mask=*/(MASK_MODE == - MaskMode::kCustom), - /*use_sliding_window=*/false, - /*use_logits_soft_cap=*/true, /*use_alibi=*/false>; - Params params( - q, k, v, /*custom_mask=*/nullptr, o, lse, - /*alibi_slopes=*/nullptr, num_qo_heads, - num_kv_heads, qo_len, kv_len, qo_stride_n, - qo_stride_h, kv_stride_n, kv_stride_h, head_dim, - /*window_left=*/-1, - /*logits_soft_cap=*/8.f, sm_scale, rope_scale, - rope_theta, debug_thread_id, debug_warp_id); - return SinglePrefillWithKVCacheDispatched< - HEAD_DIM, HEAD_DIM, POS_ENCODING_MODE, - USE_FP16_QK_REDUCTION, MASK_MODE, AttentionVariant, - Params>(params, tmp, stream); - })})})}); - return hipSuccess; +hipError_t SinglePrefillWithKVCache(DTypeQ* q, DTypeKV* k, DTypeKV* v, DTypeO* o, DTypeO* tmp, + float* lse, uint32_t num_qo_heads, uint32_t num_kv_heads, + uint32_t qo_len, uint32_t kv_len, uint32_t head_dim, + bool causal = true, QKVLayout kv_layout = QKVLayout::kNHD, + PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, + bool use_fp16_qk_reduction = false, + uint32_t debug_thread_id = 0, uint32_t debug_warp_id = 0, + std::optional maybe_sm_scale = std::nullopt, + float rope_scale = 1.f, float rope_theta = 1e4, + hipStream_t stream = nullptr) { + const float sm_scale = 1.f; + const MaskMode mask_mode = causal ? MaskMode::kCausal : MaskMode::kNone; + auto [qo_stride_n, qo_stride_h, kv_stride_n, kv_stride_h] = + get_qkv_strides(kv_layout, kv_len, num_qo_heads, num_kv_heads, head_dim); + DISPATCH_use_fp16_qk_reduction( + static_cast(use_fp16_qk_reduction), USE_FP16_QK_REDUCTION, + {DISPATCH_mask_mode( + mask_mode, MASK_MODE, + {DISPATCH_head_dim(head_dim, HEAD_DIM, + {DISPATCH_pos_encoding_mode(pos_encoding_mode, POS_ENCODING_MODE, { + using Params = SinglePrefillParams; + using AttentionVariant = DefaultAttention< + /*use_custom_mask=*/(MASK_MODE == MaskMode::kCustom), + /*use_sliding_window=*/false, + /*use_logits_soft_cap=*/true, /*use_alibi=*/false>; + Params params(q, k, v, /*custom_mask=*/nullptr, o, lse, + /*alibi_slopes=*/nullptr, num_qo_heads, num_kv_heads, + qo_len, kv_len, qo_stride_n, qo_stride_h, kv_stride_n, + kv_stride_h, head_dim, + /*window_left=*/-1, + /*logits_soft_cap=*/8.f, sm_scale, rope_scale, + rope_theta, debug_thread_id, debug_warp_id); + return SinglePrefillWithKVCacheDispatched< + HEAD_DIM, HEAD_DIM, POS_ENCODING_MODE, USE_FP16_QK_REDUCTION, + MASK_MODE, AttentionVariant, Params>(params, tmp, stream); + })})})}); + return hipSuccess; } // template @@ -228,4 +185,4 @@ hipError_t SinglePrefillWithKVCache( // return hipSuccess; // } -} // namespace flashinfer +} // namespace flashinfer diff --git a/libflashinfer/utils/utils.h b/libflashinfer/utils/utils.h index e7c3f3a7af..637c3df4bc 100644 --- a/libflashinfer/utils/utils.h +++ b/libflashinfer/utils/utils.h @@ -32,227 +32,178 @@ #include "flashinfer/exception.h" #include "generated/dispatch.inc" -#define _DISPATCH_SWITCH(var_name, cond, ...) \ - switch (cond) { \ - __VA_ARGS__ \ - default: \ - std::ostringstream oss; \ - oss << __PRETTY_FUNCTION__ << " failed to dispatch " var_name " " \ - << int(cond); \ - FLASHINFER_ERROR(oss.str()); \ - } +#define _DISPATCH_SWITCH(var_name, cond, ...) \ + switch (cond) { \ + __VA_ARGS__ \ + default: \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch " var_name " " << int(cond); \ + FLASHINFER_ERROR(oss.str()); \ + } -#define _DISPATCH_CASE(case_expr, case_var, ...) \ - case case_expr: \ - { \ - constexpr auto case_var = case_expr; \ - __VA_ARGS__ \ - break; \ - } +#define _DISPATCH_CASE(case_expr, case_var, ...) \ + case case_expr: { \ + constexpr auto case_var = case_expr; \ + __VA_ARGS__ \ + break; \ + } -#define DISPATCH_group_size(expr, const_expr, ...) \ - _DISPATCH_SWITCH("group_size", expr, \ - _DISPATCH_CASES_group_size(const_expr, __VA_ARGS__)) +#define DISPATCH_group_size(expr, const_expr, ...) \ + _DISPATCH_SWITCH("group_size", expr, _DISPATCH_CASES_group_size(const_expr, __VA_ARGS__)) -#define DISPATCH_head_dim(expr, const_expr, ...) \ - _DISPATCH_SWITCH("head_dim", expr, \ - _DISPATCH_CASES_head_dim(const_expr, __VA_ARGS__)) +#define DISPATCH_head_dim(expr, const_expr, ...) \ + _DISPATCH_SWITCH("head_dim", expr, _DISPATCH_CASES_head_dim(const_expr, __VA_ARGS__)) -#define DISPATCH_pos_encoding_mode(expr, const_expr, ...) \ - _DISPATCH_SWITCH( \ - "positional encoding mode", expr, \ - _DISPATCH_CASES_pos_encoding_mode(const_expr, __VA_ARGS__)) +#define DISPATCH_pos_encoding_mode(expr, const_expr, ...) \ + _DISPATCH_SWITCH("positional encoding mode", expr, \ + _DISPATCH_CASES_pos_encoding_mode(const_expr, __VA_ARGS__)) -#define DISPATCH_use_fp16_qk_reduction(expr, const_expr, ...) \ - _DISPATCH_SWITCH( \ - "use_fp16_qk_reduction", expr, \ - _DISPATCH_CASES_use_fp16_qk_reduction(const_expr, __VA_ARGS__)) +#define DISPATCH_use_fp16_qk_reduction(expr, const_expr, ...) \ + _DISPATCH_SWITCH("use_fp16_qk_reduction", expr, \ + _DISPATCH_CASES_use_fp16_qk_reduction(const_expr, __VA_ARGS__)) -#define DISPATCH_mask_mode(expr, const_expr, ...) \ - _DISPATCH_SWITCH("mask_mode", expr, \ - _DISPATCH_CASES_mask_mode(const_expr, __VA_ARGS__)) +#define DISPATCH_mask_mode(expr, const_expr, ...) \ + _DISPATCH_SWITCH("mask_mode", expr, _DISPATCH_CASES_mask_mode(const_expr, __VA_ARGS__)) -namespace utils -{ +namespace utils { template -void vec_normal_(std::vector &vec, float mean = 0.f, float std = 1.f) -{ - std::random_device rd{}; - std::mt19937 gen{rd()}; - std::normal_distribution d{mean, std}; - for (size_t i = 0; i < vec.size(); ++i) { - vec[i] = T(d(gen)); - } +void vec_normal_(std::vector& vec, float mean = 0.f, float std = 1.f) { + std::random_device rd{}; + std::mt19937 gen{rd()}; + std::normal_distribution d{mean, std}; + for (size_t i = 0; i < vec.size(); ++i) { + vec[i] = T(d(gen)); + } } template -void vec_uniform_(std::vector &vec, float a = 0.f, float b = 1.f) -{ - std::random_device rd{}; - std::mt19937 gen{rd()}; - std::uniform_real_distribution d{a, b}; - for (size_t i = 0; i < vec.size(); ++i) { - vec[i] = T(d(gen)); - } +void vec_uniform_(std::vector& vec, float a = 0.f, float b = 1.f) { + std::random_device rd{}; + std::mt19937 gen{rd()}; + std::uniform_real_distribution d{a, b}; + for (size_t i = 0; i < vec.size(); ++i) { + vec[i] = T(d(gen)); + } } -template void vec_zero_(std::vector &vec) -{ - std::fill(vec.begin(), vec.end(), T(0)); +template +void vec_zero_(std::vector& vec) { + std::fill(vec.begin(), vec.end(), T(0)); } -template void vec_fill_(std::vector &vec, T val) -{ - std::fill(vec.begin(), vec.end(), val); +template +void vec_fill_(std::vector& vec, T val) { + std::fill(vec.begin(), vec.end(), val); } -template void vec_randint_(std::vector &vec, int low, int high) -{ - std::random_device rd{}; - std::mt19937 gen{rd()}; - std::uniform_int_distribution d{low, high}; - for (size_t i = 0; i < vec.size(); ++i) { - vec[i] = T(d(gen)); - } +template +void vec_randint_(std::vector& vec, int low, int high) { + std::random_device rd{}; + std::mt19937 gen{rd()}; + std::uniform_int_distribution d{low, high}; + for (size_t i = 0; i < vec.size(); ++i) { + vec[i] = T(d(gen)); + } } -template size_t vec_bytes(const T &vec) -{ - return vec.size() * sizeof(typename T::value_type); +template +size_t vec_bytes(const T& vec) { + return vec.size() * sizeof(typename T::value_type); } template -bool isclose(T a, T b, float rtol = 1e-5, float atol = 1e-8) -{ - return fabs(a - b) <= (atol + rtol * fabs(b)); +bool isclose(T a, T b, float rtol = 1e-5, float atol = 1e-8) { + return fabs(a - b) <= (atol + rtol * fabs(b)); } template std::tuple>, std::vector>> -create_shared_prefix_testcase_data(size_t batch_size, - size_t shared_prefix_length, - size_t unique_kv_length, - size_t qo_append_length, - size_t num_qo_heads, - size_t num_kv_heads, - size_t head_dim, - size_t page_size) -{ - uint32_t num_pages = - ((shared_prefix_length + unique_kv_length * batch_size) / page_size); - std::vector shared_k_h(shared_prefix_length * num_kv_heads * head_dim); - std::vector shared_v_h(shared_prefix_length * num_kv_heads * head_dim); - std::vector q_h((batch_size * qo_append_length) * num_qo_heads * - head_dim); - - utils::vec_normal_(shared_k_h); - utils::vec_normal_(shared_v_h); - utils::vec_normal_(q_h); - - std::vector qo_indptr{0}; - std::vector kv_indptr_combined_h{0}; - std::vector kv_indptr_unique_h{0}; - std::vector kv_last_page_len_combined_h; - std::vector kv_last_page_len_unique_h; - +create_shared_prefix_testcase_data(size_t batch_size, size_t shared_prefix_length, + size_t unique_kv_length, size_t qo_append_length, + size_t num_qo_heads, size_t num_kv_heads, size_t head_dim, + size_t page_size) { + uint32_t num_pages = ((shared_prefix_length + unique_kv_length * batch_size) / page_size); + std::vector shared_k_h(shared_prefix_length * num_kv_heads * head_dim); + std::vector shared_v_h(shared_prefix_length * num_kv_heads * head_dim); + std::vector q_h((batch_size * qo_append_length) * num_qo_heads * head_dim); + + utils::vec_normal_(shared_k_h); + utils::vec_normal_(shared_v_h); + utils::vec_normal_(q_h); + + std::vector qo_indptr{0}; + std::vector kv_indptr_combined_h{0}; + std::vector kv_indptr_unique_h{0}; + std::vector kv_last_page_len_combined_h; + std::vector kv_last_page_len_unique_h; + + for (uint32_t request_id = 0; request_id < batch_size; ++request_id) { + qo_indptr.push_back(qo_indptr.back() + qo_append_length); + kv_indptr_combined_h.push_back(kv_indptr_combined_h.back() + + (shared_prefix_length + unique_kv_length) / page_size); + kv_indptr_unique_h.push_back(kv_indptr_unique_h.back() + unique_kv_length / page_size); + kv_last_page_len_combined_h.push_back(page_size); + kv_last_page_len_unique_h.push_back(page_size); + } + + std::vector kv_indices_combined_h(kv_indptr_combined_h.back()); + std::vector kv_indices_unique_h(kv_indptr_unique_h.back()); + + std::vector k_data_h(num_pages * num_kv_heads * page_size * head_dim); + std::vector v_data_h(num_pages * num_kv_heads * page_size * head_dim); + uint32_t page_id = 0; + + for (; page_id < (shared_prefix_length / page_size); page_id++) { + for (uint32_t entry_idx = 0; entry_idx < page_size; entry_idx++) { + for (uint32_t head_idx = 0; head_idx < num_kv_heads; head_idx++) { + std::copy(shared_k_h.begin() + + ((page_id * page_size + entry_idx) * num_kv_heads + head_idx) * head_dim, + shared_k_h.begin() + + ((page_id * page_size + entry_idx) * num_kv_heads + head_idx + 1) * head_dim, + k_data_h.begin() + + ((page_id * num_kv_heads + head_idx) * page_size + entry_idx) * head_dim); + std::copy(shared_v_h.begin() + + ((page_id * page_size + entry_idx) * num_kv_heads + head_idx) * head_dim, + shared_v_h.begin() + + ((page_id * page_size + entry_idx) * num_kv_heads + head_idx + 1) * head_dim, + v_data_h.begin() + + ((page_id * num_kv_heads + head_idx) * page_size + entry_idx) * head_dim); + } + } for (uint32_t request_id = 0; request_id < batch_size; ++request_id) { - qo_indptr.push_back(qo_indptr.back() + qo_append_length); - kv_indptr_combined_h.push_back( - kv_indptr_combined_h.back() + - (shared_prefix_length + unique_kv_length) / page_size); - kv_indptr_unique_h.push_back(kv_indptr_unique_h.back() + - unique_kv_length / page_size); - kv_last_page_len_combined_h.push_back(page_size); - kv_last_page_len_unique_h.push_back(page_size); + kv_indices_combined_h[request_id * ((shared_prefix_length + unique_kv_length) / page_size) + + page_id] = page_id; } - - std::vector kv_indices_combined_h(kv_indptr_combined_h.back()); - std::vector kv_indices_unique_h(kv_indptr_unique_h.back()); - - std::vector k_data_h(num_pages * num_kv_heads * page_size * head_dim); - std::vector v_data_h(num_pages * num_kv_heads * page_size * head_dim); - uint32_t page_id = 0; - - for (; page_id < (shared_prefix_length / page_size); page_id++) { - for (uint32_t entry_idx = 0; entry_idx < page_size; entry_idx++) { - for (uint32_t head_idx = 0; head_idx < num_kv_heads; head_idx++) { - std::copy( - shared_k_h.begin() + - ((page_id * page_size + entry_idx) * num_kv_heads + - head_idx) * - head_dim, - shared_k_h.begin() + - ((page_id * page_size + entry_idx) * num_kv_heads + - head_idx + 1) * - head_dim, + } + + for (uint32_t request_id = 0; request_id < batch_size; ++request_id) { + for (uint32_t page_iter = 0; page_iter < (unique_kv_length / page_size); + ++page_iter, ++page_id) { + for (uint32_t entry_idx = 0; entry_idx < page_size; entry_idx++) { + for (uint32_t head_idx = 0; head_idx < num_kv_heads; head_idx++) { + std::vector k(head_dim), v(head_dim); + utils::vec_normal_(k); + utils::vec_normal_(v); + std::copy(k.begin(), k.end(), k_data_h.begin() + - ((page_id * num_kv_heads + head_idx) * page_size + - entry_idx) * - head_dim); - std::copy( - shared_v_h.begin() + - ((page_id * page_size + entry_idx) * num_kv_heads + - head_idx) * - head_dim, - shared_v_h.begin() + - ((page_id * page_size + entry_idx) * num_kv_heads + - head_idx + 1) * - head_dim, + ((page_id * num_kv_heads + head_idx) * page_size + entry_idx) * head_dim); + std::copy(v.begin(), v.end(), v_data_h.begin() + - ((page_id * num_kv_heads + head_idx) * page_size + - entry_idx) * - head_dim); - } - } - for (uint32_t request_id = 0; request_id < batch_size; ++request_id) { - kv_indices_combined_h[request_id * ((shared_prefix_length + - unique_kv_length) / - page_size) + - page_id] = page_id; - } - } - - for (uint32_t request_id = 0; request_id < batch_size; ++request_id) { - for (uint32_t page_iter = 0; page_iter < (unique_kv_length / page_size); - ++page_iter, ++page_id) - { - for (uint32_t entry_idx = 0; entry_idx < page_size; entry_idx++) { - for (uint32_t head_idx = 0; head_idx < num_kv_heads; head_idx++) - { - std::vector k(head_dim), v(head_dim); - utils::vec_normal_(k); - utils::vec_normal_(v); - std::copy( - k.begin(), k.end(), - k_data_h.begin() + - ((page_id * num_kv_heads + head_idx) * page_size + - entry_idx) * - head_dim); - std::copy( - v.begin(), v.end(), - v_data_h.begin() + - ((page_id * num_kv_heads + head_idx) * page_size + - entry_idx) * - head_dim); - } - } - kv_indices_combined_h - [request_id * - ((shared_prefix_length + unique_kv_length) / page_size) + - (shared_prefix_length / page_size) + page_iter] = page_id; - kv_indices_unique_h[request_id * (unique_kv_length / page_size) + - page_iter] = page_id; + ((page_id * num_kv_heads + head_idx) * page_size + entry_idx) * head_dim); } + } + kv_indices_combined_h[request_id * ((shared_prefix_length + unique_kv_length) / page_size) + + (shared_prefix_length / page_size) + page_iter] = page_id; + kv_indices_unique_h[request_id * (unique_kv_length / page_size) + page_iter] = page_id; } - return std::make_tuple>, - std::vector>>( - {std::move(q_h), std::move(shared_k_h), std::move(shared_v_h), - std::move(k_data_h), std::move(v_data_h)}, - {std::move(qo_indptr), std::move(kv_indices_combined_h), - std::move(kv_indices_unique_h), std::move(kv_indptr_combined_h), - std::move(kv_indptr_unique_h), std::move(kv_last_page_len_combined_h), - std::move(kv_last_page_len_unique_h)}); + } + return std::make_tuple>, std::vector>>( + {std::move(q_h), std::move(shared_k_h), std::move(shared_v_h), std::move(k_data_h), + std::move(v_data_h)}, + {std::move(qo_indptr), std::move(kv_indices_combined_h), std::move(kv_indices_unique_h), + std::move(kv_indptr_combined_h), std::move(kv_indptr_unique_h), + std::move(kv_last_page_len_combined_h), std::move(kv_last_page_len_unique_h)}); } -} // namespace utils +} // namespace utils diff --git a/libflashinfer/utils/utils_hip.h b/libflashinfer/utils/utils_hip.h index 8184bb59b4..c858ef82ae 100644 --- a/libflashinfer/utils/utils_hip.h +++ b/libflashinfer/utils/utils_hip.h @@ -5,8 +5,6 @@ #pragma once -#include "gpu_iface/conversion_utils.h" - #include #include #include @@ -16,266 +14,216 @@ #include #include "dispatch.inc" +#include "gpu_iface/conversion_utils.h" -#define _DISPATCH_SWITCH(var_name, cond, ...) \ - switch (cond) { \ - __VA_ARGS__ \ - default: \ - std::ostringstream oss; \ - oss << __PRETTY_FUNCTION__ << " failed to dispatch " var_name " " \ - << int(cond); \ - FLASHINFER_ERROR(oss.str()); \ - } - -#define _DISPATCH_CASE(case_expr, case_var, ...) \ - case case_expr: \ - { \ - constexpr auto case_var = case_expr; \ - __VA_ARGS__ \ - break; \ - } - -#define DISPATCH_group_size(expr, const_expr, ...) \ - _DISPATCH_SWITCH("group_size", expr, \ - _DISPATCH_CASES_group_size(const_expr, __VA_ARGS__)) - -#define DISPATCH_head_dim(expr, const_expr, ...) \ - _DISPATCH_SWITCH("head_dim", expr, \ - _DISPATCH_CASES_head_dim(const_expr, __VA_ARGS__)) - -#define DISPATCH_pos_encoding_mode(expr, const_expr, ...) \ - _DISPATCH_SWITCH( \ - "positional encoding mode", expr, \ - _DISPATCH_CASES_pos_encoding_mode(const_expr, __VA_ARGS__)) - -#define DISPATCH_use_fp16_qk_reduction(expr, const_expr, ...) \ - _DISPATCH_SWITCH( \ - "use_fp16_qk_reduction", expr, \ - _DISPATCH_CASES_use_fp16_qk_reduction(const_expr, __VA_ARGS__)) - -#define DISPATCH_mask_mode(expr, const_expr, ...) \ - _DISPATCH_SWITCH("mask_mode", expr, \ - _DISPATCH_CASES_mask_mode(const_expr, __VA_ARGS__)) - -namespace utils -{ - -enum Predicate -{ - Linear, - Ones, - Zeros, +#define _DISPATCH_SWITCH(var_name, cond, ...) \ + switch (cond) { \ + __VA_ARGS__ \ + default: \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch " var_name " " << int(cond); \ + FLASHINFER_ERROR(oss.str()); \ + } + +#define _DISPATCH_CASE(case_expr, case_var, ...) \ + case case_expr: { \ + constexpr auto case_var = case_expr; \ + __VA_ARGS__ \ + break; \ + } + +#define DISPATCH_group_size(expr, const_expr, ...) \ + _DISPATCH_SWITCH("group_size", expr, _DISPATCH_CASES_group_size(const_expr, __VA_ARGS__)) + +#define DISPATCH_head_dim(expr, const_expr, ...) \ + _DISPATCH_SWITCH("head_dim", expr, _DISPATCH_CASES_head_dim(const_expr, __VA_ARGS__)) + +#define DISPATCH_pos_encoding_mode(expr, const_expr, ...) \ + _DISPATCH_SWITCH("positional encoding mode", expr, \ + _DISPATCH_CASES_pos_encoding_mode(const_expr, __VA_ARGS__)) + +#define DISPATCH_use_fp16_qk_reduction(expr, const_expr, ...) \ + _DISPATCH_SWITCH("use_fp16_qk_reduction", expr, \ + _DISPATCH_CASES_use_fp16_qk_reduction(const_expr, __VA_ARGS__)) + +#define DISPATCH_mask_mode(expr, const_expr, ...) \ + _DISPATCH_SWITCH("mask_mode", expr, _DISPATCH_CASES_mask_mode(const_expr, __VA_ARGS__)) + +namespace utils { + +enum Predicate { + Linear, + Ones, + Zeros, }; -template void generate_data(std::vector &vec) -{ - if constexpr (Pred == Predicate::Linear) { - assert(vec.size() <= 0); - for (int i = 0; i < vec.size(); i++) { - vec[i] = fi::con::explicit_casting(static_cast(i)); - } +template +void generate_data(std::vector& vec) { + if constexpr (Pred == Predicate::Linear) { + assert(vec.size() <= 0); + for (int i = 0; i < vec.size(); i++) { + vec[i] = fi::con::explicit_casting(static_cast(i)); } + } - else if constexpr (Pred == Predicate::Ones) { - vec_fill_(vec, fi::con::explicit_casting(1.0f)); - } + else if constexpr (Pred == Predicate::Ones) { + vec_fill_(vec, fi::con::explicit_casting(1.0f)); + } - else if constexpr (Pred == Predicate::Zeros) { - vec_zero_(vec); - } + else if constexpr (Pred == Predicate::Zeros) { + vec_zero_(vec); + } } -template void vec_lexicographic_(std::vector &vec) -{ - for (int i = 0; i < vec.size(); i++) { - vec[i] = fi::con::explicit_casting(static_cast(i)); - } +template +void vec_lexicographic_(std::vector& vec) { + for (int i = 0; i < vec.size(); i++) { + vec[i] = fi::con::explicit_casting(static_cast(i)); + } } template -void vec_normal_(std::vector &vec, float mean = 0.f, float std = 1.f) -{ - std::random_device rd{}; - std::mt19937 gen{1234}; - std::normal_distribution d{mean, std}; - for (size_t i = 0; i < vec.size(); ++i) { - float value = static_cast(d(gen)); - vec[i] = fi::con::explicit_casting(value); - } +void vec_normal_(std::vector& vec, float mean = 0.f, float std = 1.f) { + std::random_device rd{}; + std::mt19937 gen{1234}; + std::normal_distribution d{mean, std}; + for (size_t i = 0; i < vec.size(); ++i) { + float value = static_cast(d(gen)); + vec[i] = fi::con::explicit_casting(value); + } } template -void vec_uniform_(std::vector &vec, float a = 0.f, float b = 1.f) -{ - std::random_device rd{}; - std::mt19937 gen{1234}; - std::uniform_real_distribution d{a, b}; - for (size_t i = 0; i < vec.size(); ++i) { - float value = static_cast(d(gen)); - vec[i] = fi::con::explicit_casting(value); - } +void vec_uniform_(std::vector& vec, float a = 0.f, float b = 1.f) { + std::random_device rd{}; + std::mt19937 gen{1234}; + std::uniform_real_distribution d{a, b}; + for (size_t i = 0; i < vec.size(); ++i) { + float value = static_cast(d(gen)); + vec[i] = fi::con::explicit_casting(value); + } } -template void vec_zero_(std::vector &vec) -{ - std::fill(vec.begin(), vec.end(), - fi::con::explicit_casting(0.0f)); +template +void vec_zero_(std::vector& vec) { + std::fill(vec.begin(), vec.end(), fi::con::explicit_casting(0.0f)); } -template void vec_fill_(std::vector &vec, T val) -{ - std::fill(vec.begin(), vec.end(), val); +template +void vec_fill_(std::vector& vec, T val) { + std::fill(vec.begin(), vec.end(), val); } -template void vec_randint_(std::vector &vec, int low, int high) -{ - std::random_device rd{}; - std::mt19937 gen{1234}; - std::uniform_int_distribution d{low, high}; - for (size_t i = 0; i < vec.size(); ++i) { - float value = static_cast(d(gen)); - vec[i] = fi::con::explicit_casting(value); - } +template +void vec_randint_(std::vector& vec, int low, int high) { + std::random_device rd{}; + std::mt19937 gen{1234}; + std::uniform_int_distribution d{low, high}; + for (size_t i = 0; i < vec.size(); ++i) { + float value = static_cast(d(gen)); + vec[i] = fi::con::explicit_casting(value); + } } -template size_t vec_bytes(const T &vec) -{ - return vec.size() * sizeof(typename T::value_type); +template +size_t vec_bytes(const T& vec) { + return vec.size() * sizeof(typename T::value_type); } template -bool isclose(T a, T b, float rtol = 1e-5, float atol = 1e-8) -{ - float a_ = fi::con::explicit_casting(a); - float b_ = fi::con::explicit_casting(b); - return fabs(a_ - b_) <= (atol + rtol * fabs(b_)); +bool isclose(T a, T b, float rtol = 1e-5, float atol = 1e-8) { + float a_ = fi::con::explicit_casting(a); + float b_ = fi::con::explicit_casting(b); + return fabs(a_ - b_) <= (atol + rtol * fabs(b_)); } template std::tuple>, std::vector>> -create_shared_prefix_testcase_data(size_t batch_size, - size_t shared_prefix_length, - size_t unique_kv_length, - size_t qo_append_length, - size_t num_qo_heads, - size_t num_kv_heads, - size_t head_dim, - size_t page_size) -{ - uint32_t num_pages = - ((shared_prefix_length + unique_kv_length * batch_size) / page_size); - std::vector shared_k_h(shared_prefix_length * num_kv_heads * head_dim); - std::vector shared_v_h(shared_prefix_length * num_kv_heads * head_dim); - std::vector q_h((batch_size * qo_append_length) * num_qo_heads * - head_dim); - - utils::vec_normal_(shared_k_h); - utils::vec_normal_(shared_v_h); - utils::vec_normal_(q_h); - - std::vector qo_indptr{0}; - std::vector kv_indptr_combined_h{0}; - std::vector kv_indptr_unique_h{0}; - std::vector kv_last_page_len_combined_h; - std::vector kv_last_page_len_unique_h; - +create_shared_prefix_testcase_data(size_t batch_size, size_t shared_prefix_length, + size_t unique_kv_length, size_t qo_append_length, + size_t num_qo_heads, size_t num_kv_heads, size_t head_dim, + size_t page_size) { + uint32_t num_pages = ((shared_prefix_length + unique_kv_length * batch_size) / page_size); + std::vector shared_k_h(shared_prefix_length * num_kv_heads * head_dim); + std::vector shared_v_h(shared_prefix_length * num_kv_heads * head_dim); + std::vector q_h((batch_size * qo_append_length) * num_qo_heads * head_dim); + + utils::vec_normal_(shared_k_h); + utils::vec_normal_(shared_v_h); + utils::vec_normal_(q_h); + + std::vector qo_indptr{0}; + std::vector kv_indptr_combined_h{0}; + std::vector kv_indptr_unique_h{0}; + std::vector kv_last_page_len_combined_h; + std::vector kv_last_page_len_unique_h; + + for (uint32_t request_id = 0; request_id < batch_size; ++request_id) { + qo_indptr.push_back(qo_indptr.back() + qo_append_length); + kv_indptr_combined_h.push_back(kv_indptr_combined_h.back() + + (shared_prefix_length + unique_kv_length) / page_size); + kv_indptr_unique_h.push_back(kv_indptr_unique_h.back() + unique_kv_length / page_size); + kv_last_page_len_combined_h.push_back(page_size); + kv_last_page_len_unique_h.push_back(page_size); + } + + std::vector kv_indices_combined_h(kv_indptr_combined_h.back()); + std::vector kv_indices_unique_h(kv_indptr_unique_h.back()); + + std::vector k_data_h(num_pages * num_kv_heads * page_size * head_dim); + std::vector v_data_h(num_pages * num_kv_heads * page_size * head_dim); + uint32_t page_id = 0; + + for (; page_id < (shared_prefix_length / page_size); page_id++) { + for (uint32_t entry_idx = 0; entry_idx < page_size; entry_idx++) { + for (uint32_t head_idx = 0; head_idx < num_kv_heads; head_idx++) { + std::copy(shared_k_h.begin() + + ((page_id * page_size + entry_idx) * num_kv_heads + head_idx) * head_dim, + shared_k_h.begin() + + ((page_id * page_size + entry_idx) * num_kv_heads + head_idx + 1) * head_dim, + k_data_h.begin() + + ((page_id * num_kv_heads + head_idx) * page_size + entry_idx) * head_dim); + std::copy(shared_v_h.begin() + + ((page_id * page_size + entry_idx) * num_kv_heads + head_idx) * head_dim, + shared_v_h.begin() + + ((page_id * page_size + entry_idx) * num_kv_heads + head_idx + 1) * head_dim, + v_data_h.begin() + + ((page_id * num_kv_heads + head_idx) * page_size + entry_idx) * head_dim); + } + } for (uint32_t request_id = 0; request_id < batch_size; ++request_id) { - qo_indptr.push_back(qo_indptr.back() + qo_append_length); - kv_indptr_combined_h.push_back( - kv_indptr_combined_h.back() + - (shared_prefix_length + unique_kv_length) / page_size); - kv_indptr_unique_h.push_back(kv_indptr_unique_h.back() + - unique_kv_length / page_size); - kv_last_page_len_combined_h.push_back(page_size); - kv_last_page_len_unique_h.push_back(page_size); + kv_indices_combined_h[request_id * ((shared_prefix_length + unique_kv_length) / page_size) + + page_id] = page_id; } - - std::vector kv_indices_combined_h(kv_indptr_combined_h.back()); - std::vector kv_indices_unique_h(kv_indptr_unique_h.back()); - - std::vector k_data_h(num_pages * num_kv_heads * page_size * head_dim); - std::vector v_data_h(num_pages * num_kv_heads * page_size * head_dim); - uint32_t page_id = 0; - - for (; page_id < (shared_prefix_length / page_size); page_id++) { - for (uint32_t entry_idx = 0; entry_idx < page_size; entry_idx++) { - for (uint32_t head_idx = 0; head_idx < num_kv_heads; head_idx++) { - std::copy( - shared_k_h.begin() + - ((page_id * page_size + entry_idx) * num_kv_heads + - head_idx) * - head_dim, - shared_k_h.begin() + - ((page_id * page_size + entry_idx) * num_kv_heads + - head_idx + 1) * - head_dim, + } + + for (uint32_t request_id = 0; request_id < batch_size; ++request_id) { + for (uint32_t page_iter = 0; page_iter < (unique_kv_length / page_size); + ++page_iter, ++page_id) { + for (uint32_t entry_idx = 0; entry_idx < page_size; entry_idx++) { + for (uint32_t head_idx = 0; head_idx < num_kv_heads; head_idx++) { + std::vector k(head_dim), v(head_dim); + utils::vec_normal_(k); + utils::vec_normal_(v); + std::copy(k.begin(), k.end(), k_data_h.begin() + - ((page_id * num_kv_heads + head_idx) * page_size + - entry_idx) * - head_dim); - std::copy( - shared_v_h.begin() + - ((page_id * page_size + entry_idx) * num_kv_heads + - head_idx) * - head_dim, - shared_v_h.begin() + - ((page_id * page_size + entry_idx) * num_kv_heads + - head_idx + 1) * - head_dim, + ((page_id * num_kv_heads + head_idx) * page_size + entry_idx) * head_dim); + std::copy(v.begin(), v.end(), v_data_h.begin() + - ((page_id * num_kv_heads + head_idx) * page_size + - entry_idx) * - head_dim); - } - } - for (uint32_t request_id = 0; request_id < batch_size; ++request_id) { - kv_indices_combined_h[request_id * ((shared_prefix_length + - unique_kv_length) / - page_size) + - page_id] = page_id; - } - } - - for (uint32_t request_id = 0; request_id < batch_size; ++request_id) { - for (uint32_t page_iter = 0; page_iter < (unique_kv_length / page_size); - ++page_iter, ++page_id) - { - for (uint32_t entry_idx = 0; entry_idx < page_size; entry_idx++) { - for (uint32_t head_idx = 0; head_idx < num_kv_heads; head_idx++) - { - std::vector k(head_dim), v(head_dim); - utils::vec_normal_(k); - utils::vec_normal_(v); - std::copy( - k.begin(), k.end(), - k_data_h.begin() + - ((page_id * num_kv_heads + head_idx) * page_size + - entry_idx) * - head_dim); - std::copy( - v.begin(), v.end(), - v_data_h.begin() + - ((page_id * num_kv_heads + head_idx) * page_size + - entry_idx) * - head_dim); - } - } - kv_indices_combined_h - [request_id * - ((shared_prefix_length + unique_kv_length) / page_size) + - (shared_prefix_length / page_size) + page_iter] = page_id; - kv_indices_unique_h[request_id * (unique_kv_length / page_size) + - page_iter] = page_id; + ((page_id * num_kv_heads + head_idx) * page_size + entry_idx) * head_dim); } + } + kv_indices_combined_h[request_id * ((shared_prefix_length + unique_kv_length) / page_size) + + (shared_prefix_length / page_size) + page_iter] = page_id; + kv_indices_unique_h[request_id * (unique_kv_length / page_size) + page_iter] = page_id; } - return std::make_tuple>, - std::vector>>( - {std::move(q_h), std::move(shared_k_h), std::move(shared_v_h), - std::move(k_data_h), std::move(v_data_h)}, - {std::move(qo_indptr), std::move(kv_indices_combined_h), - std::move(kv_indices_unique_h), std::move(kv_indptr_combined_h), - std::move(kv_indptr_unique_h), std::move(kv_last_page_len_combined_h), - std::move(kv_last_page_len_unique_h)}); + } + return std::make_tuple>, std::vector>>( + {std::move(q_h), std::move(shared_k_h), std::move(shared_v_h), std::move(k_data_h), + std::move(v_data_h)}, + {std::move(qo_indptr), std::move(kv_indices_combined_h), std::move(kv_indices_unique_h), + std::move(kv_indptr_combined_h), std::move(kv_indptr_unique_h), + std::move(kv_last_page_len_combined_h), std::move(kv_last_page_len_unique_h)}); } -} // namespace utils +} // namespace utils From 41f340a63ecea12abc4464f415db4a08391818cb Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Tue, 16 Sep 2025 11:06:05 -0400 Subject: [PATCH 082/109] Reformat flashinfer/csrc --- flashinfer/csrc/pytorch_extension_utils.h | 375 ++++++++++------------ 1 file changed, 174 insertions(+), 201 deletions(-) diff --git a/flashinfer/csrc/pytorch_extension_utils.h b/flashinfer/csrc/pytorch_extension_utils.h index 9169d0fe64..129c5b8210 100644 --- a/flashinfer/csrc/pytorch_extension_utils.h +++ b/flashinfer/csrc/pytorch_extension_utils.h @@ -62,26 +62,24 @@ The import from Python will load the .so consisting of the file in this extension, so that the TORCH_LIBRARY_FRAGMENT static initializers are run. */ -#define FLASHINFER_EXT_MODULE_INIT(name) \ - extern "C" \ - { \ - __attribute__((weak)) PyObject *PyInit_##name(void) \ - { \ - static struct PyModuleDef module_def = { \ - PyModuleDef_HEAD_INIT, \ - #name, /* name of module */ \ - NULL, /* module documentation, may be NULL */ \ - -1, /* size of per-interpreter state of the module, \ - or -1 if the module keeps state in global variables. */ \ - NULL, /* methods */ \ - NULL, /* slots */ \ - NULL, /* traverse */ \ - NULL, /* clear */ \ - NULL, /* free */ \ - }; \ - return PyModule_Create(&module_def); \ - } \ - } +#define FLASHINFER_EXT_MODULE_INIT(name) \ + extern "C" { \ + __attribute__((weak)) PyObject* PyInit_##name(void) { \ + static struct PyModuleDef module_def = { \ + PyModuleDef_HEAD_INIT, \ + #name, /* name of module */ \ + NULL, /* module documentation, may be NULL */ \ + -1, /* size of per-interpreter state of the module, \ + or -1 if the module keeps state in global variables. */ \ + NULL, /* methods */ \ + NULL, /* slots */ \ + NULL, /* traverse */ \ + NULL, /* clear */ \ + NULL, /* free */ \ + }; \ + return PyModule_Create(&module_def); \ + } \ + } FLASHINFER_EXT_MODULE_INIT_EXPAND(TORCH_EXTENSION_NAME) @@ -91,223 +89,198 @@ FLASHINFER_EXT_MODULE_INIT_EXPAND(TORCH_EXTENSION_NAME) #endif #ifdef FLASHINFER_ENABLE_HIP - #ifdef FLASHINFER_ENABLE_F16 - using dtype_half = __half; - #endif - #ifdef FLASHINFER_ENABLE_BF16 - using dtype_bfloat16 = __hip_bfloat16; - #endif - #if defined(FLASHINFER_ENABLE_FP8_E4M3) || defined(FLASHINFER_ENABLE_FP8_E5M2) - using dtype_fp8_e4m3 = __hip_fp8_e4m3_fnuz; - using dtype_fp8_e5m2 = __hip_fp8_e5m2_fnuz; - #endif +#ifdef FLASHINFER_ENABLE_F16 +using dtype_half = __half; +#endif +#ifdef FLASHINFER_ENABLE_BF16 +using dtype_bfloat16 = __hip_bfloat16; +#endif +#if defined(FLASHINFER_ENABLE_FP8_E4M3) || defined(FLASHINFER_ENABLE_FP8_E5M2) +using dtype_fp8_e4m3 = __hip_fp8_e4m3_fnuz; +using dtype_fp8_e5m2 = __hip_fp8_e5m2_fnuz; +#endif #else - #ifdef FLASHINFER_ENABLE_F16 - using dtype_half = nv_half; - #endif - #ifdef FLASHINFER_ENABLE_BF16 - using dtype_bfloat16 = nv_bfloat16; - #endif - #if defined(FLASHINFER_ENABLE_FP8_E4M3) || defined(FLASHINFER_ENABLE_FP8_E5M2) - using dtype_fp8_e4m3 = nv_fp8_e4m3; - using dtype_fp8_e5m2 = nv_fp8_e5m2; - #endif +#ifdef FLASHINFER_ENABLE_F16 +using dtype_half = nv_half; +#endif +#ifdef FLASHINFER_ENABLE_BF16 +using dtype_bfloat16 = nv_bfloat16; +#endif +#if defined(FLASHINFER_ENABLE_FP8_E4M3) || defined(FLASHINFER_ENABLE_FP8_E5M2) +using dtype_fp8_e4m3 = nv_fp8_e4m3; +using dtype_fp8_e5m2 = nv_fp8_e5m2; +#endif #endif #ifdef FLASHINFER_ENABLE_F16 -#define _DISPATCH_CASE_F16(c_type, ...) \ - case at::ScalarType::Half: \ - { \ - using c_type = dtype_half; \ - return __VA_ARGS__(); \ - } +#define _DISPATCH_CASE_F16(c_type, ...) \ + case at::ScalarType::Half: { \ + using c_type = dtype_half; \ + return __VA_ARGS__(); \ + } #else #define _DISPATCH_CASE_F16(c_type, ...) #endif #ifdef FLASHINFER_ENABLE_BF16 -#define _DISPATCH_CASE_BF16(c_type, ...) \ - case at::ScalarType::BFloat16: \ - { \ - using c_type = dtype_bfloat16; \ - return __VA_ARGS__(); \ - } +#define _DISPATCH_CASE_BF16(c_type, ...) \ + case at::ScalarType::BFloat16: { \ + using c_type = dtype_bfloat16; \ + return __VA_ARGS__(); \ + } #else #define _DISPATCH_CASE_BF16(c_type, ...) #endif #ifdef FLASHINFER_ENABLE_FP8_E4M3 -#define _DISPATCH_CASE_FP8_E4M3(c_type, ...) \ - case at::ScalarType::Float8_e4m3fn: \ - { \ - using c_type = dtype_fp8_e4m3; \ - return __VA_ARGS__(); \ - } +#define _DISPATCH_CASE_FP8_E4M3(c_type, ...) \ + case at::ScalarType::Float8_e4m3fn: { \ + using c_type = dtype_fp8_e4m3; \ + return __VA_ARGS__(); \ + } #else #define _DISPATCH_CASE_FP8_E4M3(c_type, ...) #endif #ifdef FLASHINFER_ENABLE_FP8_E5M2 -#define _DISPATCH_CASE_FP8_E5M2(c_type, ...) \ - case at::ScalarType::Float8_e5m2: \ - { \ - using c_type = dtype_fp8_e5m2; \ - return __VA_ARGS__(); \ - } +#define _DISPATCH_CASE_FP8_E5M2(c_type, ...) \ + case at::ScalarType::Float8_e5m2: { \ + using c_type = dtype_fp8_e5m2; \ + return __VA_ARGS__(); \ + } #else #define _DISPATCH_CASE_FP8_E5M2(c_type, ...) #endif -#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(pytorch_dtype, c_type, ...) \ - [&]() -> bool { \ - switch (pytorch_dtype) { \ - _DISPATCH_CASE_F16(c_type, __VA_ARGS__) \ - _DISPATCH_CASE_BF16(c_type, __VA_ARGS__) \ - default: \ - std::ostringstream oss; \ - oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " \ - << pytorch_dtype; \ - TORCH_CHECK(false, oss.str()); \ - return false; \ - } \ - }() - -#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(pytorch_dtype, c_type, ...) \ - [&]() -> bool { \ - switch (pytorch_dtype) { \ - _DISPATCH_CASE_FP8_E4M3(c_type, __VA_ARGS__) \ - _DISPATCH_CASE_FP8_E5M2(c_type, __VA_ARGS__) \ - default: \ - std::ostringstream oss; \ - oss << __PRETTY_FUNCTION__ << " failed to dispatch fp8 data type " \ - << pytorch_dtype; \ - TORCH_CHECK(false, oss.str()); \ - return false; \ - } \ - }() - -#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE(pytorch_dtype, c_type, ...) \ - [&]() -> bool { \ - switch (pytorch_dtype) { \ - _DISPATCH_CASE_F16(c_type, __VA_ARGS__) \ - _DISPATCH_CASE_BF16(c_type, __VA_ARGS__) \ - _DISPATCH_CASE_FP8_E4M3(c_type, __VA_ARGS__) \ - _DISPATCH_CASE_FP8_E5M2(c_type, __VA_ARGS__) \ - default: \ - std::ostringstream oss; \ - oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " \ - << pytorch_dtype; \ - TORCH_CHECK(false, oss.str()); \ - return false; \ - } \ - }() - -#define _DISPATCH_SWITCH(var_name, cond, ...) \ - [&]() -> bool { \ - switch (cond) { \ - __VA_ARGS__ \ - default: \ - std::ostringstream oss; \ - oss << __PRETTY_FUNCTION__ << " failed to dispatch " var_name " " \ - << int(cond); \ - TORCH_CHECK(false, oss.str()); \ - return false; \ - } \ - }() - -#define _DISPATCH_SWITCH_U16x2(var1_name, var2_name, cond1, cond2, ...) \ - [&]() -> bool { \ - switch (pack_u16(cond1, cond2)) { \ - __VA_ARGS__ \ - default: \ - std::ostringstream oss; \ - oss << __PRETTY_FUNCTION__ \ - << " failed to dispatch (" var1_name ", " var2_name "): (" \ - << int(cond1) << ", " << int(cond2) << ")"; \ - TORCH_CHECK(false, oss.str()); \ - return false; \ - } \ - }() - -#define _DISPATCH_CASE(case_expr, case_var, ...) \ - case case_expr: \ - { \ - constexpr auto case_var = case_expr; \ - return __VA_ARGS__(); \ - } - -#define _DISPATCH_CASE_U16x2(case_expr1, case_expr2, case_var1, case_var2, \ - ...) \ - case pack_u16(case_expr1, case_expr2): \ - { \ - constexpr auto case_var1 = case_expr1; \ - constexpr auto case_var2 = case_expr2; \ - return __VA_ARGS__(); \ - } - -#define DISPATCH_BOOL(expr, const_expr, ...) \ - [&]() -> bool { \ - if (expr) { \ - constexpr bool const_expr = true; \ - return __VA_ARGS__(); \ - } \ - else { \ - constexpr bool const_expr = false; \ - return __VA_ARGS__(); \ - } \ - }() - -inline void check_shape(const at::Tensor &a, - const at::Tensor &b, - const char *a_name, - const char *b_name) -{ - TORCH_CHECK(a.dim() == b.dim(), a_name, ".dim() != ", b_name, ".dim(). ", - a.dim(), " vs ", b.dim()); - for (int i = 0; i < a.dim(); ++i) { - TORCH_CHECK(a.size(i) == b.size(i), a_name, ".size(", i, - ") != ", b_name, ".size(", i, ")"); - } +#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16(pytorch_dtype, c_type, ...) \ + [&]() -> bool { \ + switch (pytorch_dtype) { \ + _DISPATCH_CASE_F16(c_type, __VA_ARGS__) \ + _DISPATCH_CASE_BF16(c_type, __VA_ARGS__) \ + default: \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \ + TORCH_CHECK(false, oss.str()); \ + return false; \ + } \ + }() + +#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP8(pytorch_dtype, c_type, ...) \ + [&]() -> bool { \ + switch (pytorch_dtype) { \ + _DISPATCH_CASE_FP8_E4M3(c_type, __VA_ARGS__) \ + _DISPATCH_CASE_FP8_E5M2(c_type, __VA_ARGS__) \ + default: \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch fp8 data type " << pytorch_dtype; \ + TORCH_CHECK(false, oss.str()); \ + return false; \ + } \ + }() + +#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE(pytorch_dtype, c_type, ...) \ + [&]() -> bool { \ + switch (pytorch_dtype) { \ + _DISPATCH_CASE_F16(c_type, __VA_ARGS__) \ + _DISPATCH_CASE_BF16(c_type, __VA_ARGS__) \ + _DISPATCH_CASE_FP8_E4M3(c_type, __VA_ARGS__) \ + _DISPATCH_CASE_FP8_E5M2(c_type, __VA_ARGS__) \ + default: \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \ + TORCH_CHECK(false, oss.str()); \ + return false; \ + } \ + }() + +#define _DISPATCH_SWITCH(var_name, cond, ...) \ + [&]() -> bool { \ + switch (cond) { \ + __VA_ARGS__ \ + default: \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch " var_name " " << int(cond); \ + TORCH_CHECK(false, oss.str()); \ + return false; \ + } \ + }() + +#define _DISPATCH_SWITCH_U16x2(var1_name, var2_name, cond1, cond2, ...) \ + [&]() -> bool { \ + switch (pack_u16(cond1, cond2)) { \ + __VA_ARGS__ \ + default: \ + std::ostringstream oss; \ + oss << __PRETTY_FUNCTION__ << " failed to dispatch (" var1_name ", " var2_name "): (" \ + << int(cond1) << ", " << int(cond2) << ")"; \ + TORCH_CHECK(false, oss.str()); \ + return false; \ + } \ + }() + +#define _DISPATCH_CASE(case_expr, case_var, ...) \ + case case_expr: { \ + constexpr auto case_var = case_expr; \ + return __VA_ARGS__(); \ + } + +#define _DISPATCH_CASE_U16x2(case_expr1, case_expr2, case_var1, case_var2, ...) \ + case pack_u16(case_expr1, case_expr2): { \ + constexpr auto case_var1 = case_expr1; \ + constexpr auto case_var2 = case_expr2; \ + return __VA_ARGS__(); \ + } + +#define DISPATCH_BOOL(expr, const_expr, ...) \ + [&]() -> bool { \ + if (expr) { \ + constexpr bool const_expr = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr bool const_expr = false; \ + return __VA_ARGS__(); \ + } \ + }() + +inline void check_shape(const at::Tensor& a, const at::Tensor& b, const char* a_name, + const char* b_name) { + TORCH_CHECK(a.dim() == b.dim(), a_name, ".dim() != ", b_name, ".dim(). ", a.dim(), " vs ", + b.dim()); + for (int i = 0; i < a.dim(); ++i) { + TORCH_CHECK(a.size(i) == b.size(i), a_name, ".size(", i, ") != ", b_name, ".size(", i, ")"); + } } -inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) -{ - return (uint32_t(a) << 16) | uint32_t(b); +inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) { + return (uint32_t(a) << 16) | uint32_t(b); } #define CHECK_GQA_HEAD_DIVISIBLE(num_qo_heads, num_kv_heads) \ - TORCH_CHECK(num_qo_heads % num_kv_heads == 0, "num_qo_heads(", \ - num_qo_heads, ") must be divisible by num_kv_heads(", \ - num_kv_heads, ")") + TORCH_CHECK(num_qo_heads % num_kv_heads == 0, "num_qo_heads(", num_qo_heads, \ + ") must be divisible by num_kv_heads(", num_kv_heads, ")") #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") -#define CHECK_CONTIGUOUS(x) \ - TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") -#define CHECK_LAST_DIM_CONTIGUOUS(x) \ - TORCH_CHECK(x.strides()[x.strides().size() - 1] == 1, \ - #x "must be contiguous at last dimension") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_LAST_DIM_CONTIGUOUS(x) \ + TORCH_CHECK(x.strides()[x.strides().size() - 1] == 1, #x "must be contiguous at last dimension") -#define CHECK_INPUT(x) \ - CHECK_CUDA(x); \ - CHECK_CONTIGUOUS(x) -#define CHECK_LAST_DIM_CONTIGUOUS_INPUT(x) \ - CHECK_CUDA(x); \ - CHECK_LAST_DIM_CONTIGUOUS(x) +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) +#define CHECK_LAST_DIM_CONTIGUOUS_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_LAST_DIM_CONTIGUOUS(x) -#define CHECK_DIM(d, x) \ - TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor") +#define CHECK_DIM(d, x) TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor") #define CHECK_SHAPE(a, b) check_shape(a, b, #a, #b) -#define CHECK_EQ(a, b) \ - TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b) +#define CHECK_EQ(a, b) TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b) -#define CHECK_GE(a, b) \ - TORCH_CHECK((a) >= (b), "CHECK_GE(" #a ", " #b ") failed. ", a, " vs ", b) +#define CHECK_GE(a, b) TORCH_CHECK((a) >= (b), "CHECK_GE(" #a ", " #b ") failed. ", a, " vs ", b) -inline bool is_float8_tensor(const at::Tensor &tensor) -{ - return tensor.scalar_type() == at::ScalarType::Float8_e4m3fn || - tensor.scalar_type() == at::ScalarType::Float8_e5m2; +inline bool is_float8_tensor(const at::Tensor& tensor) { + return tensor.scalar_type() == at::ScalarType::Float8_e4m3fn || + tensor.scalar_type() == at::ScalarType::Float8_e5m2; } From 28327bb3e804c0767623edb71585fdfae9f96542 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Tue, 16 Sep 2025 11:59:17 -0400 Subject: [PATCH 083/109] Remove leftover file --- libflashinfer/include/gpu_iface/fragment.hpp | 123 ------------------- 1 file changed, 123 deletions(-) delete mode 100644 libflashinfer/include/gpu_iface/fragment.hpp diff --git a/libflashinfer/include/gpu_iface/fragment.hpp b/libflashinfer/include/gpu_iface/fragment.hpp deleted file mode 100644 index 558a503a02..0000000000 --- a/libflashinfer/include/gpu_iface/fragment.hpp +++ /dev/null @@ -1,123 +0,0 @@ -// SPDX-FileCopyrightText: 2025 Advanced Micro Devices, Inc. -// -// SPDX-License-Identifier: Apache-2.0 - -#pragma once - -#include "gpu_iface/mma_types.hpp" -#include "gpu_iface/platform.hpp" - -#ifdef PLATFORM_HIP_DEVICE -#include -#endif - -namespace flashinfer { -namespace gpu_iface { -namespace mma { - -enum class FragmentType { - row_major, // Row-major matrix layout - col_major, // Column-major matrix layout - accumulator // Accumulator (no layout) -}; - -template -struct fragment_t { - using value_type = T; -#ifdef PLATFORM_CUDA_DEVICE - // flashinfer's generic CUDA implementation uses raw arrays for matrix - // fragments and the interface is designed to accomodate use of raw arrays - // for such use cases. - static constexpr int elements_per_thread = (frag_type == FragmentType::accumulator) ? 8 - : (sizeof(T) == 1) ? 8 - : 4; - - // Number of 32-bit registers needed - static constexpr int num_regs = (elements_per_thread * sizeof(T) + 3) / 4; - - uint32_t data[num_regs]; - - // Provide array-like access - __device__ __forceinline__ T& operator[](int i) { return reinterpret_cast(data)[i]; } - __device__ __forceinline__ const T& operator[](int i) const { - return reinterpret_cast(data)[i]; - } - - // Get number of elements this thread holds - __device__ __forceinline__ constexpr int size() const { return elements_per_thread; } - - // Get raw pointer for MMA operations - __device__ __forceinline__ uint32_t* raw_ptr() { return data; } - __device__ __forceinline__ const uint32_t* raw_ptr() const { return data; } - -#elif defined(PLATFORM_HIP_DEVICE) - // AMD: Use rocWMMA fragments - using rocwmma_layout = - typename std::conditional::type>::type; - - using rocwmma_matrix_t = typename std::conditional< - frag_type == FragmentType::row_major, rocwmma::matrix_a, - typename std::conditional::type>::type; - - // Select appropriate fragment type based on whether it's accumulator or not - using rocwmma_frag_t = typename std::conditional< - frag_type == FragmentType::accumulator, rocwmma::fragment, - rocwmma::fragment >::type; - - rocwmma_frag_t frag; - - // Provide array-like access that maps to rocWMMA fragment - __device__ __forceinline__ T operator[](int i) const { return frag.x[i]; } - - // For non-const access, we need to provide a setter since we can't return a - // reference - __device__ __forceinline__ void set(int i, T value) { frag.x[i] = value; } - - // Get number of elements this thread holds - __device__ __forceinline__ int size() const { return frag.num_elements; } - - // Get raw pointer for operations that need it - __device__ __forceinline__ rocwmma_frag_t* raw_ptr() { return &frag; } - __device__ __forceinline__ const rocwmma_frag_t* raw_ptr() const { return &frag; } -#endif - - // Common interface - update fill method to use setter for HIP - __device__ __forceinline__ void fill(T value) { -#ifdef PLATFORM_CUDA_DEVICE -#pragma unroll - for (int i = 0; i < elements_per_thread; ++i) { - (*this)[i] = value; - } -#elif defined(PLATFORM_HIP_DEVICE) - rocwmma::fill_fragment(frag, value); -#endif - } -}; - -// Convenience typedefs for common fragment types -template -using row_major_fragment_m16n16k16 = fragment_t; - -template -using col_major_fragment_m16n16k16 = fragment_t; - -template -using accumulator_fragment_m16n16k16 = fragment_t; - -// Helper to get compile-time fragment size -template -struct fragment_traits { -#ifdef PLATFORM_CUDA_DEVICE - static constexpr int size = Fragment::elements_per_thread; -#elif defined(PLATFORM_HIP_DEVICE) - // For HIP, we can't make this constexpr, so provide a device function - __device__ static int get_size(const Fragment& f) { return f.size(); } -#endif -}; - -} // namespace mma -} // namespace gpu_iface -} // namespace flashinfer From ef5f6a1f1eb725ff90ad7d32eeacb15cfc0bcec5 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Tue, 16 Sep 2025 12:54:34 -0400 Subject: [PATCH 084/109] rever frag_layout_swizzle.cuh --- .../attention/generic/frag_layout_swizzle.cuh | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/libflashinfer/include/flashinfer/attention/generic/frag_layout_swizzle.cuh b/libflashinfer/include/flashinfer/attention/generic/frag_layout_swizzle.cuh index 0b74f89550..03bc2a3cb1 100644 --- a/libflashinfer/include/flashinfer/attention/generic/frag_layout_swizzle.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/frag_layout_swizzle.cuh @@ -10,20 +10,27 @@ #include "gpu_iface/platform.hpp" +// Define platform-specific full mask for warp/wavefront operations +#if defined(PLATFORM_CUDA_DEVICE) +constexpr uint32_t WARP_FULL_MASK = 0xffffffff; // 32-bit mask for CUDA +#elif defined(PLATFORM_HIP_DEVICE) +constexpr uint64_t WARP_FULL_MASK = 0xffffffffffffffffULL; // 64-bit mask for HIP +#endif + __device__ __forceinline__ uint32_t frag_layout_swizzle_16b_to_8b(uint32_t x) { - uint32_t tmp = __shfl_xor_sync(0xffffffff, x, 0x1); + uint32_t tmp = __shfl_xor_sync(WARP_FULL_MASK, x, 0x1); x = __byte_perm(x, tmp, ((threadIdx.x & 0x1) == 0) ? 0x5410 : 0x3276); - tmp = __shfl_xor_sync(0xffffffff, x, 0x2); + tmp = __shfl_xor_sync(WARP_FULL_MASK, x, 0x2); x = __byte_perm(x, tmp, ((threadIdx.x & 0x2) == 0) ? 0x5410 : 0x3276); return x; } __device__ __forceinline__ uint32_t frag_layout_swizzle_16b_to_8b_trans(uint32_t x) { - uint32_t tmp = __shfl_xor_sync(0xffffffff, x, 0x4); + uint32_t tmp = __shfl_xor_sync(WARP_FULL_MASK, x, 0x4); x = __byte_perm(x, tmp, ((threadIdx.x & 0x4) == 0) ? 0x6420 : 0x3175); - tmp = __shfl_xor_sync(0xffffffff, x, 0x8); + tmp = __shfl_xor_sync(WARP_FULL_MASK, x, 0x8); x = __byte_perm(x, tmp, ((threadIdx.x & 0x8) == 0) ? 0x5410 : 0x3276); - tmp = __shfl_xor_sync(0xffffffff, x, 0x10); + tmp = __shfl_xor_sync(WARP_FULL_MASK, x, 0x10); x = __byte_perm(x, tmp, ((threadIdx.x & 0x10) == 0) ? 0x5410 : 0x3276); return x; } From 6ff963a4ac351f4f7bb13a4d66d7bb8db6b6301c Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Tue, 16 Sep 2025 14:37:39 -0400 Subject: [PATCH 085/109] Reformat prefill.cuh --- .../flashinfer/attention/generic/prefill.cuh | 5570 +++++++---------- 1 file changed, 2391 insertions(+), 3179 deletions(-) diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh index dcc3993e52..d9df683e80 100644 --- a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -16,19 +16,17 @@ #ifdef FP16_QK_REDUCTION_SUPPORTED #include "../../fp16.h" #endif -#include "frag_layout_swizzle.cuh" +#include #include "cascade.cuh" #include "dispatch.cuh" +#include "frag_layout_swizzle.cuh" #include "page.cuh" #include "permuted_smem.cuh" #include "pos_enc.cuh" #include "variants.cuh" -#include - -namespace flashinfer -{ +namespace flashinfer { DEFINE_HAS_MEMBER(maybe_q_rope_offset) DEFINE_HAS_MEMBER(maybe_k_rope_offset) @@ -42,255 +40,198 @@ using mma::MMAMode; constexpr uint32_t WARP_SIZE = gpu_iface::kWarpSize; -constexpr uint32_t get_num_warps_q(const uint32_t cta_tile_q) -{ - if (cta_tile_q > 16) { - return 4; - } - else { - return 1; - } +constexpr uint32_t get_num_warps_q(const uint32_t cta_tile_q) { + if (cta_tile_q > 16) { + return 4; + } else { + return 1; + } } -constexpr uint32_t get_num_warps_kv(const uint32_t cta_tile_kv) -{ - return 4 / get_num_warps_q(cta_tile_kv); +constexpr uint32_t get_num_warps_kv(const uint32_t cta_tile_kv) { + return 4 / get_num_warps_q(cta_tile_kv); } -constexpr uint32_t get_num_mma_q(const uint32_t cta_tile_q) -{ - if (cta_tile_q > 64) { - return 2; - } - else { - return 1; - } +constexpr uint32_t get_num_mma_q(const uint32_t cta_tile_q) { + if (cta_tile_q > 64) { + return 2; + } else { + return 1; + } } -template -struct SharedStorageQKVO -{ - union - { - struct - { - alignas(16) DTypeQ q_smem[CTA_TILE_Q * HEAD_DIM_QK]; - alignas(16) DTypeKV k_smem[CTA_TILE_KV * HEAD_DIM_QK]; +template +struct SharedStorageQKVO { + union { + struct { + alignas(16) DTypeQ q_smem[CTA_TILE_Q * HEAD_DIM_QK]; + alignas(16) DTypeKV k_smem[CTA_TILE_KV * HEAD_DIM_QK]; #if Debug - alignas(16) DTypeKV qk_scratch[CTA_TILE_Q * CTA_TILE_KV]; + alignas(16) DTypeKV qk_scratch[CTA_TILE_Q * CTA_TILE_KV]; #endif - alignas(16) DTypeKV v_smem[CTA_TILE_KV * HEAD_DIM_VO]; - }; - struct - { // NOTE(Zihao): synchronize attention states across warps - alignas(16) std::conditional_t< - NUM_WARPS_KV == 1, - float[1], - float[NUM_WARPS_KV * CTA_TILE_Q * HEAD_DIM_VO]> cta_sync_o_smem; - alignas(16) std::conditional_t< - NUM_WARPS_KV == 1, - float2[1], - float2[NUM_WARPS_KV * CTA_TILE_Q]> cta_sync_md_smem; - }; - alignas(16) DTypeO smem_o[CTA_TILE_Q * HEAD_DIM_VO]; + alignas(16) DTypeKV v_smem[CTA_TILE_KV * HEAD_DIM_VO]; + }; + struct { // NOTE(Zihao): synchronize attention states across warps + alignas( + 16) std::conditional_t cta_sync_o_smem; + alignas(16) std::conditional_t cta_sync_md_smem; }; + alignas(16) DTypeO smem_o[CTA_TILE_Q * HEAD_DIM_VO]; + }; }; -template -struct KernelTraits -{ - static constexpr MaskMode MASK_MODE = MASK_MODE_; - static constexpr uint32_t NUM_MMA_Q = NUM_MMA_Q_; - static constexpr uint32_t NUM_MMA_KV = NUM_MMA_KV_; - static constexpr uint32_t NUM_MMA_D_QK = NUM_MMA_D_QK_; - static constexpr uint32_t NUM_MMA_D_VO = NUM_MMA_D_VO_; - static constexpr uint32_t NUM_WARPS_Q = NUM_WARPS_Q_; - static constexpr uint32_t NUM_WARPS_KV = NUM_WARPS_KV_; - static constexpr uint32_t NUM_WARPS = NUM_WARPS_Q * NUM_WARPS_KV; - static constexpr uint32_t HEAD_DIM_QK = NUM_MMA_D_QK * 16; - static constexpr uint32_t HEAD_DIM_VO = NUM_MMA_D_VO * 16; - static constexpr uint32_t CTA_TILE_Q = CTA_TILE_Q_; - static constexpr uint32_t CTA_TILE_KV = NUM_MMA_KV * NUM_WARPS_KV * 16; - static constexpr PosEncodingMode POS_ENCODING_MODE = POS_ENCODING_MODE_; - - using DTypeQ = DTypeQ_; - using DTypeKV = DTypeKV_; - using DTypeO = DTypeO_; - using DTypeQKAccum = DTypeQKAccum_; - using IdType = IdType_; - using AttentionVariant = AttentionVariant_; +struct KernelTraits { + static constexpr MaskMode MASK_MODE = MASK_MODE_; + static constexpr uint32_t NUM_MMA_Q = NUM_MMA_Q_; + static constexpr uint32_t NUM_MMA_KV = NUM_MMA_KV_; + static constexpr uint32_t NUM_MMA_D_QK = NUM_MMA_D_QK_; + static constexpr uint32_t NUM_MMA_D_VO = NUM_MMA_D_VO_; + static constexpr uint32_t NUM_WARPS_Q = NUM_WARPS_Q_; + static constexpr uint32_t NUM_WARPS_KV = NUM_WARPS_KV_; + static constexpr uint32_t NUM_WARPS = NUM_WARPS_Q * NUM_WARPS_KV; + static constexpr uint32_t HEAD_DIM_QK = NUM_MMA_D_QK * 16; + static constexpr uint32_t HEAD_DIM_VO = NUM_MMA_D_VO * 16; + static constexpr uint32_t CTA_TILE_Q = CTA_TILE_Q_; + static constexpr uint32_t CTA_TILE_KV = NUM_MMA_KV * NUM_WARPS_KV * 16; + static constexpr PosEncodingMode POS_ENCODING_MODE = POS_ENCODING_MODE_; + + using DTypeQ = DTypeQ_; + using DTypeKV = DTypeKV_; + using DTypeO = DTypeO_; + using DTypeQKAccum = DTypeQKAccum_; + using IdType = IdType_; + using AttentionVariant = AttentionVariant_; #if defined(PLATFORM_HIP_DEVICE) - static_assert(sizeof(DTypeKV_) != 1, "8-bit types not supported for CDNA3"); - - using SmemBasePtrTy = uint2; - static constexpr uint32_t NUM_THREADS = NUM_WARPS_Q * NUM_WARPS_KV * 64; - static constexpr uint32_t WARP_THREAD_ROWS = 4; - static constexpr uint32_t WARP_THREAD_COLS = 16; - static constexpr uint32_t HALF_ELEMS_PER_THREAD = 4; - static constexpr uint32_t INT32_ELEMS_PER_THREAD = 2; - static constexpr uint32_t VECTOR_BIT_WIDTH = HALF_ELEMS_PER_THREAD * 16; - // FIXME: Update with a proper swizzle pattern. Linear is used primarily - // for intial testing. - static constexpr SwizzleMode SWIZZLE_MODE_Q = SwizzleMode::kLinear; - static constexpr SwizzleMode SWIZZLE_MODE_KV = SwizzleMode::kLinear; - - // Presently we use 16x4 thread layout for all cases. - static constexpr uint32_t KV_THR_LAYOUT_ROW = WARP_THREAD_ROWS; - static constexpr uint32_t KV_THR_LAYOUT_COL = WARP_THREAD_COLS; - // FIXME: [The comment is not correct] The constant is defined based on the - // matrix layout of the "D/C" accumulator matrix in a D = A*B+C computation. - // On CDNA3 the D/C matrices are distributed as four 4x16 bands across the - // 64 threads. Each thread owns one element from four different rows. - static constexpr uint32_t NUM_ACCUM_ROWS_PER_THREAD = 4; - // Number of threads that collaboratively handle the same set of matrix rows - // in attention score computation and cross-warp synchronization. - // CUDA: 4 threads (each thread handles 2 elements from same row group) - // CDNA3: 16 threads (each thread handles 1 element from same row group) - static constexpr uint32_t THREADS_PER_BMATRIX_ROW_SET = 16; - // controls the indexing stride used in logits-related functions - // (logits_transform, logits_mask, and LSE writing). - static constexpr uint32_t LOGITS_INDEX_STRIDE = 4; + static_assert(sizeof(DTypeKV_) != 1, "8-bit types not supported for CDNA3"); + + using SmemBasePtrTy = uint2; + static constexpr uint32_t NUM_THREADS = NUM_WARPS_Q * NUM_WARPS_KV * 64; + static constexpr uint32_t WARP_THREAD_ROWS = 4; + static constexpr uint32_t WARP_THREAD_COLS = 16; + static constexpr uint32_t HALF_ELEMS_PER_THREAD = 4; + static constexpr uint32_t INT32_ELEMS_PER_THREAD = 2; + static constexpr uint32_t VECTOR_BIT_WIDTH = HALF_ELEMS_PER_THREAD * 16; + // FIXME: Update with a proper swizzle pattern. Linear is used primarily + // for intial testing. + static constexpr SwizzleMode SWIZZLE_MODE_Q = SwizzleMode::kLinear; + static constexpr SwizzleMode SWIZZLE_MODE_KV = SwizzleMode::kLinear; + + // Presently we use 16x4 thread layout for all cases. + static constexpr uint32_t KV_THR_LAYOUT_ROW = WARP_THREAD_ROWS; + static constexpr uint32_t KV_THR_LAYOUT_COL = WARP_THREAD_COLS; + // FIXME: [The comment is not correct] The constant is defined based on the + // matrix layout of the "D/C" accumulator matrix in a D = A*B+C computation. + // On CDNA3 the D/C matrices are distributed as four 4x16 bands across the + // 64 threads. Each thread owns one element from four different rows. + static constexpr uint32_t NUM_ACCUM_ROWS_PER_THREAD = 4; + // Number of threads that collaboratively handle the same set of matrix rows + // in attention score computation and cross-warp synchronization. + // CUDA: 4 threads (each thread handles 2 elements from same row group) + // CDNA3: 16 threads (each thread handles 1 element from same row group) + static constexpr uint32_t THREADS_PER_BMATRIX_ROW_SET = 16; + // controls the indexing stride used in logits-related functions + // (logits_transform, logits_mask, and LSE writing). + static constexpr uint32_t LOGITS_INDEX_STRIDE = 4; #else - using SmemBasePtrTy = uint4; - static constexpr uint32_t NUM_THREADS = NUM_WARPS_Q * NUM_WARPS_KV * 32; - constexpr uint32_t WARP_THREAD_ROWS = 4; - constexpr uint32_t WARP_THREAD_COLS = 8; - constexpr uint32_t HALF_ELEMS_PER_THREAD = 8; - constexpr uint32_t INT32_ELEMS_PER_THREAD = 4; - constexpr uint32_t VECTOR_BIT_WIDTH = HALF_ELEMS_PER_THREAD * 16; - - static constexpr SwizzleMode SWIZZLE_MODE_Q = SwizzleMode::k128B; - static constexpr SwizzleMode SWIZZLE_MODE_KV = - (sizeof(DTypeKV_) == 1 && HEAD_DIM_VO == 64) ? SwizzleMode::k64B - : SwizzleMode::k128B; - static constexpr uint32_t KV_THR_LAYOUT_ROW = - SWIZZLE_MODE_KV == SwizzleMode::k128B ? WARP_THREAD_ROWS - : WARP_THREAD_COLS; - static constexpr uint32_t KV_THR_LAYOUT_COL = - SWIZZLE_MODE_KV == SwizzleMode::k128B ? 8 : 4; - - // The constant is defined based on the matrix layout of the "D/C" - // accumulator matrix in a D = A*B+C computation. On CUDA for - // m16n8k16 mma ops the D/C matrix is distributed as 4 8x8 block and each - // thread stores eight elements from two different rows. - // Refer: - // https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-fragment-mma-16816-i8-f8 - static constexpr uint32_t NUM_ACCUM_ROWS_PER_THREAD = 2; - static constexpr uint32_t THREADS_PER_BMATRIX_ROW_SET = 4; - static constexpr uint32_t LOGITS_INDEX_STRIDE = 8; + using SmemBasePtrTy = uint4; + static constexpr uint32_t NUM_THREADS = NUM_WARPS_Q * NUM_WARPS_KV * 32; + constexpr uint32_t WARP_THREAD_ROWS = 4; + constexpr uint32_t WARP_THREAD_COLS = 8; + constexpr uint32_t HALF_ELEMS_PER_THREAD = 8; + constexpr uint32_t INT32_ELEMS_PER_THREAD = 4; + constexpr uint32_t VECTOR_BIT_WIDTH = HALF_ELEMS_PER_THREAD * 16; + + static constexpr SwizzleMode SWIZZLE_MODE_Q = SwizzleMode::k128B; + static constexpr SwizzleMode SWIZZLE_MODE_KV = + (sizeof(DTypeKV_) == 1 && HEAD_DIM_VO == 64) ? SwizzleMode::k64B : SwizzleMode::k128B; + static constexpr uint32_t KV_THR_LAYOUT_ROW = + SWIZZLE_MODE_KV == SwizzleMode::k128B ? WARP_THREAD_ROWS : WARP_THREAD_COLS; + static constexpr uint32_t KV_THR_LAYOUT_COL = SWIZZLE_MODE_KV == SwizzleMode::k128B ? 8 : 4; + + // The constant is defined based on the matrix layout of the "D/C" + // accumulator matrix in a D = A*B+C computation. On CUDA for + // m16n8k16 mma ops the D/C matrix is distributed as 4 8x8 block and each + // thread stores eight elements from two different rows. + // Refer: + // https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-fragment-mma-16816-i8-f8 + static constexpr uint32_t NUM_ACCUM_ROWS_PER_THREAD = 2; + static constexpr uint32_t THREADS_PER_BMATRIX_ROW_SET = 4; + static constexpr uint32_t LOGITS_INDEX_STRIDE = 8; #endif - static constexpr uint32_t UPCAST_STRIDE_Q = - HEAD_DIM_QK / upcast_size(); - static constexpr uint32_t UPCAST_STRIDE_K = - HEAD_DIM_QK / upcast_size(); - static constexpr uint32_t UPCAST_STRIDE_V = - HEAD_DIM_VO / upcast_size(); - static constexpr uint32_t UPCAST_STRIDE_O = - HEAD_DIM_VO / upcast_size(); - - static constexpr bool IsInvalid() - { - return ((NUM_MMA_D_VO < 4) || - (NUM_MMA_D_VO == 4 && NUM_MMA_KV % 2 == 1) || - (POS_ENCODING_MODE == PosEncodingMode::kRoPELlama && - NUM_MMA_D_VO > 4 && NUM_MMA_D_VO % (2 * NUM_WARPS_Q) != 0) || - (NUM_MMA_Q * (8 * NUM_MMA_D_VO + - 2 * sizeof(DTypeQKAccum) * NUM_MMA_KV) >= - 256) || - (sizeof(DTypeKV) == 1 && NUM_MMA_KV * 2 % NUM_WARPS_Q != 0) || - (sizeof(DTypeKV) == 1 && - POS_ENCODING_MODE == PosEncodingMode::kRoPELlama)); - } - - using SharedStorage = SharedStorageQKVO; + static constexpr uint32_t UPCAST_STRIDE_Q = + HEAD_DIM_QK / upcast_size(); + static constexpr uint32_t UPCAST_STRIDE_K = + HEAD_DIM_QK / upcast_size(); + static constexpr uint32_t UPCAST_STRIDE_V = + HEAD_DIM_VO / upcast_size(); + static constexpr uint32_t UPCAST_STRIDE_O = + HEAD_DIM_VO / upcast_size(); + + static constexpr bool IsInvalid() { + return ((NUM_MMA_D_VO < 4) || (NUM_MMA_D_VO == 4 && NUM_MMA_KV % 2 == 1) || + (POS_ENCODING_MODE == PosEncodingMode::kRoPELlama && NUM_MMA_D_VO > 4 && + NUM_MMA_D_VO % (2 * NUM_WARPS_Q) != 0) || + (NUM_MMA_Q * (8 * NUM_MMA_D_VO + 2 * sizeof(DTypeQKAccum) * NUM_MMA_KV) >= 256) || + (sizeof(DTypeKV) == 1 && NUM_MMA_KV * 2 % NUM_WARPS_Q != 0) || + (sizeof(DTypeKV) == 1 && POS_ENCODING_MODE == PosEncodingMode::kRoPELlama)); + } + + using SharedStorage = SharedStorageQKVO; #ifdef FP16_QK_REDUCTION_SUPPORTED - template static constexpr DT getNegInf() - { - if constexpr (std::is_same::value) { - return std::bit_cast( - fp16_ieee_from_fp32_value(-gpu_iface::math::inf)); - } - else { - return static_cast(-gpu_iface::math::inf); - } + template + static constexpr DT getNegInf() { + if constexpr (std::is_same::value) { + return std::bit_cast(fp16_ieee_from_fp32_value(-gpu_iface::math::inf)); + } else { + return static_cast(-gpu_iface::math::inf); } + } - static constexpr DTypeQKAccum MaskFillValue = - AttentionVariant::use_softmax ? getNegInf() - : DTypeQKAccum(0.f); + static constexpr DTypeQKAccum MaskFillValue = + AttentionVariant::use_softmax ? getNegInf() : DTypeQKAccum(0.f); #else - static_assert(!std::is_same::value, - "Set -DFP16_QK_REDUCTION_SUPPORTED and install boost_math " - "then recompile to support fp16 reduction"); - static constexpr DTypeQKAccum MaskFillValue = - AttentionVariant::use_softmax ? DTypeQKAccum(-gpu_iface::math::inf) - : DTypeQKAccum(0.f); + static_assert(!std::is_same::value, + "Set -DFP16_QK_REDUCTION_SUPPORTED and install boost_math " + "then recompile to support fp16 reduction"); + static constexpr DTypeQKAccum MaskFillValue = + AttentionVariant::use_softmax ? DTypeQKAccum(-gpu_iface::math::inf) : DTypeQKAccum(0.f); #endif }; -namespace -{ +namespace { template -__device__ __forceinline__ uint32_t -get_warp_idx_q(const uint32_t tid_y = threadIdx.y) -{ - if constexpr (KTraits::NUM_WARPS_Q == 1) { - return 0; - } - else { - return tid_y; - } +__device__ __forceinline__ uint32_t get_warp_idx_q(const uint32_t tid_y = threadIdx.y) { + if constexpr (KTraits::NUM_WARPS_Q == 1) { + return 0; + } else { + return tid_y; + } } template -__device__ __forceinline__ uint32_t -get_warp_idx_kv(const uint32_t tid_z = threadIdx.z) -{ - if constexpr (KTraits::NUM_WARPS_KV == 1) { - return 0; - } - else { - return tid_z; - } +__device__ __forceinline__ uint32_t get_warp_idx_kv(const uint32_t tid_z = threadIdx.z) { + if constexpr (KTraits::NUM_WARPS_KV == 1) { + return 0; + } else { + return tid_z; + } } template -__device__ __forceinline__ uint32_t -get_warp_idx(const uint32_t tid_y = threadIdx.y, - const uint32_t tid_z = threadIdx.z) -{ - return get_warp_idx_kv(tid_z) * KTraits::NUM_WARPS_Q + - get_warp_idx_q(tid_y); +__device__ __forceinline__ uint32_t get_warp_idx(const uint32_t tid_y = threadIdx.y, + const uint32_t tid_z = threadIdx.z) { + return get_warp_idx_kv(tid_z) * KTraits::NUM_WARPS_Q + get_warp_idx_q(tid_y); } /*! @@ -305,215 +246,175 @@ get_warp_idx(const uint32_t tid_y = threadIdx.y, * non tensor-ops flops, will optimize in the future. */ template -__device__ __forceinline__ void -k_frag_apply_llama_rope(T *x_first_half, - T *x_second_half, - const float *rope_freq, - const uint32_t kv_offset) -{ - static_assert(sizeof(T) == 2); -#pragma unroll - for (uint32_t reg_id = 0; reg_id < HALF_ELEMS_PER_THREAD; ++reg_id) { - float cos, sin, tmp; - // 0 1 | 2 3 - // --------- - // 4 5 | 6 7 +__device__ __forceinline__ void k_frag_apply_llama_rope(T* x_first_half, T* x_second_half, + const float* rope_freq, + const uint32_t kv_offset) { + static_assert(sizeof(T) == 2); +#pragma unroll + for (uint32_t reg_id = 0; reg_id < HALF_ELEMS_PER_THREAD; ++reg_id) { + float cos, sin, tmp; + // 0 1 | 2 3 + // --------- + // 4 5 | 6 7 #if defined(PLATFORM_HIP_DEVICE) - uint32_t i = reg_id / 2, j = reg_id % 2; + uint32_t i = reg_id / 2, j = reg_id % 2; #else - uint32_t i = reg_id / 4, j = (reg_id % 4) / 2; + uint32_t i = reg_id / 4, j = (reg_id % 4) / 2; #endif - __sincosf(float(kv_offset + 8 * i) * rope_freq[2 * j + reg_id % 2], - &sin, &cos); - tmp = x_first_half[reg_id]; - x_first_half[reg_id] = (tmp * cos - (float)x_second_half[reg_id] * sin); - x_second_half[reg_id] = - ((float)x_second_half[reg_id] * cos + tmp * sin); - } + __sincosf(float(kv_offset + 8 * i) * rope_freq[2 * j + reg_id % 2], &sin, &cos); + tmp = x_first_half[reg_id]; + x_first_half[reg_id] = (tmp * cos - (float)x_second_half[reg_id] * sin); + x_second_half[reg_id] = ((float)x_second_half[reg_id] * cos + tmp * sin); + } } template -__device__ __forceinline__ void -q_frag_apply_llama_rope(T *x_first_half, - T *x_second_half, - const float *rope_freq, - const uint32_t qo_packed_offset, - const uint_fastdiv group_size) -{ -#pragma unroll - for (uint32_t reg_id = 0; reg_id < HALF_ELEMS_PER_THREAD; ++reg_id) { - float cos, sin, tmp; +__device__ __forceinline__ void q_frag_apply_llama_rope(T* x_first_half, T* x_second_half, + const float* rope_freq, + const uint32_t qo_packed_offset, + const uint_fastdiv group_size) { +#pragma unroll + for (uint32_t reg_id = 0; reg_id < HALF_ELEMS_PER_THREAD; ++reg_id) { + float cos, sin, tmp; #if defined(PLATFORM_HIP_DEVICE) - uint32_t freq_idx = reg_id; - uint32_t position = qo_packed_offset; + uint32_t freq_idx = reg_id; + uint32_t position = qo_packed_offset; #else - // 0 1 | 4 5 - // --------- - // 2 3 | 6 7 - uint32_t i = ((reg_id % 4) / 2), j = (reg_id / 4); - uint32_t freq_idx = 2 * j + reg_id % 2; - uint32_t position = qo_packed_offset + 8 * i; + // 0 1 | 4 5 + // --------- + // 2 3 | 6 7 + uint32_t i = ((reg_id % 4) / 2), j = (reg_id / 4); + uint32_t freq_idx = 2 * j + reg_id % 2; + uint32_t position = qo_packed_offset + 8 * i; #endif - __sincosf(float(position / group_size) * rope_freq[freq_idx], &sin, - &cos); - tmp = x_first_half[reg_id]; - x_first_half[reg_id] = (tmp * cos - (float)x_second_half[reg_id] * sin); - x_second_half[reg_id] = - ((float)x_second_half[reg_id] * cos + tmp * sin); - } + __sincosf(float(position / group_size) * rope_freq[freq_idx], &sin, &cos); + tmp = x_first_half[reg_id]; + x_first_half[reg_id] = (tmp * cos - (float)x_second_half[reg_id] * sin); + x_second_half[reg_id] = ((float)x_second_half[reg_id] * cos + tmp * sin); + } } template -__device__ __forceinline__ void -q_frag_apply_llama_rope_with_pos(T *x_first_half, - T *x_second_half, - const float *rope_freq, - const uint32_t qo_packed_offset, - const uint_fastdiv group_size, - const IdType *q_rope_offset) -{ - float pos[2] = { - static_cast(q_rope_offset[qo_packed_offset / group_size]), - static_cast(q_rope_offset[(qo_packed_offset + 8) / group_size])}; -#pragma unroll - for (uint32_t reg_id = 0; reg_id < HALF_ELEMS_PER_THREAD; ++reg_id) { - float cos, sin, tmp; - // 0 1 | 4 5 - // --------- - // 2 3 | 6 7 +__device__ __forceinline__ void q_frag_apply_llama_rope_with_pos(T* x_first_half, T* x_second_half, + const float* rope_freq, + const uint32_t qo_packed_offset, + const uint_fastdiv group_size, + const IdType* q_rope_offset) { + float pos[2] = {static_cast(q_rope_offset[qo_packed_offset / group_size]), + static_cast(q_rope_offset[(qo_packed_offset + 8) / group_size])}; +#pragma unroll + for (uint32_t reg_id = 0; reg_id < HALF_ELEMS_PER_THREAD; ++reg_id) { + float cos, sin, tmp; + // 0 1 | 4 5 + // --------- + // 2 3 | 6 7 #if defined(PLATFORM_HIP_DEVICE) - const uint32_t i = reg_id / 2; - const uint32_t j = reg_id % 2; + const uint32_t i = reg_id / 2; + const uint32_t j = reg_id % 2; #else - const uint32_t i = (reg_id % 4) / 2; - const uint32_t j = reg_id / 4; + const uint32_t i = (reg_id % 4) / 2; + const uint32_t j = reg_id / 4; #endif - __sincosf(pos[i] * rope_freq[2 * j + reg_id % 2], &sin, &cos); - tmp = x_first_half[reg_id]; - x_first_half[reg_id] = (tmp * cos - (float)x_second_half[reg_id] * sin); - x_second_half[reg_id] = - ((float)x_second_half[reg_id] * cos + tmp * sin); - } + __sincosf(pos[i] * rope_freq[2 * j + reg_id % 2], &sin, &cos); + tmp = x_first_half[reg_id]; + x_first_half[reg_id] = (tmp * cos - (float)x_second_half[reg_id] * sin); + x_second_half[reg_id] = ((float)x_second_half[reg_id] * cos + tmp * sin); + } } template __device__ __forceinline__ void produce_kv_impl_cuda_( - uint32_t warp_idx, - uint32_t lane_idx, - smem_t smem, - uint32_t *smem_offset, - typename KTraits::DTypeKV **gptr, - const uint32_t stride_n, - const uint32_t kv_idx_base, - const uint32_t kv_len) -{ - using DTypeKV = typename KTraits::DTypeKV; - constexpr uint32_t NUM_WARPS = KTraits::NUM_WARPS; - constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; - constexpr uint32_t NUM_WARPS_Q = KTraits::NUM_WARPS_Q; - constexpr uint32_t NUM_MMA_D = - produce_v ? KTraits::NUM_MMA_D_VO : KTraits::NUM_MMA_D_QK; - constexpr uint32_t UPCAST_STRIDE = - produce_v ? KTraits::UPCAST_STRIDE_V : KTraits::UPCAST_STRIDE_K; - constexpr uint32_t VECTOR_BIT_WIDTH = KTraits::VECTOR_BIT_WIDTH; - - if constexpr (KTraits::SWIZZLE_MODE_KV == SwizzleMode::k128B) { - uint32_t kv_idx = kv_idx_base + warp_idx * 4 + lane_idx / 8; - // NOTE: NUM_MMA_KV * 4 / NUM_WARPS_Q = NUM_WARPS_KV * NUM_MMA_KV * 4 / - // num_warps - static_assert(NUM_MMA_KV * 4 % NUM_WARPS_Q == 0); -#pragma unroll - for (uint32_t i = 0; i < NUM_MMA_KV * 4 / NUM_WARPS_Q; ++i) { -#pragma unroll - for (uint32_t j = 0; j < NUM_MMA_D / (8 / sizeof(DTypeKV)); ++j) { - smem.template load_128b_async(*smem_offset, *gptr, - kv_idx < kv_len); - *smem_offset = - smem.template advance_offset_by_column<8>(*smem_offset, j); - *gptr += 8 * upcast_size(); - } - kv_idx += NUM_WARPS * 4; - *smem_offset = smem.template advance_offset_by_row( - *smem_offset) - - sizeof(DTypeKV) * NUM_MMA_D; - *gptr += NUM_WARPS * 4 * stride_n - - sizeof(DTypeKV) * NUM_MMA_D * - upcast_size(); - } - *smem_offset -= KTraits::CTA_TILE_KV * UPCAST_STRIDE; + uint32_t warp_idx, uint32_t lane_idx, + smem_t smem, uint32_t* smem_offset, + typename KTraits::DTypeKV** gptr, const uint32_t stride_n, const uint32_t kv_idx_base, + const uint32_t kv_len) { + using DTypeKV = typename KTraits::DTypeKV; + constexpr uint32_t NUM_WARPS = KTraits::NUM_WARPS; + constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; + constexpr uint32_t NUM_WARPS_Q = KTraits::NUM_WARPS_Q; + constexpr uint32_t NUM_MMA_D = produce_v ? KTraits::NUM_MMA_D_VO : KTraits::NUM_MMA_D_QK; + constexpr uint32_t UPCAST_STRIDE = + produce_v ? KTraits::UPCAST_STRIDE_V : KTraits::UPCAST_STRIDE_K; + constexpr uint32_t VECTOR_BIT_WIDTH = KTraits::VECTOR_BIT_WIDTH; + + if constexpr (KTraits::SWIZZLE_MODE_KV == SwizzleMode::k128B) { + uint32_t kv_idx = kv_idx_base + warp_idx * 4 + lane_idx / 8; + // NOTE: NUM_MMA_KV * 4 / NUM_WARPS_Q = NUM_WARPS_KV * NUM_MMA_KV * 4 / num_warps + static_assert(NUM_MMA_KV * 4 % NUM_WARPS_Q == 0); +#pragma unroll + for (uint32_t i = 0; i < NUM_MMA_KV * 4 / NUM_WARPS_Q; ++i) { +#pragma unroll + for (uint32_t j = 0; j < NUM_MMA_D / (8 / sizeof(DTypeKV)); ++j) { + smem.template load_128b_async(*smem_offset, *gptr, kv_idx < kv_len); + *smem_offset = smem.template advance_offset_by_column<8>(*smem_offset, j); + *gptr += 8 * upcast_size(); + } + kv_idx += NUM_WARPS * 4; + *smem_offset = + smem.template advance_offset_by_row(*smem_offset) - + sizeof(DTypeKV) * NUM_MMA_D; + *gptr += NUM_WARPS * 4 * stride_n - + sizeof(DTypeKV) * NUM_MMA_D * upcast_size(); } - else { - uint32_t kv_idx = kv_idx_base + warp_idx * 8 + lane_idx / 4; - // NOTE: NUM_MMA_KV * 2 / NUM_WARPS_Q = NUM_WARPS_KV * NUM_MMA_KV * 2 / - // num_warps - static_assert(NUM_MMA_KV * 2 % NUM_WARPS_Q == 0); -#pragma unroll - for (uint32_t i = 0; i < NUM_MMA_KV * 2 / NUM_WARPS_Q; ++i) { - smem.template load_128b_async(*smem_offset, *gptr, - kv_idx < kv_len); - *smem_offset = smem.template advance_offset_by_row( - *smem_offset); - kv_idx += NUM_WARPS * 8; - *gptr += NUM_WARPS * 8 * stride_n; - } - *smem_offset -= KTraits::CTA_TILE_KV * UPCAST_STRIDE; + *smem_offset -= KTraits::CTA_TILE_KV * UPCAST_STRIDE; + } else { + uint32_t kv_idx = kv_idx_base + warp_idx * 8 + lane_idx / 4; + // NOTE: NUM_MMA_KV * 2 / NUM_WARPS_Q = NUM_WARPS_KV * NUM_MMA_KV * 2 / + // num_warps + static_assert(NUM_MMA_KV * 2 % NUM_WARPS_Q == 0); +#pragma unroll + for (uint32_t i = 0; i < NUM_MMA_KV * 2 / NUM_WARPS_Q; ++i) { + smem.template load_128b_async(*smem_offset, *gptr, kv_idx < kv_len); + *smem_offset = + smem.template advance_offset_by_row(*smem_offset); + kv_idx += NUM_WARPS * 8; + *gptr += NUM_WARPS * 8 * stride_n; } + *smem_offset -= KTraits::CTA_TILE_KV * UPCAST_STRIDE; + } } template __device__ __forceinline__ void produce_kv_impl_cdna3_( - uint32_t warp_idx, - uint32_t lane_idx, - smem_t smem, - uint32_t *smem_offset, - typename KTraits::DTypeKV **gptr, - const uint32_t stride_n, - const uint32_t kv_idx_base, - const uint32_t kv_len) -{ - static_assert(KTraits::SWIZZLE_MODE_KV == SwizzleMode::kLinear); - using DTypeKV = typename KTraits::DTypeKV; - constexpr uint32_t KV_THR_LAYOUT_COL = KTraits::KV_THR_LAYOUT_COL; // 16 - constexpr uint32_t KV_THR_LAYOUT_ROW = KTraits::KV_THR_LAYOUT_ROW; // 4 - constexpr uint32_t NUM_WARPS = KTraits::NUM_WARPS; - constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; - constexpr uint32_t NUM_WARPS_Q = KTraits::NUM_WARPS_Q; - constexpr uint32_t NUM_MMA_D = - produce_v ? KTraits::NUM_MMA_D_VO : KTraits::NUM_MMA_D_QK; - constexpr uint32_t UPCAST_STRIDE = - produce_v ? KTraits::UPCAST_STRIDE_V : KTraits::UPCAST_STRIDE_K; - constexpr uint32_t VECTOR_BIT_WIDTH = KTraits::VECTOR_BIT_WIDTH; - constexpr uint32_t HALF_ELEMS_PER_THREAD = KTraits::HALF_ELEMS_PER_THREAD; - - // NOTE: NUM_MMA_KV*4/NUM_WARPS_Q = NUM_WARPS_KV*NUM_MMA_KV*4/num_warps - static_assert(NUM_MMA_KV * 4 % NUM_WARPS_Q == 0); - - uint32_t kv_idx = kv_idx_base + warp_idx * 4 + lane_idx / KV_THR_LAYOUT_COL; - // NOTE: NUM_MMA_KV * 4 / NUM_WARPS_Q = NUM_WARPS_KV*NUM_MMA_KV*4/num_warps - static_assert(NUM_MMA_KV * 4 % NUM_WARPS_Q == 0); -#pragma unroll - for (uint32_t i = 0; i < NUM_MMA_KV * 4 / NUM_WARPS_Q; ++i) { -#pragma unroll - for (uint32_t j = 0; j < NUM_MMA_D / (8 / sizeof(DTypeKV)); ++j) { - smem.template load_vector_async(*smem_offset, *gptr, - kv_idx < kv_len); - *smem_offset = - smem.template advance_offset_by_column<16>(*smem_offset, j); - *gptr += 16 * upcast_size(); - } - kv_idx += NUM_WARPS * 4; - *smem_offset = - smem.template advance_offset_by_row( - *smem_offset) - - (sizeof(DTypeKV) * NUM_MMA_D * 2); - *gptr += NUM_WARPS * 4 * stride_n - - sizeof(DTypeKV) * NUM_MMA_D * 2 * - upcast_size(); + uint32_t warp_idx, uint32_t lane_idx, + smem_t smem, uint32_t* smem_offset, + typename KTraits::DTypeKV** gptr, const uint32_t stride_n, const uint32_t kv_idx_base, + const uint32_t kv_len) { + static_assert(KTraits::SWIZZLE_MODE_KV == SwizzleMode::kLinear); + using DTypeKV = typename KTraits::DTypeKV; + constexpr uint32_t KV_THR_LAYOUT_COL = KTraits::KV_THR_LAYOUT_COL; // 16 + constexpr uint32_t KV_THR_LAYOUT_ROW = KTraits::KV_THR_LAYOUT_ROW; // 4 + constexpr uint32_t NUM_WARPS = KTraits::NUM_WARPS; + constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; + constexpr uint32_t NUM_WARPS_Q = KTraits::NUM_WARPS_Q; + constexpr uint32_t NUM_MMA_D = produce_v ? KTraits::NUM_MMA_D_VO : KTraits::NUM_MMA_D_QK; + constexpr uint32_t UPCAST_STRIDE = + produce_v ? KTraits::UPCAST_STRIDE_V : KTraits::UPCAST_STRIDE_K; + constexpr uint32_t VECTOR_BIT_WIDTH = KTraits::VECTOR_BIT_WIDTH; + constexpr uint32_t HALF_ELEMS_PER_THREAD = KTraits::HALF_ELEMS_PER_THREAD; + + // NOTE: NUM_MMA_KV*4/NUM_WARPS_Q = NUM_WARPS_KV*NUM_MMA_KV*4/num_warps + static_assert(NUM_MMA_KV * 4 % NUM_WARPS_Q == 0); + + uint32_t kv_idx = kv_idx_base + warp_idx * 4 + lane_idx / KV_THR_LAYOUT_COL; + // NOTE: NUM_MMA_KV * 4 / NUM_WARPS_Q = NUM_WARPS_KV*NUM_MMA_KV*4/num_warps + static_assert(NUM_MMA_KV * 4 % NUM_WARPS_Q == 0); +#pragma unroll + for (uint32_t i = 0; i < NUM_MMA_KV * 4 / NUM_WARPS_Q; ++i) { +#pragma unroll + for (uint32_t j = 0; j < NUM_MMA_D / (8 / sizeof(DTypeKV)); ++j) { + smem.template load_vector_async(*smem_offset, *gptr, kv_idx < kv_len); + *smem_offset = smem.template advance_offset_by_column<16>(*smem_offset, j); + *gptr += 16 * upcast_size(); } - *smem_offset -= KTraits::CTA_TILE_KV * UPCAST_STRIDE; + kv_idx += NUM_WARPS * 4; + *smem_offset = smem.template advance_offset_by_row(*smem_offset) - + (sizeof(DTypeKV) * NUM_MMA_D * 2); + *gptr += NUM_WARPS * 4 * stride_n - + sizeof(DTypeKV) * NUM_MMA_D * 2 * upcast_size(); + } + *smem_offset -= KTraits::CTA_TILE_KV * UPCAST_STRIDE; } /*! @@ -530,157 +431,127 @@ __device__ __forceinline__ void produce_kv_impl_cdna3_( */ template __device__ __forceinline__ void produce_kv( - smem_t smem, - uint32_t *smem_offset, - typename KTraits::DTypeKV **gptr, - const uint32_t stride_n, - const uint32_t kv_idx_base, - const uint32_t kv_len, - const dim3 tid = threadIdx) -{ - // NOTE: for fp8, this function doesn't work for head_dim = 64 at the moment - const uint32_t warp_idx = get_warp_idx(tid.y, tid.z), - lane_idx = tid.x; + smem_t smem, uint32_t* smem_offset, + typename KTraits::DTypeKV** gptr, const uint32_t stride_n, const uint32_t kv_idx_base, + const uint32_t kv_len, const dim3 tid = threadIdx) { + // NOTE: for fp8, this function doesn't work for head_dim = 64 at the moment + const uint32_t warp_idx = get_warp_idx(tid.y, tid.z), lane_idx = tid.x; #if defined(PLATFORM_HIP_DEVICE) - produce_kv_impl_cdna3_( - warp_idx, lane_idx, smem, smem_offset, gptr, stride_n, kv_idx_base, - kv_len); + produce_kv_impl_cdna3_(warp_idx, lane_idx, smem, smem_offset, gptr, + stride_n, kv_idx_base, kv_len); #elif defined(PLATFORM_CUDA_DEVICE) - produce_kv_impl_cuda_( - warp_idx, lane_idx, smem, smem_offset, gptr, stride_n, kv_idx_base, - kv_len); + produce_kv_impl_cuda_(warp_idx, lane_idx, smem, smem_offset, gptr, + stride_n, kv_idx_base, kv_len); #endif } template __device__ __forceinline__ void page_produce_kv( - smem_t smem, - uint32_t *smem_offset, - const paged_kv_t - &paged_kv, - const uint32_t kv_idx_base, - const size_t *thr_local_kv_offset, - const uint32_t kv_len, - const dim3 tid = threadIdx) -{ - // NOTE: for fp8, this function doesn't work for head_dim = 64 at the moment - using DType = typename KTraits::DTypeKV; - constexpr SharedMemFillMode fill_mode = - produce_v ? SharedMemFillMode::kFillZero : SharedMemFillMode::kNoFill; - constexpr uint32_t NUM_WARPS = KTraits::NUM_WARPS; - constexpr uint32_t NUM_WARPS_Q = KTraits::NUM_WARPS_Q; - constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; - constexpr uint32_t NUM_MMA_D = - produce_v ? KTraits::NUM_MMA_D_VO : KTraits::NUM_MMA_D_QK; - constexpr uint32_t UPCAST_STRIDE = - produce_v ? KTraits::UPCAST_STRIDE_V : KTraits::UPCAST_STRIDE_K; - constexpr uint32_t VECTOR_BIT_WIDTH = KTraits::VECTOR_BIT_WIDTH; - - const uint32_t warp_idx = get_warp_idx(tid.y, tid.z), - lane_idx = tid.x; - if constexpr (KTraits::SWIZZLE_MODE_KV == SwizzleMode::k128B) { - uint32_t kv_idx = kv_idx_base + warp_idx * 4 + lane_idx / 8; - // NOTE: NUM_MMA_KV * 4/NUM_WARPS_Q=NUM_WARPS_KV*NUM_MMA_KV*4/num_warps - static_assert(NUM_MMA_KV * 4 % NUM_WARPS_Q == 0); -#pragma unroll - for (uint32_t i = 0; i < NUM_MMA_KV * 4 / NUM_WARPS_Q; ++i) { - DType *gptr = produce_v ? paged_kv.v_data + thr_local_kv_offset[i] - : paged_kv.k_data + thr_local_kv_offset[i]; -#pragma unroll - for (uint32_t j = 0; j < NUM_MMA_D / (8 / sizeof(DType)); ++j) { - smem.template load_vector_async(*smem_offset, gptr, - kv_idx < kv_len); - *smem_offset = - smem.template advance_offset_by_column<8>(*smem_offset, j); - gptr += 8 * upcast_size(); - } - kv_idx += NUM_WARPS * 4; - *smem_offset = smem.template advance_offset_by_row( - *smem_offset) - - sizeof(DType) * NUM_MMA_D; - } - *smem_offset -= KTraits::CTA_TILE_KV * UPCAST_STRIDE; + smem_t smem, uint32_t* smem_offset, + const paged_kv_t& paged_kv, + const uint32_t kv_idx_base, const size_t* thr_local_kv_offset, const uint32_t kv_len, + const dim3 tid = threadIdx) { + // NOTE: for fp8, this function doesn't work for head_dim = 64 at the moment + using DType = typename KTraits::DTypeKV; + constexpr SharedMemFillMode fill_mode = + produce_v ? SharedMemFillMode::kFillZero : SharedMemFillMode::kNoFill; + constexpr uint32_t NUM_WARPS = KTraits::NUM_WARPS; + constexpr uint32_t NUM_WARPS_Q = KTraits::NUM_WARPS_Q; + constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; + constexpr uint32_t NUM_MMA_D = produce_v ? KTraits::NUM_MMA_D_VO : KTraits::NUM_MMA_D_QK; + constexpr uint32_t UPCAST_STRIDE = + produce_v ? KTraits::UPCAST_STRIDE_V : KTraits::UPCAST_STRIDE_K; + constexpr uint32_t VECTOR_BIT_WIDTH = KTraits::VECTOR_BIT_WIDTH; + + const uint32_t warp_idx = get_warp_idx(tid.y, tid.z), lane_idx = tid.x; + if constexpr (KTraits::SWIZZLE_MODE_KV == SwizzleMode::k128B) { + uint32_t kv_idx = kv_idx_base + warp_idx * 4 + lane_idx / 8; + // NOTE: NUM_MMA_KV * 4/NUM_WARPS_Q=NUM_WARPS_KV*NUM_MMA_KV*4/num_warps + static_assert(NUM_MMA_KV * 4 % NUM_WARPS_Q == 0); +#pragma unroll + for (uint32_t i = 0; i < NUM_MMA_KV * 4 / NUM_WARPS_Q; ++i) { + DType* gptr = produce_v ? paged_kv.v_data + thr_local_kv_offset[i] + : paged_kv.k_data + thr_local_kv_offset[i]; +#pragma unroll + for (uint32_t j = 0; j < NUM_MMA_D / (8 / sizeof(DType)); ++j) { + smem.template load_vector_async(*smem_offset, gptr, kv_idx < kv_len); + *smem_offset = smem.template advance_offset_by_column<8>(*smem_offset, j); + gptr += 8 * upcast_size(); + } + kv_idx += NUM_WARPS * 4; + *smem_offset = + smem.template advance_offset_by_row(*smem_offset) - + sizeof(DType) * NUM_MMA_D; } - else { - uint32_t kv_idx = kv_idx_base + warp_idx * 8 + lane_idx / 4; - // NOTE: NUM_MMA_KV * 2 / NUM_WARPS_Q = NUM_WARPS_KV * NUM_MMA_KV * 2 / - // num_warps - static_assert(NUM_MMA_KV * 2 % NUM_WARPS_Q == 0); -#pragma unroll - for (uint32_t i = 0; i < NUM_MMA_KV * 2 / NUM_WARPS_Q; ++i) { - DType *gptr = produce_v ? paged_kv.v_data + thr_local_kv_offset[i] - : paged_kv.k_data + thr_local_kv_offset[i]; - smem.template load_vector_async(*smem_offset, gptr, - kv_idx < kv_len); - kv_idx += NUM_WARPS * 8; - *smem_offset = smem.template advance_offset_by_row( - *smem_offset); - } - *smem_offset -= KTraits::CTA_TILE_KV * UPCAST_STRIDE; + *smem_offset -= KTraits::CTA_TILE_KV * UPCAST_STRIDE; + } else { + uint32_t kv_idx = kv_idx_base + warp_idx * 8 + lane_idx / 4; + // NOTE: NUM_MMA_KV * 2 / NUM_WARPS_Q = NUM_WARPS_KV * NUM_MMA_KV * 2 / num_warps + static_assert(NUM_MMA_KV * 2 % NUM_WARPS_Q == 0); +#pragma unroll + for (uint32_t i = 0; i < NUM_MMA_KV * 2 / NUM_WARPS_Q; ++i) { + DType* gptr = produce_v ? paged_kv.v_data + thr_local_kv_offset[i] + : paged_kv.k_data + thr_local_kv_offset[i]; + smem.template load_vector_async(*smem_offset, gptr, kv_idx < kv_len); + kv_idx += NUM_WARPS * 8; + *smem_offset = + smem.template advance_offset_by_row(*smem_offset); } + *smem_offset -= KTraits::CTA_TILE_KV * UPCAST_STRIDE; + } } template -__device__ __forceinline__ uint32_t get_feature_index(uint32_t mma_d, - uint32_t lane_idx, - uint32_t j) -{ +__device__ __forceinline__ uint32_t get_feature_index(uint32_t mma_d, uint32_t lane_idx, + uint32_t j) { #if defined(PLATFORM_HIP_DEVICE) - // CDNA3 A-matrix MMA tile to thread mapping for a 64-thread wavefront: - // Each group of 16 threads handles the same four consecutive features for - // different sequences: - // T0-T15: Features [0,1,2,3] for sequences 0-15 respectively - // T16-T31: Features [4,5,6,7] for sequences 0-15 respectively - // T32-T47: Features [8,9,10,11] for sequences 0-15 respectively - // T48-T63: Features [12,13,14,15] for sequences 0-15 respectively - // - uint32_t feature_index = (mma_d * 16 + (lane_idx / 4) + j) % (HEAD_DIM / 2); + // CDNA3 A-matrix MMA tile to thread mapping for a 64-thread wavefront: + // Each group of 16 threads handles the same four consecutive features for + // different sequences: + // T0-T15: Features [0,1,2,3] for sequences 0-15 respectively + // T16-T31: Features [4,5,6,7] for sequences 0-15 respectively + // T32-T47: Features [8,9,10,11] for sequences 0-15 respectively + // T48-T63: Features [12,13,14,15] for sequences 0-15 respectively + // + uint32_t feature_index = (mma_d * 16 + (lane_idx / 4) + j) % (HEAD_DIM / 2); #elif defined(PLATFORM_CUDA_DEVICE) - // CUDA A-matrix MMA tile to thread mapping for a 32 thread warp: - // Each group of four consecutive threads map four different features for - // the same sequence. - // T0: {0,1,8,9}, T1: {2,3,10,11}, T2: {4,5,12,13}, T3: {6,7,14,15} - // - // The pattern repeats across 8 rows with each row mapped to a set of four - // consecutive threads. - // row 0 --> T0, T1, T2, T3 - // row 1 --> T4, T5, T6, T7 - // ... - // row 7 --> T28, T29, T30, T31 - // The full data to thread mapping repeats again for the next set of 16 - // rows. Thereby, forming a 16x16 MMA tile dubdivided into four 8x8 - // quadrants. - uint32_t feature_index = - ((mma_d * 16 + (j / 2) * 8 + (lane_idx % 4) * 2 + (j % 2)) % - (HEAD_DIM / 2)); + // CUDA A-matrix MMA tile to thread mapping for a 32 thread warp: + // Each group of four consecutive threads map four different features for + // the same sequence. + // T0: {0,1,8,9}, T1: {2,3,10,11}, T2: {4,5,12,13}, T3: {6,7,14,15} + // + // The pattern repeats across 8 rows with each row mapped to a set of four + // consecutive threads. + // row 0 --> T0, T1, T2, T3 + // row 1 --> T4, T5, T6, T7 + // ... + // row 7 --> T28, T29, T30, T31 + // The full data to thread mapping repeats again for the next set of 16 + // rows. Thereby, forming a 16x16 MMA tile dubdivided into four 8x8 + // quadrants. + uint32_t feature_index = + ((mma_d * 16 + (j / 2) * 8 + (lane_idx % 4) * 2 + (j % 2)) % (HEAD_DIM / 2)); #endif - return feature_index; + return feature_index; } template -__device__ __forceinline__ void -init_rope_freq(float (*rope_freq)[4], - const float rope_rcp_scale, - const float rope_rcp_theta, - const uint32_t tid_x = threadIdx.x) -{ - constexpr uint32_t HEAD_DIM = KTraits::NUM_MMA_D_QK * 16; - const uint32_t lane_idx = tid_x; - -#pragma unroll - for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO / 2; ++mma_d) { -#pragma unroll - for (uint32_t j = 0; j < 4; ++j) { - uint32_t feature_index = - get_feature_index(mma_d, lane_idx, j); - float freq_base = float(2 * feature_index) / float(HEAD_DIM); - rope_freq[mma_d][j] = - rope_rcp_scale * __powf(rope_rcp_theta, freq_base); - } +__device__ __forceinline__ void init_rope_freq(float (*rope_freq)[4], const float rope_rcp_scale, + const float rope_rcp_theta, + const uint32_t tid_x = threadIdx.x) { + constexpr uint32_t HEAD_DIM = KTraits::NUM_MMA_D_QK * 16; + const uint32_t lane_idx = tid_x; + +#pragma unroll + for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO / 2; ++mma_d) { +#pragma unroll + for (uint32_t j = 0; j < 4; ++j) { + uint32_t feature_index = get_feature_index(mma_d, lane_idx, j); + float freq_base = float(2 * feature_index) / float(HEAD_DIM); + rope_freq[mma_d][j] = rope_rcp_scale * __powf(rope_rcp_theta, freq_base); } + } } template @@ -688,1001 +559,768 @@ __device__ __forceinline__ void init_states( typename KTraits::AttentionVariant variant, float (*o_frag)[KTraits::NUM_MMA_D_VO][KTraits::HALF_ELEMS_PER_THREAD], typename KTraits::DTypeQKAccum (*m)[KTraits::NUM_ACCUM_ROWS_PER_THREAD], - float (*d)[KTraits::NUM_ACCUM_ROWS_PER_THREAD]) -{ - constexpr uint32_t NUM_ACCUM_ROWS_PER_THREAD = - KTraits::NUM_ACCUM_ROWS_PER_THREAD; + float (*d)[KTraits::NUM_ACCUM_ROWS_PER_THREAD]) { + constexpr uint32_t NUM_ACCUM_ROWS_PER_THREAD = KTraits::NUM_ACCUM_ROWS_PER_THREAD; #pragma unroll - for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { #pragma unroll - for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { + for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { #pragma unroll - for (uint32_t reg_id = 0; reg_id < KTraits::HALF_ELEMS_PER_THREAD; - ++reg_id) - { - o_frag[mma_q][mma_d][reg_id] = 1.f; - } - } + for (uint32_t reg_id = 0; reg_id < KTraits::HALF_ELEMS_PER_THREAD; ++reg_id) { + o_frag[mma_q][mma_d][reg_id] = 1.f; + } } + } - if constexpr (variant.use_softmax) { + if constexpr (variant.use_softmax) { #pragma unroll - for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { #pragma unroll - for (uint32_t j = 0; j < NUM_ACCUM_ROWS_PER_THREAD; ++j) { - m[mma_q][j] = - typename KTraits::DTypeQKAccum(-gpu_iface::math::inf); - d[mma_q][j] = 0.f; - } - } + for (uint32_t j = 0; j < NUM_ACCUM_ROWS_PER_THREAD; ++j) { + m[mma_q][j] = typename KTraits::DTypeQKAccum(-gpu_iface::math::inf); + d[mma_q][j] = 0.f; + } } + } } template __device__ __forceinline__ void load_q_global_smem( - uint32_t packed_offset, - const uint32_t qo_upper_bound, - typename KTraits::DTypeQ *q_ptr_base, - const uint32_t q_stride_n, - const uint32_t q_stride_h, - const uint_fastdiv group_size, - smem_t *q_smem, - const dim3 tid = threadIdx) -{ - using DTypeQ = typename KTraits::DTypeQ; - constexpr uint32_t WARP_THREAD_COLS = KTraits::WARP_THREAD_COLS; - constexpr uint32_t WARP_THREAD_ROWS = KTraits::WARP_THREAD_ROWS; - constexpr uint32_t HALF_ELEMS_PER_THREAD = KTraits::HALF_ELEMS_PER_THREAD; - constexpr uint32_t NUM_MMA_D_QK = KTraits::NUM_MMA_D_QK; - constexpr uint32_t UPCAST_STRIDE_Q = KTraits::UPCAST_STRIDE_Q; - constexpr uint32_t VECTOR_BIT_WIDTH = KTraits::VECTOR_BIT_WIDTH; + uint32_t packed_offset, const uint32_t qo_upper_bound, typename KTraits::DTypeQ* q_ptr_base, + const uint32_t q_stride_n, const uint32_t q_stride_h, const uint_fastdiv group_size, + smem_t* q_smem, + const dim3 tid = threadIdx) { + using DTypeQ = typename KTraits::DTypeQ; + constexpr uint32_t WARP_THREAD_COLS = KTraits::WARP_THREAD_COLS; + constexpr uint32_t WARP_THREAD_ROWS = KTraits::WARP_THREAD_ROWS; + constexpr uint32_t HALF_ELEMS_PER_THREAD = KTraits::HALF_ELEMS_PER_THREAD; + constexpr uint32_t NUM_MMA_D_QK = KTraits::NUM_MMA_D_QK; + constexpr uint32_t UPCAST_STRIDE_Q = KTraits::UPCAST_STRIDE_Q; + constexpr uint32_t VECTOR_BIT_WIDTH = KTraits::VECTOR_BIT_WIDTH; #if defined(PLATFORM_HIP_DEVICE) - constexpr uint32_t COLUMN_RESET_OFFSET = - (NUM_MMA_D_QK / 4) * WARP_THREAD_COLS; + constexpr uint32_t COLUMN_RESET_OFFSET = (NUM_MMA_D_QK / 4) * WARP_THREAD_COLS; #else - constexpr uint32_t COLUMN_RESET_OFFSET = 2 * KTraits::NUM_MMA_D_QK; + constexpr uint32_t COLUMN_RESET_OFFSET = 2 * KTraits::NUM_MMA_D_QK; #endif - const uint32_t lane_idx = tid.x, - warp_idx_x = get_warp_idx_q(tid.y); - uint32_t row = lane_idx / WARP_THREAD_COLS; - uint32_t col = lane_idx % WARP_THREAD_COLS; + const uint32_t lane_idx = tid.x, warp_idx_x = get_warp_idx_q(tid.y); + uint32_t row = lane_idx / WARP_THREAD_COLS; + uint32_t col = lane_idx % WARP_THREAD_COLS; - if (get_warp_idx_kv(tid.z) == 0) { - uint32_t q_smem_offset_w = - q_smem->template get_permuted_offset( - warp_idx_x * KTraits::NUM_MMA_Q * 16 + row, col); - -#pragma unroll - for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { -#pragma unroll - for (uint32_t j = 0; j < 2 * 2; ++j) { - uint32_t q, r; - group_size.divmod(packed_offset + row + mma_q * 16 + j * 4, q, - r); - const uint32_t q_idx = q; - DTypeQ *q_ptr = q_ptr_base + q * q_stride_n + r * q_stride_h + - col * upcast_size(); -#pragma unroll - for (uint32_t mma_do = 0; mma_do < KTraits::NUM_MMA_D_QK / 4; - ++mma_do) - { - // load q fragment from gmem to smem - q_smem->template load_vector_async< - SharedMemFillMode::kNoFill>(q_smem_offset_w, q_ptr, - q_idx < qo_upper_bound); - q_smem_offset_w = q_smem->template advance_offset_by_column< - WARP_THREAD_COLS>(q_smem_offset_w, mma_do); - q_ptr += HALF_ELEMS_PER_THREAD * - upcast_size(); - } - q_smem_offset_w = - q_smem->template advance_offset_by_row( - q_smem_offset_w) - - COLUMN_RESET_OFFSET; - } + if (get_warp_idx_kv(tid.z) == 0) { + uint32_t q_smem_offset_w = q_smem->template get_permuted_offset( + warp_idx_x * KTraits::NUM_MMA_Q * 16 + row, col); + +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t j = 0; j < 2 * 2; ++j) { + uint32_t q, r; + group_size.divmod(packed_offset + row + mma_q * 16 + j * 4, q, r); + const uint32_t q_idx = q; + DTypeQ* q_ptr = q_ptr_base + q * q_stride_n + r * q_stride_h + + col * upcast_size(); +#pragma unroll + for (uint32_t mma_do = 0; mma_do < KTraits::NUM_MMA_D_QK / 4; ++mma_do) { + // load q fragment from gmem to smem + q_smem->template load_vector_async(q_smem_offset_w, q_ptr, + q_idx < qo_upper_bound); + q_smem_offset_w = + q_smem->template advance_offset_by_column(q_smem_offset_w, mma_do); + q_ptr += HALF_ELEMS_PER_THREAD * upcast_size(); } + q_smem_offset_w = q_smem->template advance_offset_by_row( + q_smem_offset_w) - + COLUMN_RESET_OFFSET; + } } + } } template __device__ __forceinline__ void q_smem_inplace_apply_rotary( - const uint32_t q_packed_idx, - const uint32_t qo_len, - const uint32_t kv_len, + const uint32_t q_packed_idx, const uint32_t qo_len, const uint32_t kv_len, const uint_fastdiv group_size, - smem_t *q_smem, - uint32_t *q_smem_offset_r, - float (*rope_freq)[4], - const dim3 tid = threadIdx) -{ - if (get_warp_idx_kv(tid.z) != 0) - return; + smem_t* q_smem, + uint32_t* q_smem_offset_r, float (*rope_freq)[4], const dim3 tid = threadIdx) { + if (get_warp_idx_kv(tid.z) != 0) return; - constexpr uint32_t UPCAST_STRIDE_Q = KTraits::UPCAST_STRIDE_Q; - constexpr uint32_t COL_ADVANCE_TO_NEXT = - 16 / KTraits::HALF_ELEMS_PER_THREAD; + constexpr uint32_t UPCAST_STRIDE_Q = KTraits::UPCAST_STRIDE_Q; + constexpr uint32_t COL_ADVANCE_TO_NEXT = 16 / KTraits::HALF_ELEMS_PER_THREAD; #if defined(PLATFORM_HIP_DEVICE) - constexpr uint32_t COL_ADVANCE_TO_LAST_HALF = KTraits::NUM_MMA_D_QK * 2; + constexpr uint32_t COL_ADVANCE_TO_LAST_HALF = KTraits::NUM_MMA_D_QK * 2; #elif defined(PLATFORM_CUDA_DEVICE) - constexpr uint32_t COL_ADVANCE_TO_LAST_HALF = KTraits::NUM_MMA_D_QK; + constexpr uint32_t COL_ADVANCE_TO_LAST_HALF = KTraits::NUM_MMA_D_QK; #endif - const uint32_t lane_idx = tid.x; - uint32_t q_frag_local[2][KTraits::INT32_ELEMS_PER_THREAD]; - static_assert(KTraits::NUM_MMA_D_QK % 4 == 0, - "NUM_MMA_D_QK must be a multiple of 4"); + const uint32_t lane_idx = tid.x; + uint32_t q_frag_local[2][KTraits::INT32_ELEMS_PER_THREAD]; + static_assert(KTraits::NUM_MMA_D_QK % 4 == 0, "NUM_MMA_D_QK must be a multiple of 4"); #pragma unroll - for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { - uint32_t q_smem_offset_r_first_half = *q_smem_offset_r; + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { + uint32_t q_smem_offset_r_first_half = *q_smem_offset_r; #if defined(PLATFORM_HIP_DEVICE) - const uint32_t seq_id = q_packed_idx + kv_len * group_size - - qo_len * group_size + mma_q * 16 + - lane_idx % 16; + const uint32_t seq_id = + q_packed_idx + kv_len * group_size - qo_len * group_size + mma_q * 16 + lane_idx % 16; #elif defined(PLATFORM_CUDA_DEVICE) - const uint32_t seq_id = q_packed_idx + kv_len * group_size - - qo_len * group_size + mma_q * 16 + lane_idx / 4; + const uint32_t seq_id = + q_packed_idx + kv_len * group_size - qo_len * group_size + mma_q * 16 + lane_idx / 4; #endif #pragma unroll - for (uint32_t mma_di = 0; mma_di < KTraits::NUM_MMA_D_QK / 2; ++mma_di) - { - q_smem->template load_fragment(q_smem_offset_r_first_half, - q_frag_local[0]); - uint32_t q_smem_offset_r_last_half = - q_smem->template advance_offset_by_column< - COL_ADVANCE_TO_LAST_HALF>(q_smem_offset_r_first_half, 0); - q_smem->template load_fragment(q_smem_offset_r_last_half, - q_frag_local[1]); - q_frag_apply_llama_rope( - (typename KTraits::DTypeQ *)q_frag_local[0], - (typename KTraits::DTypeQ *)q_frag_local[1], rope_freq[mma_di], - seq_id, group_size); - q_smem->template store_fragment(q_smem_offset_r_last_half, - q_frag_local[1]); - q_smem->template store_fragment(q_smem_offset_r_first_half, - q_frag_local[0]); - q_smem_offset_r_first_half = - q_smem->template advance_offset_by_column( - q_smem_offset_r_first_half, mma_di); - } - *q_smem_offset_r += 16 * UPCAST_STRIDE_Q; + for (uint32_t mma_di = 0; mma_di < KTraits::NUM_MMA_D_QK / 2; ++mma_di) { + q_smem->template load_fragment(q_smem_offset_r_first_half, q_frag_local[0]); + uint32_t q_smem_offset_r_last_half = + q_smem->template advance_offset_by_column( + q_smem_offset_r_first_half, 0); + q_smem->template load_fragment(q_smem_offset_r_last_half, q_frag_local[1]); + q_frag_apply_llama_rope( + (typename KTraits::DTypeQ*)q_frag_local[0], (typename KTraits::DTypeQ*)q_frag_local[1], + rope_freq[mma_di], seq_id, group_size); + q_smem->template store_fragment(q_smem_offset_r_last_half, q_frag_local[1]); + q_smem->template store_fragment(q_smem_offset_r_first_half, q_frag_local[0]); + q_smem_offset_r_first_half = q_smem->template advance_offset_by_column( + q_smem_offset_r_first_half, mma_di); } - *q_smem_offset_r -= KTraits::NUM_MMA_Q * 16 * UPCAST_STRIDE_Q; + *q_smem_offset_r += 16 * UPCAST_STRIDE_Q; + } + *q_smem_offset_r -= KTraits::NUM_MMA_Q * 16 * UPCAST_STRIDE_Q; } template __device__ __forceinline__ void q_smem_inplace_apply_rotary_with_pos( - const uint32_t q_packed_idx_base, - const typename KTraits::IdType *q_rope_offset, - smem_t *q_smem, - const uint_fastdiv group_size, - uint32_t *q_smem_offset_r, - float (*rope_freq)[4], - const dim3 tid = threadIdx) -{ - if (get_warp_idx_kv(tid.z) == 0) { - constexpr uint32_t UPCAST_STRIDE_Q = KTraits::UPCAST_STRIDE_Q; - const uint32_t lane_idx = tid.x; - uint32_t q_frag_local[2][KTraits::INT32_ELEMS_PER_THREAD]; - static_assert(KTraits::NUM_MMA_D_QK % 4 == 0, - "NUM_MMA_D_QK must be a multiple of 4"); -#pragma unroll - for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { - uint32_t q_smem_offset_r_first_half = *q_smem_offset_r; -#pragma unroll - for (uint32_t mma_di = 0; mma_di < KTraits::NUM_MMA_D_QK / 2; - ++mma_di) - { - q_smem->load_fragment(q_smem_offset_r_first_half, - q_frag_local[0]); - uint32_t q_smem_offset_r_last_half = - q_smem->template advance_offset_by_column< - KTraits::NUM_MMA_D_QK>(q_smem_offset_r_first_half, 0); - q_smem->load_fragment(q_smem_offset_r_last_half, - q_frag_local[1]); - q_frag_apply_llama_rope_with_pos< - typename KTraits::DTypeQ, typename KTraits::IdType, - KTraits::HALF_ELEMS_PER_THREAD>( - (typename KTraits::DTypeQ *)q_frag_local[0], - (typename KTraits::DTypeQ *)q_frag_local[1], - rope_freq[mma_di], - q_packed_idx_base + mma_q * 16 + - lane_idx / KTraits::THREADS_PER_BMATRIX_ROW_SET, - group_size, q_rope_offset); - q_smem->store_fragment(q_smem_offset_r_last_half, - q_frag_local[1]); - q_smem->store_fragment(q_smem_offset_r_first_half, - q_frag_local[0]); - q_smem_offset_r_first_half = - q_smem->template advance_offset_by_column<2>( - q_smem_offset_r_first_half, mma_di); - } - *q_smem_offset_r += 16 * UPCAST_STRIDE_Q; - } - *q_smem_offset_r -= KTraits::NUM_MMA_Q * 16 * UPCAST_STRIDE_Q; + const uint32_t q_packed_idx_base, const typename KTraits::IdType* q_rope_offset, + smem_t* q_smem, + const uint_fastdiv group_size, uint32_t* q_smem_offset_r, float (*rope_freq)[4], + const dim3 tid = threadIdx) { + if (get_warp_idx_kv(tid.z) == 0) { + constexpr uint32_t UPCAST_STRIDE_Q = KTraits::UPCAST_STRIDE_Q; + const uint32_t lane_idx = tid.x; + uint32_t q_frag_local[2][KTraits::INT32_ELEMS_PER_THREAD]; + static_assert(KTraits::NUM_MMA_D_QK % 4 == 0, "NUM_MMA_D_QK must be a multiple of 4"); +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { + uint32_t q_smem_offset_r_first_half = *q_smem_offset_r; +#pragma unroll + for (uint32_t mma_di = 0; mma_di < KTraits::NUM_MMA_D_QK / 2; ++mma_di) { + q_smem->load_fragment(q_smem_offset_r_first_half, q_frag_local[0]); + uint32_t q_smem_offset_r_last_half = + q_smem->template advance_offset_by_column( + q_smem_offset_r_first_half, 0); + q_smem->load_fragment(q_smem_offset_r_last_half, q_frag_local[1]); + q_frag_apply_llama_rope_with_pos( + (typename KTraits::DTypeQ*)q_frag_local[0], (typename KTraits::DTypeQ*)q_frag_local[1], + rope_freq[mma_di], + q_packed_idx_base + mma_q * 16 + lane_idx / KTraits::THREADS_PER_BMATRIX_ROW_SET, + group_size, q_rope_offset); + q_smem->store_fragment(q_smem_offset_r_last_half, q_frag_local[1]); + q_smem->store_fragment(q_smem_offset_r_first_half, q_frag_local[0]); + q_smem_offset_r_first_half = + q_smem->template advance_offset_by_column<2>(q_smem_offset_r_first_half, mma_di); + } + *q_smem_offset_r += 16 * UPCAST_STRIDE_Q; } + *q_smem_offset_r -= KTraits::NUM_MMA_Q * 16 * UPCAST_STRIDE_Q; + } } template __device__ __forceinline__ void k_smem_inplace_apply_rotary( const uint32_t kv_idx_base, - smem_t *k_smem, - uint32_t *k_smem_offset_r, - float (*rope_freq)[4], - const dim3 tid = threadIdx) -{ - using DTypeKV = typename KTraits::DTypeKV; - static_assert(sizeof(DTypeKV) == 2); - constexpr uint32_t UPCAST_STRIDE_K = KTraits::UPCAST_STRIDE_K; - constexpr uint32_t THREADS_PER_BMATRIX_ROW_SET = - KTraits::THREADS_PER_BMATRIX_ROW_SET; - constexpr uint32_t HALF_ELEMS_PER_THREAD = KTraits::HALF_ELEMS_PER_THREAD; - uint32_t k_frag_local[2][KTraits::INT32_ELEMS_PER_THREAD]; - const uint32_t lane_idx = tid.x; - if constexpr (KTraits::NUM_MMA_D_QK == 4 && KTraits::NUM_WARPS_Q == 4) { - static_assert(KTraits::NUM_WARPS_KV == 1); - const uint32_t warp_idx = get_warp_idx_q(tid.y); - // horizontal-axis: y - // vertical-axis: z - // | 1-16 | 16-32 | 32-48 | 48-64 | - // | 1-16 | warp_idx=0 | warp_idx=1 | warp_idx=0 | warp_idx=1 | - // | 16-32 | warp_idx=2 | warp_idx=3 | warp_idx=2 | warp_idx=3 | - static_assert( - KTraits::NUM_MMA_KV % 2 == 0, - "when NUM_MMA_D_QK == 4, NUM_MMA_KV must be a multiple of 2"); - uint32_t kv_idx = kv_idx_base + (warp_idx / 2) * 16 + - lane_idx / THREADS_PER_BMATRIX_ROW_SET; - *k_smem_offset_r = (*k_smem_offset_r ^ (0x2 * (warp_idx % 2))) + - (warp_idx / 2) * 16 * UPCAST_STRIDE_K; -#pragma unroll - for (uint32_t i = 0; i < KTraits::NUM_MMA_KV / 2; ++i) { - uint32_t k_smem_offset_r_first_half = *k_smem_offset_r; - uint32_t mma_di = (warp_idx % 2); - k_smem->load_fragment(k_smem_offset_r_first_half, k_frag_local[0]); - uint32_t k_smem_offset_r_last_half = - k_smem->template advance_offset_by_column<4>( - k_smem_offset_r_first_half, 0); - k_smem->load_fragment(k_smem_offset_r_last_half, k_frag_local[1]); - k_frag_apply_llama_rope( - (DTypeKV *)k_frag_local[0], (DTypeKV *)k_frag_local[1], - rope_freq[mma_di], kv_idx); - k_smem->store_fragment(k_smem_offset_r_last_half, k_frag_local[1]); - k_smem->store_fragment(k_smem_offset_r_first_half, k_frag_local[0]); - *k_smem_offset_r += 32 * UPCAST_STRIDE_K; - kv_idx += 32; - } - *k_smem_offset_r = - (*k_smem_offset_r ^ (0x2 * (warp_idx % 2))) - - ((warp_idx / 2) + KTraits::NUM_MMA_KV) * 16 * UPCAST_STRIDE_K; + smem_t* k_smem, + uint32_t* k_smem_offset_r, float (*rope_freq)[4], const dim3 tid = threadIdx) { + using DTypeKV = typename KTraits::DTypeKV; + static_assert(sizeof(DTypeKV) == 2); + constexpr uint32_t UPCAST_STRIDE_K = KTraits::UPCAST_STRIDE_K; + constexpr uint32_t THREADS_PER_BMATRIX_ROW_SET = KTraits::THREADS_PER_BMATRIX_ROW_SET; + constexpr uint32_t HALF_ELEMS_PER_THREAD = KTraits::HALF_ELEMS_PER_THREAD; + uint32_t k_frag_local[2][KTraits::INT32_ELEMS_PER_THREAD]; + const uint32_t lane_idx = tid.x; + if constexpr (KTraits::NUM_MMA_D_QK == 4 && KTraits::NUM_WARPS_Q == 4) { + static_assert(KTraits::NUM_WARPS_KV == 1); + const uint32_t warp_idx = get_warp_idx_q(tid.y); + // horizontal-axis: y + // vertical-axis: z + // | 1-16 | 16-32 | 32-48 | 48-64 | + // | 1-16 | warp_idx=0 | warp_idx=1 | warp_idx=0 | warp_idx=1 | + // | 16-32 | warp_idx=2 | warp_idx=3 | warp_idx=2 | warp_idx=3 | + static_assert(KTraits::NUM_MMA_KV % 2 == 0, + "when NUM_MMA_D_QK == 4, NUM_MMA_KV must be a multiple of 2"); + uint32_t kv_idx = kv_idx_base + (warp_idx / 2) * 16 + lane_idx / THREADS_PER_BMATRIX_ROW_SET; + *k_smem_offset_r = + (*k_smem_offset_r ^ (0x2 * (warp_idx % 2))) + (warp_idx / 2) * 16 * UPCAST_STRIDE_K; +#pragma unroll + for (uint32_t i = 0; i < KTraits::NUM_MMA_KV / 2; ++i) { + uint32_t k_smem_offset_r_first_half = *k_smem_offset_r; + uint32_t mma_di = (warp_idx % 2); + k_smem->load_fragment(k_smem_offset_r_first_half, k_frag_local[0]); + uint32_t k_smem_offset_r_last_half = + k_smem->template advance_offset_by_column<4>(k_smem_offset_r_first_half, 0); + k_smem->load_fragment(k_smem_offset_r_last_half, k_frag_local[1]); + k_frag_apply_llama_rope( + (DTypeKV*)k_frag_local[0], (DTypeKV*)k_frag_local[1], rope_freq[mma_di], kv_idx); + k_smem->store_fragment(k_smem_offset_r_last_half, k_frag_local[1]); + k_smem->store_fragment(k_smem_offset_r_first_half, k_frag_local[0]); + *k_smem_offset_r += 32 * UPCAST_STRIDE_K; + kv_idx += 32; } - else { - const uint32_t warp_idx_x = get_warp_idx_q(tid.y), - warp_idx_z = get_warp_idx_kv(tid.z); - static_assert(KTraits::NUM_MMA_D_QK % (2 * KTraits::NUM_WARPS_Q) == 0); - // horizontal axis: y - // vertical axis: z - // | (warp_idx_z, warp_idx_x) | 1-16 | 16-32 | 32-48 | 48-64 - // | ... | 1-16*NUM_MMA_KV | (0, 0) | (0, 1) | (0, 2) | (0, 3) - // | ... | 16*NUM_MMA_KV-32*NUM_MMA_KV | (1, 0) | (1, 1) | (1, 2) | (1, 3) - // | ... ... - uint32_t kv_idx = kv_idx_base + - (warp_idx_z * KTraits::NUM_MMA_KV * 16) + - lane_idx / THREADS_PER_BMATRIX_ROW_SET; - *k_smem_offset_r = *k_smem_offset_r ^ (0x2 * warp_idx_x); -#pragma unroll - for (uint32_t i = 0; i < KTraits::NUM_MMA_KV; ++i) { - uint32_t k_smem_offset_r_first_half = *k_smem_offset_r; -#pragma unroll - for (uint32_t j = 0; - j < KTraits::NUM_MMA_D_QK / (2 * KTraits::NUM_WARPS_Q); ++j) - { - uint32_t mma_di = warp_idx_x + j * KTraits::NUM_WARPS_Q; - k_smem->load_fragment(k_smem_offset_r_first_half, - k_frag_local[0]); - uint32_t k_smem_offset_r_last_half = - k_smem->template advance_offset_by_column< - KTraits::NUM_MMA_D_QK>(k_smem_offset_r_first_half, 0); - k_smem->load_fragment(k_smem_offset_r_last_half, - k_frag_local[1]); - k_frag_apply_llama_rope( - (DTypeKV *)k_frag_local[0], (DTypeKV *)k_frag_local[1], - rope_freq[mma_di], kv_idx); - k_smem->store_fragment(k_smem_offset_r_last_half, - k_frag_local[1]); - k_smem->store_fragment(k_smem_offset_r_first_half, - k_frag_local[0]); - k_smem_offset_r_first_half = - k_smem->template advance_offset_by_column< - 2 * KTraits::NUM_WARPS_Q>(k_smem_offset_r_first_half, - mma_di); - } - *k_smem_offset_r += 16 * UPCAST_STRIDE_K; - kv_idx += 16; - } - *k_smem_offset_r = (*k_smem_offset_r ^ (0x2 * warp_idx_x)) - - KTraits::NUM_MMA_KV * 16 * UPCAST_STRIDE_K; + *k_smem_offset_r = (*k_smem_offset_r ^ (0x2 * (warp_idx % 2))) - + ((warp_idx / 2) + KTraits::NUM_MMA_KV) * 16 * UPCAST_STRIDE_K; + } else { + const uint32_t warp_idx_x = get_warp_idx_q(tid.y), + warp_idx_z = get_warp_idx_kv(tid.z); + static_assert(KTraits::NUM_MMA_D_QK % (2 * KTraits::NUM_WARPS_Q) == 0); + // horizontal axis: y + // vertical axis: z + // | (warp_idx_z, warp_idx_x) | 1-16 | 16-32 | 32-48 | 48-64 + // | ... | 1-16*NUM_MMA_KV | (0, 0) | (0, 1) | (0, 2) | (0, 3) + // | ... | 16*NUM_MMA_KV-32*NUM_MMA_KV | (1, 0) | (1, 1) | (1, 2) | (1, 3) + // | ... ... + uint32_t kv_idx = kv_idx_base + (warp_idx_z * KTraits::NUM_MMA_KV * 16) + + lane_idx / THREADS_PER_BMATRIX_ROW_SET; + *k_smem_offset_r = *k_smem_offset_r ^ (0x2 * warp_idx_x); +#pragma unroll + for (uint32_t i = 0; i < KTraits::NUM_MMA_KV; ++i) { + uint32_t k_smem_offset_r_first_half = *k_smem_offset_r; +#pragma unroll + for (uint32_t j = 0; j < KTraits::NUM_MMA_D_QK / (2 * KTraits::NUM_WARPS_Q); ++j) { + uint32_t mma_di = warp_idx_x + j * KTraits::NUM_WARPS_Q; + k_smem->load_fragment(k_smem_offset_r_first_half, k_frag_local[0]); + uint32_t k_smem_offset_r_last_half = + k_smem->template advance_offset_by_column( + k_smem_offset_r_first_half, 0); + k_smem->load_fragment(k_smem_offset_r_last_half, k_frag_local[1]); + k_frag_apply_llama_rope( + (DTypeKV*)k_frag_local[0], (DTypeKV*)k_frag_local[1], rope_freq[mma_di], kv_idx); + k_smem->store_fragment(k_smem_offset_r_last_half, k_frag_local[1]); + k_smem->store_fragment(k_smem_offset_r_first_half, k_frag_local[0]); + k_smem_offset_r_first_half = + k_smem->template advance_offset_by_column<2 * KTraits::NUM_WARPS_Q>( + k_smem_offset_r_first_half, mma_di); + } + *k_smem_offset_r += 16 * UPCAST_STRIDE_K; + kv_idx += 16; } + *k_smem_offset_r = + (*k_smem_offset_r ^ (0x2 * warp_idx_x)) - KTraits::NUM_MMA_KV * 16 * UPCAST_STRIDE_K; + } } template __device__ __forceinline__ void compute_qk( - smem_t *q_smem, - uint32_t *q_smem_offset_r, - smem_t *k_smem, - uint32_t *k_smem_offset_r, - typename KTraits::DTypeQKAccum ( - *s_frag)[KTraits::NUM_MMA_KV][KTraits::HALF_ELEMS_PER_THREAD]) -{ - constexpr uint32_t UPCAST_STRIDE_Q = KTraits::UPCAST_STRIDE_Q; - constexpr uint32_t UPCAST_STRIDE_K = KTraits::UPCAST_STRIDE_K; - constexpr uint32_t QK_SMEM_COLUMN_ADVANCE = - 16 / KTraits::HALF_ELEMS_PER_THREAD; + smem_t* q_smem, + uint32_t* q_smem_offset_r, + smem_t* k_smem, + uint32_t* k_smem_offset_r, + typename KTraits::DTypeQKAccum (*s_frag)[KTraits::NUM_MMA_KV][KTraits::HALF_ELEMS_PER_THREAD]) { + constexpr uint32_t UPCAST_STRIDE_Q = KTraits::UPCAST_STRIDE_Q; + constexpr uint32_t UPCAST_STRIDE_K = KTraits::UPCAST_STRIDE_K; + constexpr uint32_t QK_SMEM_COLUMN_ADVANCE = 16 / KTraits::HALF_ELEMS_PER_THREAD; - uint32_t a_frag[KTraits::NUM_MMA_Q][KTraits::INT32_ELEMS_PER_THREAD], - b_frag[KTraits::INT32_ELEMS_PER_THREAD]; - // compute q*k^T + uint32_t a_frag[KTraits::NUM_MMA_Q][KTraits::INT32_ELEMS_PER_THREAD], + b_frag[KTraits::INT32_ELEMS_PER_THREAD]; + // compute q*k^T #pragma unroll - for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_QK; ++mma_d) { + for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_QK; ++mma_d) { #pragma unroll - for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { - q_smem->load_fragment(*q_smem_offset_r, a_frag[mma_q]); - *q_smem_offset_r = - q_smem->template advance_offset_by_row<16, UPCAST_STRIDE_Q>( - *q_smem_offset_r); - } + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { + q_smem->load_fragment(*q_smem_offset_r, a_frag[mma_q]); + *q_smem_offset_r = + q_smem->template advance_offset_by_row<16, UPCAST_STRIDE_Q>(*q_smem_offset_r); + } - *q_smem_offset_r = - q_smem->template advance_offset_by_column( - *q_smem_offset_r, mma_d) - - KTraits::NUM_MMA_Q * 16 * UPCAST_STRIDE_Q; + *q_smem_offset_r = + q_smem->template advance_offset_by_column(*q_smem_offset_r, mma_d) - + KTraits::NUM_MMA_Q * 16 * UPCAST_STRIDE_Q; #pragma unroll - for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { - if constexpr (sizeof(typename KTraits::DTypeKV) == 1) { + for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { + if constexpr (sizeof(typename KTraits::DTypeKV) == 1) { #if defined(PLATFORM_HIP_DEVICE) - static_assert(false, - "FP8 support not yet implemented for CDNA3"); + static_assert(false, "FP8 support not yet implemented for CDNA3"); #else - uint32_t b_frag_f8[2]; - if (mma_d % 2 == 0) { - k_smem->ldmatrix_m8n8x4_left_half(*k_smem_offset_r, - b_frag_f8); - } - else { - k_smem->ldmatrix_m8n8x4_right_half(*k_smem_offset_r, - b_frag_f8); - } - b_frag_f8[0] = frag_layout_swizzle_16b_to_8b(b_frag_f8[0]); - b_frag_f8[1] = frag_layout_swizzle_16b_to_8b(b_frag_f8[1]); - vec_cast:: - template cast<8>((typename KTraits::DTypeQ *)b_frag, - (typename KTraits::DTypeKV *)b_frag_f8); + uint32_t b_frag_f8[2]; + if (mma_d % 2 == 0) { + k_smem->ldmatrix_m8n8x4_left_half(*k_smem_offset_r, b_frag_f8); + } else { + k_smem->ldmatrix_m8n8x4_right_half(*k_smem_offset_r, b_frag_f8); + } + b_frag_f8[0] = frag_layout_swizzle_16b_to_8b(b_frag_f8[0]); + b_frag_f8[1] = frag_layout_swizzle_16b_to_8b(b_frag_f8[1]); + vec_cast::template cast<8>( + (typename KTraits::DTypeQ*)b_frag, (typename KTraits::DTypeKV*)b_frag_f8); #endif - } - else { - k_smem->load_fragment(*k_smem_offset_r, b_frag); - } - - *k_smem_offset_r = - k_smem->template advance_offset_by_row<16, UPCAST_STRIDE_K>( - *k_smem_offset_r); - -#pragma unroll - for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { - if constexpr (std::is_same_v) - { - if (mma_d == 0) { - mma::mma_sync_m16n16k16_row_col_f16f16f32< - typename KTraits::DTypeQ, MMAMode::kInit>( - s_frag[mma_q][mma_kv], a_frag[mma_q], b_frag); - } - else { - mma::mma_sync_m16n16k16_row_col_f16f16f32< - typename KTraits::DTypeQ>(s_frag[mma_q][mma_kv], - a_frag[mma_q], b_frag); - } - } + } else { + k_smem->load_fragment(*k_smem_offset_r, b_frag); + } + + *k_smem_offset_r = + k_smem->template advance_offset_by_row<16, UPCAST_STRIDE_K>(*k_smem_offset_r); + +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { + if constexpr (std::is_same_v) { + if (mma_d == 0) { + mma::mma_sync_m16n16k16_row_col_f16f16f32( + s_frag[mma_q][mma_kv], a_frag[mma_q], b_frag); + } else { + mma::mma_sync_m16n16k16_row_col_f16f16f32( + s_frag[mma_q][mma_kv], a_frag[mma_q], b_frag); + } + } - else if (std::is_same_v) { + else if (std::is_same_v) { #if defined(PLATFORM_HIP_DEVICE) - static_assert( - false, - "FP16 DTypeQKAccum not yet implemented for CDNA3"); + static_assert(false, "FP16 DTypeQKAccum not yet implemented for CDNA3"); #else - if (mma_d == 0) { - mma::mma_sync_m16n16k16_row_col_f16f16f16< - MMAMode::kInit>((uint32_t *)s_frag[mma_q][mma_kv], - a_frag[mma_q], b_frag); - } - else { - mma::mma_sync_m16n16k16_row_col_f16f16f16( - (uint32_t *)s_frag[mma_q][mma_kv], a_frag[mma_q], - b_frag); - } + if (mma_d == 0) { + mma::mma_sync_m16n16k16_row_col_f16f16f16( + (uint32_t*)s_frag[mma_q][mma_kv], a_frag[mma_q], b_frag); + } else { + mma::mma_sync_m16n16k16_row_col_f16f16f16((uint32_t*)s_frag[mma_q][mma_kv], + a_frag[mma_q], b_frag); + } #endif - } - } - } - if constexpr (sizeof(typename KTraits::DTypeKV) == 1) { - if (mma_d % 2 == 1) { - *k_smem_offset_r = k_smem->template advance_offset_by_column< - QK_SMEM_COLUMN_ADVANCE>(*k_smem_offset_r, mma_d / 2); - } - *k_smem_offset_r -= KTraits::NUM_MMA_KV * 16 * UPCAST_STRIDE_K; - } - else { - *k_smem_offset_r = - k_smem - ->template advance_offset_by_column( - *k_smem_offset_r, mma_d) - - KTraits::NUM_MMA_KV * 16 * UPCAST_STRIDE_K; } + } } - *q_smem_offset_r -= KTraits::NUM_MMA_D_QK * QK_SMEM_COLUMN_ADVANCE; + if constexpr (sizeof(typename KTraits::DTypeKV) == 1) { + if (mma_d % 2 == 1) { + *k_smem_offset_r = k_smem->template advance_offset_by_column( + *k_smem_offset_r, mma_d / 2); + } + *k_smem_offset_r -= KTraits::NUM_MMA_KV * 16 * UPCAST_STRIDE_K; + } else { + *k_smem_offset_r = k_smem->template advance_offset_by_column( + *k_smem_offset_r, mma_d) - + KTraits::NUM_MMA_KV * 16 * UPCAST_STRIDE_K; + } + } + *q_smem_offset_r -= KTraits::NUM_MMA_D_QK * QK_SMEM_COLUMN_ADVANCE; #if defined(PLATFORM_HIP_DEVICE) - *k_smem_offset_r -= KTraits::NUM_MMA_D_QK * (QK_SMEM_COLUMN_ADVANCE); + *k_smem_offset_r -= KTraits::NUM_MMA_D_QK * (QK_SMEM_COLUMN_ADVANCE); #elif defined(PLATFORM_CUDA_DEVICE) - *k_smem_offset_r -= - KTraits::NUM_MMA_D_QK * sizeof(typename KTraits::DTypeKV); + *k_smem_offset_r -= KTraits::NUM_MMA_D_QK * sizeof(typename KTraits::DTypeKV); #endif } template __device__ __forceinline__ void logits_transform( - const Params ¶ms, - typename KTraits::AttentionVariant variant, - const uint32_t batch_idx, - const uint32_t qo_packed_idx_base, - const uint32_t kv_idx_base, - const uint32_t qo_len, - const uint32_t kv_len, - const uint_fastdiv group_size, + const Params& params, typename KTraits::AttentionVariant variant, const uint32_t batch_idx, + const uint32_t qo_packed_idx_base, const uint32_t kv_idx_base, const uint32_t qo_len, + const uint32_t kv_len, const uint_fastdiv group_size, DTypeQKAccum (*s_frag)[KTraits::NUM_MMA_KV][KTraits::HALF_ELEMS_PER_THREAD], - const dim3 tid = threadIdx, - const uint32_t kv_head_idx = blockIdx.z) -{ - constexpr uint32_t TPR = KTraits::THREADS_PER_BMATRIX_ROW_SET; - constexpr uint32_t NAPTR = KTraits::NUM_ACCUM_ROWS_PER_THREAD; - constexpr uint32_t LIS = KTraits::LOGITS_INDEX_STRIDE; + const dim3 tid = threadIdx, const uint32_t kv_head_idx = blockIdx.z) { + constexpr uint32_t TPR = KTraits::THREADS_PER_BMATRIX_ROW_SET; + constexpr uint32_t NAPTR = KTraits::NUM_ACCUM_ROWS_PER_THREAD; + constexpr uint32_t LIS = KTraits::LOGITS_INDEX_STRIDE; - const uint32_t lane_idx = tid.x; - uint32_t q[KTraits::NUM_MMA_Q][NAPTR], r[KTraits::NUM_MMA_Q][NAPTR]; - float logits = 0., logitsTransformed = 0.; + const uint32_t lane_idx = tid.x; + uint32_t q[KTraits::NUM_MMA_Q][NAPTR], r[KTraits::NUM_MMA_Q][NAPTR]; + float logits = 0., logitsTransformed = 0.; #pragma unroll - for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { #pragma unroll - for (uint32_t j = 0; j < NAPTR; ++j) { - group_size.divmod(qo_packed_idx_base + mma_q * 16 + lane_idx / TPR + - LIS * j, - q[mma_q][j], r[mma_q][j]); - } + for (uint32_t j = 0; j < NAPTR; ++j) { + group_size.divmod(qo_packed_idx_base + mma_q * 16 + lane_idx / TPR + LIS * j, q[mma_q][j], + r[mma_q][j]); } + } #pragma unroll - for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { #pragma unroll - for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { + for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { #pragma unroll - for (uint32_t reg_id = 0; reg_id < KTraits::HALF_ELEMS_PER_THREAD; - ++reg_id) - { + for (uint32_t reg_id = 0; reg_id < KTraits::HALF_ELEMS_PER_THREAD; ++reg_id) { #if defined(PLATFORM_HIP_DEVICE) - const uint32_t q_idx = q[mma_q][reg_id % NAPTR]; - const uint32_t qo_head_idx = - kv_head_idx * group_size + r[mma_q][reg_id % NAPTR]; - const uint32_t kv_idx = - kv_idx_base + mma_kv * 16 + (lane_idx % TPR); + const uint32_t q_idx = q[mma_q][reg_id % NAPTR]; + const uint32_t qo_head_idx = kv_head_idx * group_size + r[mma_q][reg_id % NAPTR]; + const uint32_t kv_idx = kv_idx_base + mma_kv * 16 + (lane_idx % TPR); #else - const uint32_t q_idx = q[mma_q][(reg_id % 4) / 2], - kv_idx = kv_idx_base + mma_kv * 16 + - 2 * (lane_idx % 4) + 8 * (reg_id / 4) + - reg_id % 2; - const uint32_t qo_head_idx = - kv_head_idx * group_size + r[mma_q][(reg_id % 4) / 2]; + const uint32_t q_idx = q[mma_q][(reg_id % 4) / 2], kv_idx = kv_idx_base + mma_kv * 16 + + 2 * (lane_idx % 4) + + 8 * (reg_id / 4) + reg_id % 2; + const uint32_t qo_head_idx = kv_head_idx * group_size + r[mma_q][(reg_id % 4) / 2]; #endif #ifdef FP16_QK_REDUCTION_SUPPORTED - if constexpr (std::is_same::value) { - logits = std::bit_cast( - fp16_ieee_to_fp32_value(s_frag[mma_q][mma_kv][reg_id])); - } - else if constexpr (!std::is_same::value) { - logits = s_frag[mma_q][mma_kv][reg_id]; - } + if constexpr (std::is_same::value) { + logits = std::bit_cast(fp16_ieee_to_fp32_value(s_frag[mma_q][mma_kv][reg_id])); + } else if constexpr (!std::is_same::value) { + logits = s_frag[mma_q][mma_kv][reg_id]; + } #else - static_assert( - !std::is_same::value, - "Set -DFP16_QK_REDUCTION_SUPPORTED and install boost_math " - "then recompile to support fp16 reduction"); - logits = s_frag[mma_q][mma_kv][reg_id]; + static_assert(!std::is_same::value, + "Set -DFP16_QK_REDUCTION_SUPPORTED and install boost_math " + "then recompile to support fp16 reduction"); + logits = s_frag[mma_q][mma_kv][reg_id]; #endif - logitsTransformed = - variant.LogitsTransform(params, logits, batch_idx, q_idx, - kv_idx, qo_head_idx, kv_head_idx); + logitsTransformed = variant.LogitsTransform(params, logits, batch_idx, q_idx, kv_idx, + qo_head_idx, kv_head_idx); #if Debug1 - const uint32_t lane_idx = tid.x, - warp_idx = get_warp_idx(tid.y, tid.z); + const uint32_t lane_idx = tid.x, warp_idx = get_warp_idx(tid.y, tid.z); - if (warp_idx == 0 && lane_idx == 0) { - printf("logits : %f logitsTransformed: %f\n", float(logits), - float(logitsTransformed)); - } + if (warp_idx == 0 && lane_idx == 0) { + printf("logits : %f logitsTransformed: %f\n", float(logits), float(logitsTransformed)); + } #endif #ifdef FP16_QK_REDUCTION_SUPPORTED - if constexpr (std::is_same::value) { - s_frag[mma_q][mma_kv][reg_id] = std::bit_cast( - fp16_ieee_from_fp32_value(logitsTransformed)); - } - else if constexpr (!std::is_same::value) { - s_frag[mma_q][mma_kv][reg_id] = logitsTransformed; - } + if constexpr (std::is_same::value) { + s_frag[mma_q][mma_kv][reg_id] = + std::bit_cast(fp16_ieee_from_fp32_value(logitsTransformed)); + } else if constexpr (!std::is_same::value) { + s_frag[mma_q][mma_kv][reg_id] = logitsTransformed; + } #else - s_frag[mma_q][mma_kv][reg_id] = logitsTransformed; + s_frag[mma_q][mma_kv][reg_id] = logitsTransformed; #endif - } - } + } } + } } template -__device__ __forceinline__ void -logits_mask(const Params ¶ms, - typename KTraits::AttentionVariant variant, - const uint32_t batch_idx, - const uint32_t qo_packed_idx_base, - const uint32_t kv_idx_base, - const uint32_t qo_len, - const uint32_t kv_len, - const uint32_t chunk_end, - const uint_fastdiv group_size, - typename KTraits::DTypeQKAccum ( - *s_frag)[KTraits::NUM_MMA_KV][KTraits::HALF_ELEMS_PER_THREAD], - const dim3 tid = threadIdx, - const uint32_t kv_head_idx = blockIdx.z) -{ - const uint32_t lane_idx = tid.x; - constexpr uint32_t NUM_MMA_Q = KTraits::NUM_MMA_Q; - constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; - constexpr MaskMode MASK_MODE = KTraits::MASK_MODE; - constexpr uint32_t TPR = KTraits::THREADS_PER_BMATRIX_ROW_SET; - constexpr uint32_t NAPTR = KTraits::NUM_ACCUM_ROWS_PER_THREAD; - constexpr uint32_t LIS = KTraits::LOGITS_INDEX_STRIDE; - - uint32_t q[NUM_MMA_Q][NAPTR], r[NUM_MMA_Q][NAPTR]; -#pragma unroll - for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { -#pragma unroll - for (uint32_t j = 0; j < NAPTR; ++j) { - group_size.divmod(qo_packed_idx_base + mma_q * 16 + lane_idx / TPR + - LIS * j, - q[mma_q][j], r[mma_q][j]); - } +__device__ __forceinline__ void logits_mask( + const Params& params, typename KTraits::AttentionVariant variant, const uint32_t batch_idx, + const uint32_t qo_packed_idx_base, const uint32_t kv_idx_base, const uint32_t qo_len, + const uint32_t kv_len, const uint32_t chunk_end, const uint_fastdiv group_size, + typename KTraits::DTypeQKAccum (*s_frag)[KTraits::NUM_MMA_KV][KTraits::HALF_ELEMS_PER_THREAD], + const dim3 tid = threadIdx, const uint32_t kv_head_idx = blockIdx.z) { + const uint32_t lane_idx = tid.x; + constexpr uint32_t NUM_MMA_Q = KTraits::NUM_MMA_Q; + constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; + constexpr MaskMode MASK_MODE = KTraits::MASK_MODE; + constexpr uint32_t TPR = KTraits::THREADS_PER_BMATRIX_ROW_SET; + constexpr uint32_t NAPTR = KTraits::NUM_ACCUM_ROWS_PER_THREAD; + constexpr uint32_t LIS = KTraits::LOGITS_INDEX_STRIDE; + + uint32_t q[NUM_MMA_Q][NAPTR], r[NUM_MMA_Q][NAPTR]; +#pragma unroll + for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t j = 0; j < NAPTR; ++j) { + group_size.divmod(qo_packed_idx_base + mma_q * 16 + lane_idx / TPR + LIS * j, q[mma_q][j], + r[mma_q][j]); } + } #pragma unroll - for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { + for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { #pragma unroll - for (uint32_t mma_kv = 0; mma_kv < NUM_MMA_KV; ++mma_kv) { + for (uint32_t mma_kv = 0; mma_kv < NUM_MMA_KV; ++mma_kv) { #pragma unroll - for (uint32_t reg_id = 0; reg_id < KTraits::HALF_ELEMS_PER_THREAD; - ++reg_id) - { + for (uint32_t reg_id = 0; reg_id < KTraits::HALF_ELEMS_PER_THREAD; ++reg_id) { #if defined(PLATFORM_HIP_DEVICE) - const uint32_t q_idx = q[mma_q][(reg_id % NAPTR)]; - const uint32_t kv_idx = - kv_idx_base + mma_kv * 16 + (lane_idx % TPR); - const uint32_t qo_head_idx = - kv_head_idx * group_size + r[mma_q][(reg_id % NAPTR)]; + const uint32_t q_idx = q[mma_q][(reg_id % NAPTR)]; + const uint32_t kv_idx = kv_idx_base + mma_kv * 16 + (lane_idx % TPR); + const uint32_t qo_head_idx = kv_head_idx * group_size + r[mma_q][(reg_id % NAPTR)]; #else - const uint32_t q_idx = q[mma_q][(reg_id % 4) / 2], - kv_idx = kv_idx_base + mma_kv * 16 + - 2 * (lane_idx % TPR) + - 8 * (reg_id / 4) + reg_id % 2; - const uint32_t qo_head_idx = - kv_head_idx * group_size + r[mma_q][(reg_id % 4) / 2]; + const uint32_t q_idx = q[mma_q][(reg_id % 4) / 2], kv_idx = kv_idx_base + mma_kv * 16 + + 2 * (lane_idx % TPR) + + 8 * (reg_id / 4) + reg_id % 2; + const uint32_t qo_head_idx = kv_head_idx * group_size + r[mma_q][(reg_id % 4) / 2]; #endif - const bool mask = - (!(MASK_MODE == MaskMode::kCausal - ? (kv_idx + qo_len > kv_len + q_idx || - (kv_idx >= chunk_end)) - : kv_idx >= chunk_end)) && - variant.LogitsMask(params, batch_idx, q_idx, kv_idx, - qo_head_idx, kv_head_idx); - s_frag[mma_q][mma_kv][reg_id] = - (mask) ? s_frag[mma_q][mma_kv][reg_id] - : (KTraits::MaskFillValue); - } - } + const bool mask = + (!(MASK_MODE == MaskMode::kCausal + ? (kv_idx + qo_len > kv_len + q_idx || (kv_idx >= chunk_end)) + : kv_idx >= chunk_end)) && + variant.LogitsMask(params, batch_idx, q_idx, kv_idx, qo_head_idx, kv_head_idx); + s_frag[mma_q][mma_kv][reg_id] = + (mask) ? s_frag[mma_q][mma_kv][reg_id] : (KTraits::MaskFillValue); + } } + } } template __device__ __forceinline__ void update_mdo_states( typename KTraits::AttentionVariant variant, - typename KTraits::DTypeQKAccum ( - *s_frag)[KTraits::NUM_MMA_KV][KTraits::HALF_ELEMS_PER_THREAD], + typename KTraits::DTypeQKAccum (*s_frag)[KTraits::NUM_MMA_KV][KTraits::HALF_ELEMS_PER_THREAD], float (*o_frag)[KTraits::NUM_MMA_D_VO][KTraits::HALF_ELEMS_PER_THREAD], typename KTraits::DTypeQKAccum (*m)[KTraits::NUM_ACCUM_ROWS_PER_THREAD], - float (*d)[KTraits::NUM_ACCUM_ROWS_PER_THREAD], - uint32_t warp_idx = 0, - uint32_t lane_idx = 0) -{ - using DTypeQKAccum = typename KTraits::DTypeQKAccum; - using AttentionVariant = typename KTraits::AttentionVariant; - constexpr uint32_t NUM_ACCUM_ROWS_PER_THREAD = - KTraits::NUM_ACCUM_ROWS_PER_THREAD; - constexpr bool use_softmax = AttentionVariant::use_softmax; + float (*d)[KTraits::NUM_ACCUM_ROWS_PER_THREAD], uint32_t warp_idx = 0, uint32_t lane_idx = 0) { + using DTypeQKAccum = typename KTraits::DTypeQKAccum; + using AttentionVariant = typename KTraits::AttentionVariant; + constexpr uint32_t NUM_ACCUM_ROWS_PER_THREAD = KTraits::NUM_ACCUM_ROWS_PER_THREAD; + constexpr bool use_softmax = AttentionVariant::use_softmax; - if constexpr (use_softmax) { - const float sm_scale = variant.sm_scale_log2; - if constexpr (std::is_same_v) { + if constexpr (use_softmax) { + const float sm_scale = variant.sm_scale_log2; + if constexpr (std::is_same_v) { #pragma unroll - for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { #pragma unroll - for (uint32_t j = 0; j < NUM_ACCUM_ROWS_PER_THREAD; ++j) { - float m_prev = m[mma_q][j]; + for (uint32_t j = 0; j < NUM_ACCUM_ROWS_PER_THREAD; ++j) { + float m_prev = m[mma_q][j]; #if defined(PLATFORM_HIP_DEVICE) #pragma unroll - for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; - ++mma_kv) - { - m[mma_q][j] = - max(m[mma_q][j], s_frag[mma_q][mma_kv][j]); - } - // Butterfly reduction across all threads in the band - m[mma_q][j] = - max(m[mma_q][j], gpu_iface::math::shfl_xor_sync( - m[mma_q][j], 0x8)); // 16 apart - m[mma_q][j] = - max(m[mma_q][j], gpu_iface::math::shfl_xor_sync( - m[mma_q][j], 0x4)); // 8 apart - m[mma_q][j] = - max(m[mma_q][j], gpu_iface::math::shfl_xor_sync( - m[mma_q][j], 0x2)); // 4 apart - m[mma_q][j] = - max(m[mma_q][j], gpu_iface::math::shfl_xor_sync( - m[mma_q][j], 0x1)); // 2 apart - float o_scale = gpu_iface::math::ptx_exp2( - m_prev * sm_scale - m[mma_q][j] * sm_scale); - - // Scale output fragments for this specific row -#pragma unroll - for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; - ++mma_d) - { - o_frag[mma_q][mma_d][j] *= o_scale; - } - - // Convert logits to probabilities for this row -#pragma unroll - for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; - ++mma_kv) - { - s_frag[mma_q][mma_kv][j] = gpu_iface::math::ptx_exp2( - s_frag[mma_q][mma_kv][j] * sm_scale - - m[mma_q][j] * sm_scale); - } + for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { + m[mma_q][j] = max(m[mma_q][j], s_frag[mma_q][mma_kv][j]); + } + // Butterfly reduction across all threads in the band + m[mma_q][j] = + max(m[mma_q][j], gpu_iface::math::shfl_xor_sync(m[mma_q][j], 0x8)); // 16 apart + m[mma_q][j] = + max(m[mma_q][j], gpu_iface::math::shfl_xor_sync(m[mma_q][j], 0x4)); // 8 apart + m[mma_q][j] = + max(m[mma_q][j], gpu_iface::math::shfl_xor_sync(m[mma_q][j], 0x2)); // 4 apart + m[mma_q][j] = + max(m[mma_q][j], gpu_iface::math::shfl_xor_sync(m[mma_q][j], 0x1)); // 2 apart + float o_scale = gpu_iface::math::ptx_exp2(m_prev * sm_scale - m[mma_q][j] * sm_scale); + + // Scale output fragments for this specific row +#pragma unroll + for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { + o_frag[mma_q][mma_d][j] *= o_scale; + } + + // Convert logits to probabilities for this row +#pragma unroll + for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { + s_frag[mma_q][mma_kv][j] = gpu_iface::math::ptx_exp2( + s_frag[mma_q][mma_kv][j] * sm_scale - m[mma_q][j] * sm_scale); + } #elif (PLATFORM_CUDA_DEVICE) #pragma unroll - for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; - ++mma_kv) - { - auto m_local = max(max(s_frag[mma_q][mma_kv][0], - s_frag[mma_q][mma_kv][1]), - max(s_frag[mma_q][mma_kv][2], - s_frag[mma_q][mma_kv][3])); - m[mma_q][j] = max(m[mma_q][j], m_local); - } - - m[mma_q][j] = - max(m[mma_q][j], - gpu_iface::math::shfl_xor_sync(m[mma_q][j], 0x2)); - m[mma_q][j] = - max(m[mma_q][j], - gpu_iface::math::shfl_xor_sync(m[mma_q][j], 0x1)); - float o_scale = gpu_iface::math::ptx_exp2( - m_prev * sm_scale - m[mma_q][j] * sm_scale); - d[mma_q][j] *= o_scale; -#pragma unroll - for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; - ++mma_d) - { - o_frag[mma_q][mma_d][j * 2 + 0] *= o_scale; - o_frag[mma_q][mma_d][j * 2 + 1] *= o_scale; - o_frag[mma_q][mma_d][j * 2 + 4] *= o_scale; - o_frag[mma_q][mma_d][j * 2 + 5] *= o_scale; - } -#pragma unroll - for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; - ++mma_kv) - { - s_frag[mma_q][mma_kv][j * 2 + 0] = - gpu_iface::math::ptx_exp2( - s_frag[mma_q][mma_kv][j * 2 + 0] * sm_scale - - m[mma_q][j] * sm_scale); - s_frag[mma_q][mma_kv][j * 2 + 1] = - gpu_iface::math::ptx_exp2( - s_frag[mma_q][mma_kv][j * 2 + 1] * sm_scale - - m[mma_q][j] * sm_scale); - s_frag[mma_q][mma_kv][j * 2 + 4] = - gpu_iface::math::ptx_exp2( - s_frag[mma_q][mma_kv][j * 2 + 4] * sm_scale - - m[mma_q][j] * sm_scale); - s_frag[mma_q][mma_kv][j * 2 + 5] = - gpu_iface::math::ptx_exp2( - s_frag[mma_q][mma_kv][j * 2 + 5] * sm_scale - - m[mma_q][j] * sm_scale); - } + for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { + auto m_local = max(max(s_frag[mma_q][mma_kv][0], s_frag[mma_q][mma_kv][1]), + max(s_frag[mma_q][mma_kv][2], s_frag[mma_q][mma_kv][3])); + m[mma_q][j] = max(m[mma_q][j], m_local); + } + + m[mma_q][j] = max(m[mma_q][j], gpu_iface::math::shfl_xor_sync(m[mma_q][j], 0x2)); + m[mma_q][j] = max(m[mma_q][j], gpu_iface::math::shfl_xor_sync(m[mma_q][j], 0x1)); + float o_scale = gpu_iface::math::ptx_exp2(m_prev * sm_scale - m[mma_q][j] * sm_scale); + d[mma_q][j] *= o_scale; +#pragma unroll + for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { + o_frag[mma_q][mma_d][j * 2 + 0] *= o_scale; + o_frag[mma_q][mma_d][j * 2 + 1] *= o_scale; + o_frag[mma_q][mma_d][j * 2 + 4] *= o_scale; + o_frag[mma_q][mma_d][j * 2 + 5] *= o_scale; + } +#pragma unroll + for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { + s_frag[mma_q][mma_kv][j * 2 + 0] = gpu_iface::math::ptx_exp2( + s_frag[mma_q][mma_kv][j * 2 + 0] * sm_scale - m[mma_q][j] * sm_scale); + s_frag[mma_q][mma_kv][j * 2 + 1] = gpu_iface::math::ptx_exp2( + s_frag[mma_q][mma_kv][j * 2 + 1] * sm_scale - m[mma_q][j] * sm_scale); + s_frag[mma_q][mma_kv][j * 2 + 4] = gpu_iface::math::ptx_exp2( + s_frag[mma_q][mma_kv][j * 2 + 4] * sm_scale - m[mma_q][j] * sm_scale); + s_frag[mma_q][mma_kv][j * 2 + 5] = gpu_iface::math::ptx_exp2( + s_frag[mma_q][mma_kv][j * 2 + 5] * sm_scale - m[mma_q][j] * sm_scale); + } #endif - } - } } - else if constexpr (std::is_same_v) { + } + } else if constexpr (std::is_same_v) { #if defined(PLATFORM_HIP_DEVICE) - static_assert( - false, - "Half precision accumulator not yet implemented for AMD"); + static_assert(false, "Half precision accumulator not yet implemented for AMD"); #else - const half2 sm_scale = __float2half2_rn(variant.sm_scale_log2); -#pragma unroll - for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { - half m_prev[2]; -#pragma unroll - for (uint32_t j = 0; j < 2; ++j) { - m_prev[j] = m[mma_q][j]; -#pragma unroll - for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; - ++mma_kv) - { - half2 m_local = gpu_iface::math::hmax2( - *(half2 *)&s_frag[mma_q][mma_kv][j * 2], - *(half2 *)&s_frag[mma_q][mma_kv][j * 2 + 4]); - m[mma_q][j] = - __hmax(m[mma_q][j], __hmax(m_local.x, m_local.y)); - } - } - *(half2 *)&m[mma_q] = gpu_iface::math::hmax2( - *(half2 *)&m[mma_q], - gpu_iface::math::shfl_xor_sync(*(half2 *)&m[mma_q], 0x2)); - *(half2 *)&m[mma_q] = gpu_iface::math::hmax2( - *(half2 *)&m[mma_q], - gpu_iface::math::shfl_xor_sync(*(half2 *)&m[mma_q], 0x1)); -#pragma unroll - for (uint32_t j = 0; j < 2; ++j) { - float o_scale = gpu_iface::math::ptx_exp2(float( - m_prev[j] * sm_scale.x - m[mma_q][j] * sm_scale.x)); - d[mma_q][j] *= o_scale; -#pragma unroll - for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; - ++mma_d) - { - o_frag[mma_q][mma_d][j * 2 + 0] *= o_scale; - o_frag[mma_q][mma_d][j * 2 + 1] *= o_scale; - o_frag[mma_q][mma_d][j * 2 + 4] *= o_scale; - o_frag[mma_q][mma_d][j * 2 + 5] *= o_scale; - } - half2 m2 = make_half2(m[mma_q][j], m[mma_q][j]); -#pragma unroll - for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; - ++mma_kv) - { - *(half2 *)&s_frag[mma_q][mma_kv][j * 2] = - gpu_iface::math::ptx_exp2( - *(half2 *)&s_frag[mma_q][mma_kv][j * 2] * - sm_scale - - m2 * sm_scale); - *(half2 *)&s_frag[mma_q][mma_kv][j * 2 + 4] = - gpu_iface::math::ptx_exp2( - *(half2 *)&s_frag[mma_q][mma_kv][j * 2 + 4] * - sm_scale - - m2 * sm_scale); - } - } - } -#endif + const half2 sm_scale = __float2half2_rn(variant.sm_scale_log2); +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { + half m_prev[2]; +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + m_prev[j] = m[mma_q][j]; +#pragma unroll + for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { + half2 m_local = gpu_iface::math::hmax2(*(half2*)&s_frag[mma_q][mma_kv][j * 2], + *(half2*)&s_frag[mma_q][mma_kv][j * 2 + 4]); + m[mma_q][j] = __hmax(m[mma_q][j], __hmax(m_local.x, m_local.y)); + } } + *(half2*)&m[mma_q] = gpu_iface::math::hmax2( + *(half2*)&m[mma_q], gpu_iface::math::shfl_xor_sync(*(half2*)&m[mma_q], 0x2)); + *(half2*)&m[mma_q] = gpu_iface::math::hmax2( + *(half2*)&m[mma_q], gpu_iface::math::shfl_xor_sync(*(half2*)&m[mma_q], 0x1)); +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + float o_scale = + gpu_iface::math::ptx_exp2(float(m_prev[j] * sm_scale.x - m[mma_q][j] * sm_scale.x)); + d[mma_q][j] *= o_scale; +#pragma unroll + for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { + o_frag[mma_q][mma_d][j * 2 + 0] *= o_scale; + o_frag[mma_q][mma_d][j * 2 + 1] *= o_scale; + o_frag[mma_q][mma_d][j * 2 + 4] *= o_scale; + o_frag[mma_q][mma_d][j * 2 + 5] *= o_scale; + } + half2 m2 = make_half2(m[mma_q][j], m[mma_q][j]); +#pragma unroll + for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { + *(half2*)&s_frag[mma_q][mma_kv][j * 2] = gpu_iface::math::ptx_exp2( + *(half2*)&s_frag[mma_q][mma_kv][j * 2] * sm_scale - m2 * sm_scale); + *(half2*)&s_frag[mma_q][mma_kv][j * 2 + 4] = gpu_iface::math::ptx_exp2( + *(half2*)&s_frag[mma_q][mma_kv][j * 2 + 4] * sm_scale - m2 * sm_scale); + } + } + } +#endif } + } } template __device__ __forceinline__ void compute_sfm_v( - smem_t *v_smem, - uint32_t *v_smem_offset_r, - typename KTraits::DTypeQKAccum ( - *s_frag)[KTraits::NUM_MMA_KV][KTraits::HALF_ELEMS_PER_THREAD], + smem_t* v_smem, + uint32_t* v_smem_offset_r, + typename KTraits::DTypeQKAccum (*s_frag)[KTraits::NUM_MMA_KV][KTraits::HALF_ELEMS_PER_THREAD], float (*o_frag)[KTraits::NUM_MMA_D_VO][KTraits::HALF_ELEMS_PER_THREAD], - float (*d)[KTraits::NUM_ACCUM_ROWS_PER_THREAD]) -{ - constexpr uint32_t UPCAST_STRIDE_V = KTraits::UPCAST_STRIDE_V; - constexpr uint32_t HALF_ELEMS_PER_THREAD = KTraits::HALF_ELEMS_PER_THREAD; - constexpr uint32_t INT32_ELEMS_PER_THREAD = KTraits::INT32_ELEMS_PER_THREAD; - constexpr uint32_t V_SMEM_COLUMN_ADVANCE = - 16 / KTraits::HALF_ELEMS_PER_THREAD; - - typename KTraits::DTypeQ s_frag_f16[KTraits::NUM_MMA_Q][KTraits::NUM_MMA_KV] - [HALF_ELEMS_PER_THREAD]; - if constexpr (std::is_same_v) { -#pragma unroll - for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { -#pragma unroll - for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { - vec_cast::template cast< - HALF_ELEMS_PER_THREAD>(s_frag_f16[mma_q][mma_kv], - s_frag[mma_q][mma_kv]); - } - } + float (*d)[KTraits::NUM_ACCUM_ROWS_PER_THREAD]) { + constexpr uint32_t UPCAST_STRIDE_V = KTraits::UPCAST_STRIDE_V; + constexpr uint32_t HALF_ELEMS_PER_THREAD = KTraits::HALF_ELEMS_PER_THREAD; + constexpr uint32_t INT32_ELEMS_PER_THREAD = KTraits::INT32_ELEMS_PER_THREAD; + constexpr uint32_t V_SMEM_COLUMN_ADVANCE = 16 / KTraits::HALF_ELEMS_PER_THREAD; + + typename KTraits::DTypeQ s_frag_f16[KTraits::NUM_MMA_Q][KTraits::NUM_MMA_KV] + [HALF_ELEMS_PER_THREAD]; + if constexpr (std::is_same_v) { +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { + vec_cast::template cast( + s_frag_f16[mma_q][mma_kv], s_frag[mma_q][mma_kv]); + } } + } - if constexpr (KTraits::AttentionVariant::use_softmax) { + if constexpr (KTraits::AttentionVariant::use_softmax) { #pragma unroll - for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { #pragma unroll - for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { - if constexpr (std::is_same_v) - { - mma::m16k16_rowsum_f16f16f32(d[mma_q], - s_frag_f16[mma_q][mma_kv]); - } - else { + for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { + if constexpr (std::is_same_v) { + mma::m16k16_rowsum_f16f16f32(d[mma_q], s_frag_f16[mma_q][mma_kv]); + } else { #if defined(PLATFORM_HIP_DEVICE) - static_assert( - !std::is_same_v, + static_assert(!std::is_same_v, "FP16 reduction path not implemented for CDNA3"); #else - mma::m16k16_rowsum_f16f16f32(d[mma_q], - s_frag[mma_q][mma_kv]); + mma::m16k16_rowsum_f16f16f32(d[mma_q], s_frag[mma_q][mma_kv]); #endif - } - } } + } } + } #pragma unroll - for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { + for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { #pragma unroll - for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { - uint32_t b_frag[INT32_ELEMS_PER_THREAD]; - if constexpr (sizeof(typename KTraits::DTypeKV) == 1) { + for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { + uint32_t b_frag[INT32_ELEMS_PER_THREAD]; + if constexpr (sizeof(typename KTraits::DTypeKV) == 1) { #if defined(PLATFORM_HIP_DEVICE) - static_assert(false, - "FP8 V path not implemented for CDNA3 yet"); + static_assert(false, "FP8 V path not implemented for CDNA3 yet"); #else - uint32_t b_frag_f8[2]; - if (mma_d % 2 == 0) { - v_smem->ldmatrix_m8n8x4_trans_left_half(*v_smem_offset_r, - b_frag_f8); - } - else { - v_smem->ldmatrix_m8n8x4_trans_right_half(*v_smem_offset_r, - b_frag_f8); - } - b_frag_f8[0] = - frag_layout_swizzle_16b_to_8b_trans(b_frag_f8[0]); - b_frag_f8[1] = - frag_layout_swizzle_16b_to_8b_trans(b_frag_f8[1]); - vec_cast:: - template cast<8>((typename KTraits::DTypeQ *)b_frag, - (typename KTraits::DTypeKV *)b_frag_f8); - swap(b_frag[1], b_frag[2]); + uint32_t b_frag_f8[2]; + if (mma_d % 2 == 0) { + v_smem->ldmatrix_m8n8x4_trans_left_half(*v_smem_offset_r, b_frag_f8); + } else { + v_smem->ldmatrix_m8n8x4_trans_right_half(*v_smem_offset_r, b_frag_f8); + } + b_frag_f8[0] = frag_layout_swizzle_16b_to_8b_trans(b_frag_f8[0]); + b_frag_f8[1] = frag_layout_swizzle_16b_to_8b_trans(b_frag_f8[1]); + vec_cast::template cast<8>( + (typename KTraits::DTypeQ*)b_frag, (typename KTraits::DTypeKV*)b_frag_f8); + swap(b_frag[1], b_frag[2]); #endif - } - else { + } else { #if defined(PLATFORM_HIP_DEVICE) - v_smem->load_fragment_4x4_transposed(*v_smem_offset_r, b_frag); + v_smem->load_fragment_4x4_transposed(*v_smem_offset_r, b_frag); #else - v_smem->ldmatrix_m8n8x4_trans(*v_smem_offset_r, b_frag); + v_smem->ldmatrix_m8n8x4_trans(*v_smem_offset_r, b_frag); #endif - } -#pragma unroll - for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { - if constexpr (std::is_same_v) - { - mma::mma_sync_m16n16k16_row_col_f16f16f32< - typename KTraits::DTypeQ>( - o_frag[mma_q][mma_d], - (uint32_t *)s_frag_f16[mma_q][mma_kv], b_frag); - } - else { - mma::mma_sync_m16n16k16_row_col_f16f16f32< - typename KTraits::DTypeQ>( - o_frag[mma_q][mma_d], (uint32_t *)s_frag[mma_q][mma_kv], - b_frag); - } - } - if constexpr (sizeof(typename KTraits::DTypeKV) == 1) { - if (mma_d % 2 == 1) { - *v_smem_offset_r = - v_smem->template advance_offset_by_column< - V_SMEM_COLUMN_ADVANCE>(*v_smem_offset_r, mma_d / 2); - } - } - else { - *v_smem_offset_r = v_smem->template advance_offset_by_column< - V_SMEM_COLUMN_ADVANCE>(*v_smem_offset_r, mma_d); - } + } +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { + if constexpr (std::is_same_v) { + mma::mma_sync_m16n16k16_row_col_f16f16f32( + o_frag[mma_q][mma_d], (uint32_t*)s_frag_f16[mma_q][mma_kv], b_frag); + } else { + mma::mma_sync_m16n16k16_row_col_f16f16f32( + o_frag[mma_q][mma_d], (uint32_t*)s_frag[mma_q][mma_kv], b_frag); } - *v_smem_offset_r = - v_smem->template advance_offset_by_row<16, UPCAST_STRIDE_V>( - *v_smem_offset_r) - - sizeof(typename KTraits::DTypeKV) * KTraits::NUM_MMA_D_VO; + } + if constexpr (sizeof(typename KTraits::DTypeKV) == 1) { + if (mma_d % 2 == 1) { + *v_smem_offset_r = v_smem->template advance_offset_by_column( + *v_smem_offset_r, mma_d / 2); + } + } else { + *v_smem_offset_r = v_smem->template advance_offset_by_column( + *v_smem_offset_r, mma_d); + } } - *v_smem_offset_r -= 16 * KTraits::NUM_MMA_KV * UPCAST_STRIDE_V; + *v_smem_offset_r = + v_smem->template advance_offset_by_row<16, UPCAST_STRIDE_V>(*v_smem_offset_r) - + sizeof(typename KTraits::DTypeKV) * KTraits::NUM_MMA_D_VO; + } + *v_smem_offset_r -= 16 * KTraits::NUM_MMA_KV * UPCAST_STRIDE_V; } template __device__ __forceinline__ void normalize_d( float (*o_frag)[KTraits::NUM_MMA_D_VO][KTraits::HALF_ELEMS_PER_THREAD], typename KTraits::DTypeQKAccum (*m)[KTraits::NUM_ACCUM_ROWS_PER_THREAD], - float (*d)[KTraits::NUM_ACCUM_ROWS_PER_THREAD]) -{ - using AttentionVariant = typename KTraits::AttentionVariant; - constexpr uint32_t NAPTR = KTraits::NUM_ACCUM_ROWS_PER_THREAD; + float (*d)[KTraits::NUM_ACCUM_ROWS_PER_THREAD]) { + using AttentionVariant = typename KTraits::AttentionVariant; + constexpr uint32_t NAPTR = KTraits::NUM_ACCUM_ROWS_PER_THREAD; - if constexpr (AttentionVariant::use_softmax) { - float d_rcp[KTraits::NUM_MMA_Q][KTraits::NUM_ACCUM_ROWS_PER_THREAD]; - // compute reciprocal of d + if constexpr (AttentionVariant::use_softmax) { + float d_rcp[KTraits::NUM_MMA_Q][KTraits::NUM_ACCUM_ROWS_PER_THREAD]; + // compute reciprocal of d #pragma unroll - for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { #pragma unroll - for (uint32_t j = 0; j < KTraits::NUM_ACCUM_ROWS_PER_THREAD; ++j) { - d_rcp[mma_q][j] = - (m[mma_q][j] != - typename KTraits::DTypeQKAccum(-gpu_iface::math::inf)) - ? gpu_iface::math::ptx_rcp(d[mma_q][j]) - : 0.f; - } - } + for (uint32_t j = 0; j < KTraits::NUM_ACCUM_ROWS_PER_THREAD; ++j) { + d_rcp[mma_q][j] = (m[mma_q][j] != typename KTraits::DTypeQKAccum(-gpu_iface::math::inf)) + ? gpu_iface::math::ptx_rcp(d[mma_q][j]) + : 0.f; + } + } #pragma unroll - for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { #pragma unroll - for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { + for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { #pragma unroll - for (uint32_t reg_id = 0; - reg_id < KTraits::HALF_ELEMS_PER_THREAD; ++reg_id) - { + for (uint32_t reg_id = 0; reg_id < KTraits::HALF_ELEMS_PER_THREAD; ++reg_id) { #if defined(PLATFORM_HIP_DEVICE) - o_frag[mma_q][mma_d][reg_id] = - o_frag[mma_q][mma_d][reg_id] * - d_rcp[mma_q][reg_id % NAPTR]; + o_frag[mma_q][mma_d][reg_id] = + o_frag[mma_q][mma_d][reg_id] * d_rcp[mma_q][reg_id % NAPTR]; #else - o_frag[mma_q][mma_d][reg_id] = - o_frag[mma_q][mma_d][reg_id] * - d_rcp[mma_q][(reg_id % 4) / 2]; + o_frag[mma_q][mma_d][reg_id] = + o_frag[mma_q][mma_d][reg_id] * d_rcp[mma_q][(reg_id % 4) / 2]; #endif - } - } } + } } + } } template __device__ __forceinline__ void finalize_m( typename KTraits::AttentionVariant variant, - typename KTraits::DTypeQKAccum (*m)[KTraits::NUM_ACCUM_ROWS_PER_THREAD]) -{ - if constexpr (variant.use_softmax) { + typename KTraits::DTypeQKAccum (*m)[KTraits::NUM_ACCUM_ROWS_PER_THREAD]) { + if constexpr (variant.use_softmax) { #pragma unroll - for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { #pragma unroll - for (uint32_t j = 0; j < KTraits::NUM_ACCUM_ROWS_PER_THREAD; ++j) { - if (m[mma_q][j] != - typename KTraits::DTypeQKAccum(-gpu_iface::math::inf)) - { - m[mma_q][j] *= variant.sm_scale_log2; - } - } + for (uint32_t j = 0; j < KTraits::NUM_ACCUM_ROWS_PER_THREAD; ++j) { + if (m[mma_q][j] != typename KTraits::DTypeQKAccum(-gpu_iface::math::inf)) { + m[mma_q][j] *= variant.sm_scale_log2; } + } } + } } /*! @@ -1692,354 +1330,285 @@ __device__ __forceinline__ void finalize_m( template __device__ __forceinline__ void threadblock_sync_mdo_states( float (*o_frag)[KTraits::NUM_MMA_D_VO][KTraits::HALF_ELEMS_PER_THREAD], - typename KTraits::SharedStorage *smem_storage, + typename KTraits::SharedStorage* smem_storage, typename KTraits::DTypeQKAccum (*m)[KTraits::NUM_ACCUM_ROWS_PER_THREAD], - float (*d)[KTraits::NUM_ACCUM_ROWS_PER_THREAD], - const uint32_t warp_idx, - const uint32_t lane_idx, - const dim3 tid = threadIdx) -{ - constexpr uint32_t TPR = KTraits::THREADS_PER_BMATRIX_ROW_SET; - constexpr uint32_t NARPT = KTraits::NUM_ACCUM_ROWS_PER_THREAD; - - static_assert(WARP_SIZE % TPR == 0, - "THREADS_PER_BMATRIX_ROW_SET must divide WARP_SIZE"); - constexpr uint32_t GROUPS_PER_WARP = WARP_SIZE / TPR; - const uint32_t lane_group_idx = lane_idx / TPR; - - // only necessary when blockDim.z > 1 - if constexpr (KTraits::NUM_WARPS_KV > 1) { - float *smem_o = smem_storage->cta_sync_o_smem; - float2 *smem_md = smem_storage->cta_sync_md_smem; - // o: [num_warps, - // NUM_MMA_Q, - // NUM_MMA_D_VO, - // WARP_SIZE, - // HALF_ELEMS_PER_THREAD] - // md: [num_warps, NUM_MMA_Q, 16, 2 (m/d)] -#pragma unroll - for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { -#pragma unroll - for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { - vec_t::memcpy( - smem_o + (((warp_idx * KTraits::NUM_MMA_Q + mma_q) * - KTraits::NUM_MMA_D_VO + - mma_d) * - WARP_SIZE + - lane_idx) * - KTraits::HALF_ELEMS_PER_THREAD, - o_frag[mma_q][mma_d]); - } - } - - if constexpr (KTraits::AttentionVariant::use_softmax) { + float (*d)[KTraits::NUM_ACCUM_ROWS_PER_THREAD], const uint32_t warp_idx, + const uint32_t lane_idx, const dim3 tid = threadIdx) { + constexpr uint32_t TPR = KTraits::THREADS_PER_BMATRIX_ROW_SET; + constexpr uint32_t NARPT = KTraits::NUM_ACCUM_ROWS_PER_THREAD; + + static_assert(WARP_SIZE % TPR == 0, "THREADS_PER_BMATRIX_ROW_SET must divide WARP_SIZE"); + constexpr uint32_t GROUPS_PER_WARP = WARP_SIZE / TPR; + const uint32_t lane_group_idx = lane_idx / TPR; + + // only necessary when blockDim.z > 1 + if constexpr (KTraits::NUM_WARPS_KV > 1) { + float* smem_o = smem_storage->cta_sync_o_smem; + float2* smem_md = smem_storage->cta_sync_md_smem; + // o: [num_warps, + // NUM_MMA_Q, + // NUM_MMA_D_VO, + // WARP_SIZE, + // HALF_ELEMS_PER_THREAD] + // md: [num_warps, NUM_MMA_Q, 16, 2 (m/d)] #pragma unroll - for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { #pragma unroll - for (uint32_t j = 0; j < NARPT; ++j) { - smem_md[((warp_idx * KTraits::NUM_MMA_Q + mma_q) * NARPT + - j) * - GROUPS_PER_WARP + - lane_group_idx] = - make_float2(float(m[mma_q][j]), d[mma_q][j]); - } - } - - // synchronize m,d first - __syncthreads(); -#pragma unroll - for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { - float o_scale[NARPT][KTraits::NUM_WARPS_KV]; -#pragma unroll - for (uint32_t j = 0; j < NARPT; ++j) { - float m_new = -gpu_iface::math::inf, d_new = 1.f; -#pragma unroll - for (uint32_t i = 0; i < KTraits::NUM_WARPS_KV; ++i) { - float2 md = smem_md[(((i * KTraits::NUM_WARPS_Q + - get_warp_idx_q(tid.y)) * - KTraits::NUM_MMA_Q + - mma_q) * - NARPT + - j) * - GROUPS_PER_WARP + - lane_group_idx]; - float m_prev = m_new, d_prev = d_new; - m_new = max(m_new, md.x); - d_new = - d_prev * gpu_iface::math::ptx_exp2(m_prev - m_new) + - md.y * gpu_iface::math::ptx_exp2(md.x - m_new); - } + for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { + vec_t::memcpy( + smem_o + (((warp_idx * KTraits::NUM_MMA_Q + mma_q) * KTraits::NUM_MMA_D_VO + mma_d) * + WARP_SIZE + + lane_idx) * + KTraits::HALF_ELEMS_PER_THREAD, + o_frag[mma_q][mma_d]); + } + } + if constexpr (KTraits::AttentionVariant::use_softmax) { #pragma unroll - for (uint32_t i = 0; i < KTraits::NUM_WARPS_KV; ++i) { - float2 md = smem_md[(((i * KTraits::NUM_WARPS_Q + - get_warp_idx_q(tid.y)) * - KTraits::NUM_MMA_Q + - mma_q) * - NARPT + - j) * - GROUPS_PER_WARP + - lane_group_idx]; - float mi = md.x; - o_scale[j][i] = - gpu_iface::math::ptx_exp2(float(mi - m_new)); - } - m[mma_q][j] = typename KTraits::DTypeQKAccum(m_new); - d[mma_q][j] = d_new; - } - + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { #pragma unroll - for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) - { - vec_t o_new; - o_new.fill(0.f); + for (uint32_t j = 0; j < NARPT; ++j) { + smem_md[((warp_idx * KTraits::NUM_MMA_Q + mma_q) * NARPT + j) * GROUPS_PER_WARP + + lane_group_idx] = make_float2(float(m[mma_q][j]), d[mma_q][j]); + } + } + + // synchronize m,d first + __syncthreads(); +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { + float o_scale[NARPT][KTraits::NUM_WARPS_KV]; +#pragma unroll + for (uint32_t j = 0; j < NARPT; ++j) { + float m_new = -gpu_iface::math::inf, d_new = 1.f; +#pragma unroll + for (uint32_t i = 0; i < KTraits::NUM_WARPS_KV; ++i) { + float2 md = smem_md[(((i * KTraits::NUM_WARPS_Q + get_warp_idx_q(tid.y)) * + KTraits::NUM_MMA_Q + + mma_q) * + NARPT + + j) * + GROUPS_PER_WARP + + lane_group_idx]; + float m_prev = m_new, d_prev = d_new; + m_new = max(m_new, md.x); + d_new = d_prev * gpu_iface::math::ptx_exp2(m_prev - m_new) + + md.y * gpu_iface::math::ptx_exp2(md.x - m_new); + } + +#pragma unroll + for (uint32_t i = 0; i < KTraits::NUM_WARPS_KV; ++i) { + float2 md = smem_md[(((i * KTraits::NUM_WARPS_Q + get_warp_idx_q(tid.y)) * + KTraits::NUM_MMA_Q + + mma_q) * + NARPT + + j) * + GROUPS_PER_WARP + + lane_group_idx]; + float mi = md.x; + o_scale[j][i] = gpu_iface::math::ptx_exp2(float(mi - m_new)); + } + m[mma_q][j] = typename KTraits::DTypeQKAccum(m_new); + d[mma_q][j] = d_new; + } + #pragma unroll - for (uint32_t i = 0; i < KTraits::NUM_WARPS_KV; ++i) { - vec_t oi; - oi.load(smem_o + ((((i * KTraits::NUM_WARPS_Q + - get_warp_idx_q(tid.y)) * - KTraits::NUM_MMA_Q + - mma_q) * - KTraits::NUM_MMA_D_VO + - mma_d) * - WARP_SIZE + - lane_idx) * - KTraits::HALF_ELEMS_PER_THREAD); + for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { + vec_t o_new; + o_new.fill(0.f); +#pragma unroll + for (uint32_t i = 0; i < KTraits::NUM_WARPS_KV; ++i) { + vec_t oi; + oi.load(smem_o + ((((i * KTraits::NUM_WARPS_Q + get_warp_idx_q(tid.y)) * + KTraits::NUM_MMA_Q + + mma_q) * + KTraits::NUM_MMA_D_VO + + mma_d) * + WARP_SIZE + + lane_idx) * + KTraits::HALF_ELEMS_PER_THREAD); #pragma unroll - for (uint32_t reg_id = 0; - reg_id < KTraits::HALF_ELEMS_PER_THREAD; ++reg_id) - { + for (uint32_t reg_id = 0; reg_id < KTraits::HALF_ELEMS_PER_THREAD; ++reg_id) { #if defined(PLATFORM_HIP_DEVICE) - // CDNA3: Direct mapping - each reg_id corresponds - // to one accumulator row - o_new[reg_id] += oi[reg_id] * o_scale[reg_id][i]; + // CDNA3: Direct mapping - each reg_id corresponds + // to one accumulator row + o_new[reg_id] += oi[reg_id] * o_scale[reg_id][i]; #else - // CUDA: Grouped mapping - 2 elements per - // accumulator row - o_new[reg_id] += - oi[reg_id] * o_scale[(reg_id % 4) / 2][i]; + // CUDA: Grouped mapping - 2 elements per + // accumulator row + o_new[reg_id] += oi[reg_id] * o_scale[(reg_id % 4) / 2][i]; #endif - } - } - o_new.store(o_frag[mma_q][mma_d]); - } } + } + o_new.store(o_frag[mma_q][mma_d]); } - else { - // synchronize m,d first - __syncthreads(); -#pragma unroll - for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { -#pragma unroll - for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) - { - vec_t o_new; - o_new.fill(0.f); -#pragma unroll - for (uint32_t i = 0; i < KTraits::NUM_WARPS_KV; ++i) { - vec_t oi; - oi.load(smem_o + ((((i * KTraits::NUM_WARPS_Q + - get_warp_idx_q(tid.y)) * - KTraits::NUM_MMA_Q + - mma_q) * - KTraits::NUM_MMA_D_VO + - mma_d) * - WARP_SIZE + - lane_idx) * - KTraits::HALF_ELEMS_PER_THREAD); -#pragma unroll - for (uint32_t reg_id = 0; - reg_id < KTraits::HALF_ELEMS_PER_THREAD; ++reg_id) - { - o_new[reg_id] += oi[reg_id]; - } - } - o_new.store(o_frag[mma_q][mma_d]); - } + } + } else { + // synchronize m,d first + __syncthreads(); +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { + vec_t o_new; + o_new.fill(0.f); +#pragma unroll + for (uint32_t i = 0; i < KTraits::NUM_WARPS_KV; ++i) { + vec_t oi; + oi.load(smem_o + ((((i * KTraits::NUM_WARPS_Q + get_warp_idx_q(tid.y)) * + KTraits::NUM_MMA_Q + + mma_q) * + KTraits::NUM_MMA_D_VO + + mma_d) * + WARP_SIZE + + lane_idx) * + KTraits::HALF_ELEMS_PER_THREAD); +#pragma unroll + for (uint32_t reg_id = 0; reg_id < KTraits::HALF_ELEMS_PER_THREAD; ++reg_id) { + o_new[reg_id] += oi[reg_id]; } + } + o_new.store(o_frag[mma_q][mma_d]); } + } } + } } template __device__ __forceinline__ void write_o_reg_gmem( float (*o_frag)[KTraits::NUM_MMA_D_VO][KTraits::HALF_ELEMS_PER_THREAD], - smem_t *o_smem, - typename KTraits::DTypeO *o_ptr_base, - const uint32_t o_packed_idx_base, - const uint32_t qo_upper_bound, - const uint32_t o_stride_n, - const uint32_t o_stride_h, - const uint_fastdiv group_size, - const dim3 tid = threadIdx) -{ - using DTypeO = typename KTraits::DTypeO; - constexpr uint32_t UPCAST_STRIDE_O = KTraits::UPCAST_STRIDE_O; - constexpr uint32_t TPR = KTraits::THREADS_PER_BMATRIX_ROW_SET; - constexpr uint32_t NAPTR = KTraits::NUM_ACCUM_ROWS_PER_THREAD; - constexpr uint32_t HALF_ELEMS_PER_THREAD = KTraits::HALF_ELEMS_PER_THREAD; - constexpr uint32_t WARP_THREAD_COLS = KTraits::WARP_THREAD_COLS; - constexpr uint32_t VECTOR_BIT_WIDTH = KTraits::VECTOR_BIT_WIDTH; - - const uint32_t warp_idx_x = get_warp_idx_q(tid.y); - const uint32_t lane_idx = tid.x; - - if constexpr (sizeof(DTypeO) == 4) { + smem_t* o_smem, + typename KTraits::DTypeO* o_ptr_base, const uint32_t o_packed_idx_base, + const uint32_t qo_upper_bound, const uint32_t o_stride_n, const uint32_t o_stride_h, + const uint_fastdiv group_size, const dim3 tid = threadIdx) { + using DTypeO = typename KTraits::DTypeO; + constexpr uint32_t UPCAST_STRIDE_O = KTraits::UPCAST_STRIDE_O; + constexpr uint32_t TPR = KTraits::THREADS_PER_BMATRIX_ROW_SET; + constexpr uint32_t NAPTR = KTraits::NUM_ACCUM_ROWS_PER_THREAD; + constexpr uint32_t HALF_ELEMS_PER_THREAD = KTraits::HALF_ELEMS_PER_THREAD; + constexpr uint32_t WARP_THREAD_COLS = KTraits::WARP_THREAD_COLS; + constexpr uint32_t VECTOR_BIT_WIDTH = KTraits::VECTOR_BIT_WIDTH; + + const uint32_t warp_idx_x = get_warp_idx_q(tid.y); + const uint32_t lane_idx = tid.x; + + if constexpr (sizeof(DTypeO) == 4) { #pragma unroll - for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { #pragma unroll - for (uint32_t j = 0; j < NAPTR; ++j) { - uint32_t q, r; - group_size.divmod(o_packed_idx_base + lane_idx / TPR + - mma_q * 16 + j * 8, - q, r); - const uint32_t o_idx = q; + for (uint32_t j = 0; j < NAPTR; ++j) { + uint32_t q, r; + group_size.divmod(o_packed_idx_base + lane_idx / TPR + mma_q * 16 + j * 8, q, r); + const uint32_t o_idx = q; #pragma unroll - for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) - { - if (o_idx < qo_upper_bound) { - auto base_addr = o_ptr_base + q * o_stride_n + - r * o_stride_h + mma_d * 16; - auto col_offset = lane_idx % 16; + for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { + if (o_idx < qo_upper_bound) { + auto base_addr = o_ptr_base + q * o_stride_n + r * o_stride_h + mma_d * 16; + auto col_offset = lane_idx % 16; #if defined(PLATFORM_HIP_DEVICE) - *(base_addr + col_offset) = o_frag[mma_q][mma_d][j]; + *(base_addr + col_offset) = o_frag[mma_q][mma_d][j]; #else - *reinterpret_cast(base_addr + - col_offset * 2) = - *reinterpret_cast( - &o_frag[mma_q][mma_d][j * 2]); - - *reinterpret_cast(base_addr + 8 + - col_offset * 2) = - *reinterpret_cast( - &o_frag[mma_q][mma_d][$ + j * 2]); + *reinterpret_cast(base_addr + col_offset * 2) = + *reinterpret_cast(&o_frag[mma_q][mma_d][j * 2]); + + *reinterpret_cast(base_addr + 8 + col_offset * 2) = + *reinterpret_cast(&o_frag[mma_q][mma_d][$ + j * 2]); #endif - } - } - } + } } + } } - else { - if (get_warp_idx_kv(tid.z) == 0) { + } else { + if (get_warp_idx_kv(tid.z) == 0) { #pragma unroll - for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { #pragma unroll - for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) - { - uint32_t o_frag_f16[HALF_ELEMS_PER_THREAD / 2]; - vec_cast::template cast< - HALF_ELEMS_PER_THREAD>((DTypeO *)o_frag_f16, - o_frag[mma_q][mma_d]); + for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { + uint32_t o_frag_f16[HALF_ELEMS_PER_THREAD / 2]; + vec_cast::template cast((DTypeO*)o_frag_f16, + o_frag[mma_q][mma_d]); #ifdef FLASHINFER_STMATRIX_M8N8X4_ENABLED - uint32_t o_smem_offset_w = - o_smem->template get_permuted_offset( - (warp_idx_x * KTraits::NUM_MMA_Q + mma_q) * 16 + - lane_idx % 16, - mma_d * 2 + lane_idx / 16); - o_smem->stmatrix_m8n8x4(o_smem_offset_w, o_frag_f16); + uint32_t o_smem_offset_w = o_smem->template get_permuted_offset( + (warp_idx_x * KTraits::NUM_MMA_Q + mma_q) * 16 + lane_idx % 16, + mma_d * 2 + lane_idx / 16); + o_smem->stmatrix_m8n8x4(o_smem_offset_w, o_frag_f16); #else - uint32_t o_smem_offset_w = - o_smem->template get_permuted_offset( - (warp_idx_x * KTraits::NUM_MMA_Q + mma_q) * 16 + - lane_idx / TPR, - mma_d * 2); + uint32_t o_smem_offset_w = o_smem->template get_permuted_offset( + (warp_idx_x * KTraits::NUM_MMA_Q + mma_q) * 16 + lane_idx / TPR, mma_d * 2); #if defined(PLATFORM_HIP_DEVICE) - ((uint32_t *)(o_smem->base + - o_smem_offset_w))[lane_idx % TPR] = - o_frag_f16[0]; - // Move 2 elements forward in the same row - uint32_t offset_2 = o_smem_offset_w + 2; - ((uint32_t *)(o_smem->base + offset_2))[lane_idx % 16] = - o_frag_f16[1]; + ((uint32_t*)(o_smem->base + o_smem_offset_w))[lane_idx % TPR] = o_frag_f16[0]; + // Move 2 elements forward in the same row + uint32_t offset_2 = o_smem_offset_w + 2; + ((uint32_t*)(o_smem->base + offset_2))[lane_idx % 16] = o_frag_f16[1]; #else - ((uint32_t *)(o_smem->base + - o_smem_offset_w))[lane_idx % TPR] = - o_frag_f16[0]; - ((uint32_t *)(o_smem->base + o_smem_offset_w + - 8 * UPCAST_STRIDE_O))[lane_idx % 4] = - o_frag_f16[1]; - ((uint32_t *)(o_smem->base + - (o_smem_offset_w ^ 0x1)))[lane_idx % TPR] = - o_frag_f16[2]; - ((uint32_t *)(o_smem->base + (o_smem_offset_w ^ 0x1) + - 8 * UPCAST_STRIDE_O))[lane_idx % 4] = - o_frag_f16[3]; + ((uint32_t*)(o_smem->base + o_smem_offset_w))[lane_idx % TPR] = o_frag_f16[0]; + ((uint32_t*)(o_smem->base + o_smem_offset_w + 8 * UPCAST_STRIDE_O))[lane_idx % 4] = + o_frag_f16[1]; + ((uint32_t*)(o_smem->base + (o_smem_offset_w ^ 0x1)))[lane_idx % TPR] = o_frag_f16[2]; + ((uint32_t*)(o_smem->base + (o_smem_offset_w ^ 0x1) + + 8 * UPCAST_STRIDE_O))[lane_idx % 4] = o_frag_f16[3]; #endif #endif - } - } + } + } - uint32_t o_smem_offset_w = - o_smem->template get_permuted_offset( - warp_idx_x * KTraits::NUM_MMA_Q * 16 + - lane_idx / WARP_THREAD_COLS, - lane_idx % WARP_THREAD_COLS); - -#pragma unroll - for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { -#pragma unroll - for (uint32_t j = 0; j < 2 * 2; ++j) { - uint32_t q, r; - group_size.divmod(o_packed_idx_base + - lane_idx / WARP_THREAD_COLS + - mma_q * 16 + j * 4, - q, r); - const uint32_t o_idx = q; - DTypeO *o_ptr = o_ptr_base + q * o_stride_n + - r * o_stride_h + - (lane_idx % WARP_THREAD_COLS) * - upcast_size(); -#pragma unroll - for (uint32_t mma_do = 0; - mma_do < KTraits::NUM_MMA_D_VO / 4; ++mma_do) - { - if (o_idx < qo_upper_bound) { - o_smem->store_vector(o_smem_offset_w, o_ptr); - } - o_ptr += WARP_THREAD_COLS * - upcast_size(); - o_smem_offset_w = - o_smem->template advance_offset_by_column< - WARP_THREAD_COLS>(o_smem_offset_w, mma_do); - } - o_smem_offset_w = o_smem->template advance_offset_by_row< - 4, UPCAST_STRIDE_O>(o_smem_offset_w) - - 2 * KTraits::NUM_MMA_D_VO; - } + uint32_t o_smem_offset_w = o_smem->template get_permuted_offset( + warp_idx_x * KTraits::NUM_MMA_Q * 16 + lane_idx / WARP_THREAD_COLS, + lane_idx % WARP_THREAD_COLS); + +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t j = 0; j < 2 * 2; ++j) { + uint32_t q, r; + group_size.divmod(o_packed_idx_base + lane_idx / WARP_THREAD_COLS + mma_q * 16 + j * 4, q, + r); + const uint32_t o_idx = q; + DTypeO* o_ptr = o_ptr_base + q * o_stride_n + r * o_stride_h + + (lane_idx % WARP_THREAD_COLS) * upcast_size(); +#pragma unroll + for (uint32_t mma_do = 0; mma_do < KTraits::NUM_MMA_D_VO / 4; ++mma_do) { + if (o_idx < qo_upper_bound) { + o_smem->store_vector(o_smem_offset_w, o_ptr); } + o_ptr += WARP_THREAD_COLS * upcast_size(); + o_smem_offset_w = o_smem->template advance_offset_by_column( + o_smem_offset_w, mma_do); + } + o_smem_offset_w = + o_smem->template advance_offset_by_row<4, UPCAST_STRIDE_O>(o_smem_offset_w) - + 2 * KTraits::NUM_MMA_D_VO; } + } } + } } -} // namespace +} // namespace template __device__ __forceinline__ void debug_write_sfrag_to_scratch( - typename KTraits::DTypeQKAccum ( - *s_frag)[KTraits::NUM_MMA_KV][KTraits::HALF_ELEMS_PER_THREAD], - const dim3 tid = threadIdx, - uint32_t debug_thread_id = 0, - uint32_t debug_warp_id = 0) -{ - using DTypeQKAccum = typename KTraits::DTypeQKAccum; - constexpr uint32_t NUM_MMA_Q = KTraits::NUM_MMA_Q; - constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; - const uint32_t warp_idx = get_warp_idx(tid.y, tid.z), - lane_idx = tid.x; - - // Write all thread's fragments to shared memory - for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { - for (uint32_t mma_kv = 0; mma_kv < NUM_MMA_KV; ++mma_kv) { - if (lane_idx == debug_thread_id && warp_idx == debug_warp_id) { - printf("%.6f %.6f %.6f %.6f\n", s_frag[mma_q][mma_kv][0], - s_frag[mma_q][mma_kv][1], s_frag[mma_q][mma_kv][2], - s_frag[mma_q][mma_kv][3]); - } - } + typename KTraits::DTypeQKAccum (*s_frag)[KTraits::NUM_MMA_KV][KTraits::HALF_ELEMS_PER_THREAD], + const dim3 tid = threadIdx, uint32_t debug_thread_id = 0, uint32_t debug_warp_id = 0) { + using DTypeQKAccum = typename KTraits::DTypeQKAccum; + constexpr uint32_t NUM_MMA_Q = KTraits::NUM_MMA_Q; + constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; + const uint32_t warp_idx = get_warp_idx(tid.y, tid.z), lane_idx = tid.x; + + // Write all thread's fragments to shared memory + for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { + for (uint32_t mma_kv = 0; mma_kv < NUM_MMA_KV; ++mma_kv) { + if (lane_idx == debug_thread_id && warp_idx == debug_warp_id) { + printf("%.6f %.6f %.6f %.6f\n", s_frag[mma_q][mma_kv][0], s_frag[mma_q][mma_kv][1], + s_frag[mma_q][mma_kv][2], s_frag[mma_q][mma_kv][3]); + } } - __syncthreads(); + } + __syncthreads(); } /*! @@ -2066,243 +1635,184 @@ __device__ __forceinline__ void debug_write_sfrag_to_scratch( * used in RoPE. */ template -__device__ __forceinline__ void -SinglePrefillWithKVCacheDevice(const Params params, - typename KTraits::SharedStorage &smem_storage, - const dim3 tid = threadIdx, - const uint32_t bx = blockIdx.x, - const uint32_t chunk_idx = blockIdx.y, - const uint32_t kv_head_idx = blockIdx.z, - const uint32_t num_chunks = gridDim.y, - const uint32_t num_kv_heads = gridDim.z) -{ - using DTypeQ = typename Params::DTypeQ; +__device__ __forceinline__ void SinglePrefillWithKVCacheDevice( + const Params params, typename KTraits::SharedStorage& smem_storage, const dim3 tid = threadIdx, + const uint32_t bx = blockIdx.x, const uint32_t chunk_idx = blockIdx.y, + const uint32_t kv_head_idx = blockIdx.z, const uint32_t num_chunks = gridDim.y, + const uint32_t num_kv_heads = gridDim.z) { + using DTypeQ = typename Params::DTypeQ; #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) - if constexpr (std::is_same_v) { - FLASHINFER_RUNTIME_ASSERT( - "Prefill kernels do not support bf16 on sm75."); - } - else { + if constexpr (std::is_same_v) { + FLASHINFER_RUNTIME_ASSERT("Prefill kernels do not support bf16 on sm75."); + } else { #endif - using DTypeKV = typename Params::DTypeKV; - using DTypeO = typename Params::DTypeO; - using DTypeQKAccum = typename KTraits::DTypeQKAccum; - using AttentionVariant = typename KTraits::AttentionVariant; - [[maybe_unused]] constexpr uint32_t NUM_MMA_Q = KTraits::NUM_MMA_Q; - [[maybe_unused]] constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; - [[maybe_unused]] constexpr uint32_t NUM_MMA_D_QK = - KTraits::NUM_MMA_D_QK; - [[maybe_unused]] constexpr uint32_t NUM_MMA_D_VO = - KTraits::NUM_MMA_D_VO; - [[maybe_unused]] constexpr uint32_t HEAD_DIM_QK = KTraits::HEAD_DIM_QK; - [[maybe_unused]] constexpr uint32_t HEAD_DIM_VO = KTraits::HEAD_DIM_VO; - [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_Q = - KTraits::UPCAST_STRIDE_Q; - [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_K = - KTraits::UPCAST_STRIDE_K; - [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_V = - KTraits::UPCAST_STRIDE_V; - [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_O = - KTraits::UPCAST_STRIDE_O; - [[maybe_unused]] constexpr uint32_t CTA_TILE_Q = KTraits::CTA_TILE_Q; - [[maybe_unused]] constexpr uint32_t CTA_TILE_KV = KTraits::CTA_TILE_KV; - [[maybe_unused]] constexpr uint32_t NUM_WARPS_Q = KTraits::NUM_WARPS_Q; - [[maybe_unused]] constexpr uint32_t NUM_WARPS_KV = - KTraits::NUM_WARPS_KV; - [[maybe_unused]] constexpr SwizzleMode SWIZZLE_MODE_Q = - KTraits::SWIZZLE_MODE_Q; - [[maybe_unused]] constexpr SwizzleMode SWIZZLE_MODE_KV = - KTraits::SWIZZLE_MODE_KV; - [[maybe_unused]] constexpr uint32_t KV_THR_LAYOUT_ROW = - KTraits::KV_THR_LAYOUT_ROW; - [[maybe_unused]] constexpr uint32_t KV_THR_LAYOUT_COL = - KTraits::KV_THR_LAYOUT_COL; - [[maybe_unused]] constexpr MaskMode MASK_MODE = KTraits::MASK_MODE; - [[maybe_unused]] constexpr uint32_t HALF_ELEMS_PER_THREAD = - KTraits::HALF_ELEMS_PER_THREAD; - [[maybe_unused]] constexpr uint32_t NUM_ACCUM_ROWS_PER_THREAD = - KTraits::NUM_ACCUM_ROWS_PER_THREAD; - [[maybe_unused]] constexpr uint32_t LOGITS_INDEX_STRIDE = - KTraits::LOGITS_INDEX_STRIDE; - [[maybe_unused]] constexpr uint32_t THREADS_PER_BMATRIX_ROW_SET = - KTraits::THREADS_PER_BMATRIX_ROW_SET; - [[maybe_unused]] constexpr uint32_t VECTOR_BIT_WIDTH = - KTraits::VECTOR_BIT_WIDTH; - - DTypeQ *q = params.q; - DTypeKV *k = params.k; - DTypeKV *v = params.v; - DTypeO *o = params.o; - float *lse = params.lse; - const uint32_t qo_len = params.qo_len; - const uint32_t kv_len = params.kv_len; - const bool partition_kv = params.partition_kv; - const uint32_t q_stride_n = params.q_stride_n; - const uint32_t q_stride_h = params.q_stride_h; - const uint32_t k_stride_n = params.k_stride_n; - const uint32_t k_stride_h = params.k_stride_h; - const uint32_t v_stride_n = params.v_stride_n; - const uint32_t v_stride_h = params.v_stride_h; - const uint_fastdiv &group_size = params.group_size; - - static_assert(sizeof(DTypeQ) == 2); - const uint32_t lane_idx = tid.x, - warp_idx = get_warp_idx(tid.y, tid.z); - const uint32_t num_qo_heads = num_kv_heads * group_size; - - const uint32_t max_chunk_size = - partition_kv ? ceil_div(kv_len, num_chunks) : kv_len; - const uint32_t chunk_start = - partition_kv ? chunk_idx * max_chunk_size : 0; - const uint32_t chunk_end = - partition_kv ? min((chunk_idx + 1) * max_chunk_size, kv_len) - : kv_len; - const uint32_t chunk_size = chunk_end - chunk_start; - - auto block = cg::this_thread_block(); - auto smem = reinterpret_cast(&smem_storage); - AttentionVariant variant(params, /*batch_idx=*/0, smem); - const uint32_t window_left = variant.window_left; - - DTypeQKAccum s_frag[NUM_MMA_Q][NUM_MMA_KV][HALF_ELEMS_PER_THREAD]; - alignas( - 16) float o_frag[NUM_MMA_Q][NUM_MMA_D_VO][HALF_ELEMS_PER_THREAD]; - DTypeQKAccum m[NUM_MMA_Q][NUM_ACCUM_ROWS_PER_THREAD]; - float d[NUM_MMA_Q][NUM_ACCUM_ROWS_PER_THREAD]; - float rope_freq[NUM_MMA_D_QK / 2][4]; - if constexpr (KTraits::POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) - { - const float rope_rcp_scale = params.rope_rcp_scale; - const float rope_rcp_theta = params.rope_rcp_theta; - init_rope_freq(rope_freq, rope_rcp_scale, rope_rcp_theta, - tid.x); - } - init_states(variant, o_frag, m, d); - - // cooperative fetch q fragment from gmem to reg - const uint32_t qo_packed_idx_base = - (bx * NUM_WARPS_Q + get_warp_idx_q(tid.y)) * NUM_MMA_Q * - 16; - smem_t qo_smem( - smem_storage.q_smem); - const uint32_t o_stride_n = num_qo_heads * HEAD_DIM_VO, - o_stride_h = HEAD_DIM_VO; - DTypeQ *q_ptr_base = q + (kv_head_idx * group_size) * q_stride_h; - DTypeO *o_ptr_base = partition_kv - ? o + chunk_idx * o_stride_n + - (kv_head_idx * group_size) * o_stride_h - : o + (kv_head_idx * group_size) * o_stride_h; - - load_q_global_smem(qo_packed_idx_base, qo_len, q_ptr_base, - q_stride_n, q_stride_h, group_size, - &qo_smem, tid); - - uint32_t q_smem_offset_r = - qo_smem.template get_permuted_offset( - get_warp_idx_q(tid.y) * NUM_MMA_Q * 16 + lane_idx % 16, - lane_idx / 16); - - memory::commit_group(); - if constexpr (KTraits::POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) - { - memory::wait_group<0>(); - block.sync(); - q_smem_inplace_apply_rotary( - qo_packed_idx_base, qo_len, kv_len, group_size, &qo_smem, - &q_smem_offset_r, rope_freq, tid); - block.sync(); - } + using DTypeKV = typename Params::DTypeKV; + using DTypeO = typename Params::DTypeO; + using DTypeQKAccum = typename KTraits::DTypeQKAccum; + using AttentionVariant = typename KTraits::AttentionVariant; + [[maybe_unused]] constexpr uint32_t NUM_MMA_Q = KTraits::NUM_MMA_Q; + [[maybe_unused]] constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; + [[maybe_unused]] constexpr uint32_t NUM_MMA_D_QK = KTraits::NUM_MMA_D_QK; + [[maybe_unused]] constexpr uint32_t NUM_MMA_D_VO = KTraits::NUM_MMA_D_VO; + [[maybe_unused]] constexpr uint32_t HEAD_DIM_QK = KTraits::HEAD_DIM_QK; + [[maybe_unused]] constexpr uint32_t HEAD_DIM_VO = KTraits::HEAD_DIM_VO; + [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_Q = KTraits::UPCAST_STRIDE_Q; + [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_K = KTraits::UPCAST_STRIDE_K; + [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_V = KTraits::UPCAST_STRIDE_V; + [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_O = KTraits::UPCAST_STRIDE_O; + [[maybe_unused]] constexpr uint32_t CTA_TILE_Q = KTraits::CTA_TILE_Q; + [[maybe_unused]] constexpr uint32_t CTA_TILE_KV = KTraits::CTA_TILE_KV; + [[maybe_unused]] constexpr uint32_t NUM_WARPS_Q = KTraits::NUM_WARPS_Q; + [[maybe_unused]] constexpr uint32_t NUM_WARPS_KV = KTraits::NUM_WARPS_KV; + [[maybe_unused]] constexpr SwizzleMode SWIZZLE_MODE_Q = KTraits::SWIZZLE_MODE_Q; + [[maybe_unused]] constexpr SwizzleMode SWIZZLE_MODE_KV = KTraits::SWIZZLE_MODE_KV; + [[maybe_unused]] constexpr uint32_t KV_THR_LAYOUT_ROW = KTraits::KV_THR_LAYOUT_ROW; + [[maybe_unused]] constexpr uint32_t KV_THR_LAYOUT_COL = KTraits::KV_THR_LAYOUT_COL; + [[maybe_unused]] constexpr MaskMode MASK_MODE = KTraits::MASK_MODE; + [[maybe_unused]] constexpr uint32_t HALF_ELEMS_PER_THREAD = KTraits::HALF_ELEMS_PER_THREAD; + [[maybe_unused]] constexpr uint32_t NUM_ACCUM_ROWS_PER_THREAD = + KTraits::NUM_ACCUM_ROWS_PER_THREAD; + [[maybe_unused]] constexpr uint32_t LOGITS_INDEX_STRIDE = KTraits::LOGITS_INDEX_STRIDE; + [[maybe_unused]] constexpr uint32_t THREADS_PER_BMATRIX_ROW_SET = + KTraits::THREADS_PER_BMATRIX_ROW_SET; + [[maybe_unused]] constexpr uint32_t VECTOR_BIT_WIDTH = KTraits::VECTOR_BIT_WIDTH; + + DTypeQ* q = params.q; + DTypeKV* k = params.k; + DTypeKV* v = params.v; + DTypeO* o = params.o; + float* lse = params.lse; + const uint32_t qo_len = params.qo_len; + const uint32_t kv_len = params.kv_len; + const bool partition_kv = params.partition_kv; + const uint32_t q_stride_n = params.q_stride_n; + const uint32_t q_stride_h = params.q_stride_h; + const uint32_t k_stride_n = params.k_stride_n; + const uint32_t k_stride_h = params.k_stride_h; + const uint32_t v_stride_n = params.v_stride_n; + const uint32_t v_stride_h = params.v_stride_h; + const uint_fastdiv& group_size = params.group_size; + + static_assert(sizeof(DTypeQ) == 2); + const uint32_t lane_idx = tid.x, warp_idx = get_warp_idx(tid.y, tid.z); + const uint32_t num_qo_heads = num_kv_heads * group_size; + + const uint32_t max_chunk_size = partition_kv ? ceil_div(kv_len, num_chunks) : kv_len; + const uint32_t chunk_start = partition_kv ? chunk_idx * max_chunk_size : 0; + const uint32_t chunk_end = + partition_kv ? min((chunk_idx + 1) * max_chunk_size, kv_len) : kv_len; + const uint32_t chunk_size = chunk_end - chunk_start; + + auto block = cg::this_thread_block(); + auto smem = reinterpret_cast(&smem_storage); + AttentionVariant variant(params, /*batch_idx=*/0, smem); + const uint32_t window_left = variant.window_left; + + DTypeQKAccum s_frag[NUM_MMA_Q][NUM_MMA_KV][HALF_ELEMS_PER_THREAD]; + alignas(16) float o_frag[NUM_MMA_Q][NUM_MMA_D_VO][HALF_ELEMS_PER_THREAD]; + DTypeQKAccum m[NUM_MMA_Q][NUM_ACCUM_ROWS_PER_THREAD]; + float d[NUM_MMA_Q][NUM_ACCUM_ROWS_PER_THREAD]; + float rope_freq[NUM_MMA_D_QK / 2][4]; + if constexpr (KTraits::POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) { + const float rope_rcp_scale = params.rope_rcp_scale; + const float rope_rcp_theta = params.rope_rcp_theta; + init_rope_freq(rope_freq, rope_rcp_scale, rope_rcp_theta, tid.x); + } + init_states(variant, o_frag, m, d); + + // cooperative fetch q fragment from gmem to reg + const uint32_t qo_packed_idx_base = + (bx * NUM_WARPS_Q + get_warp_idx_q(tid.y)) * NUM_MMA_Q * 16; + smem_t qo_smem(smem_storage.q_smem); + const uint32_t o_stride_n = num_qo_heads * HEAD_DIM_VO, o_stride_h = HEAD_DIM_VO; + DTypeQ* q_ptr_base = q + (kv_head_idx * group_size) * q_stride_h; + DTypeO* o_ptr_base = partition_kv + ? o + chunk_idx * o_stride_n + (kv_head_idx * group_size) * o_stride_h + : o + (kv_head_idx * group_size) * o_stride_h; + + load_q_global_smem(qo_packed_idx_base, qo_len, q_ptr_base, q_stride_n, q_stride_h, + group_size, &qo_smem, tid); + + uint32_t q_smem_offset_r = qo_smem.template get_permuted_offset( + get_warp_idx_q(tid.y) * NUM_MMA_Q * 16 + lane_idx % 16, lane_idx / 16); + + memory::commit_group(); + if constexpr (KTraits::POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) { + memory::wait_group<0>(); + block.sync(); + q_smem_inplace_apply_rotary(qo_packed_idx_base, qo_len, kv_len, group_size, &qo_smem, + &q_smem_offset_r, rope_freq, tid); + block.sync(); + } - smem_t k_smem( - smem_storage.k_smem); - smem_t v_smem( - smem_storage.v_smem); - - const uint32_t num_iterations = ceil_div( - MASK_MODE == MaskMode::kCausal - ? min(chunk_size, sub_if_greater_or_zero( - kv_len - qo_len + - ((bx + 1) * CTA_TILE_Q) / group_size, - chunk_start)) - : chunk_size, - CTA_TILE_KV); - - const uint32_t window_iteration = ceil_div( - sub_if_greater_or_zero(kv_len + (bx + 1) * CTA_TILE_Q / group_size, - qo_len + window_left + chunk_start), - CTA_TILE_KV); - - const uint32_t mask_iteration = - (MASK_MODE == MaskMode::kCausal - ? min(chunk_size, - sub_if_greater_or_zero( - kv_len + (bx * CTA_TILE_Q) / group_size - qo_len, - chunk_start)) - : chunk_size) / - CTA_TILE_KV; - - DTypeKV *k_ptr = k + - (chunk_start + warp_idx * KV_THR_LAYOUT_ROW + - lane_idx / KV_THR_LAYOUT_COL) * - k_stride_n + - kv_head_idx * k_stride_h + - (lane_idx % KV_THR_LAYOUT_COL) * - upcast_size(); - DTypeKV *v_ptr = v + - (chunk_start + warp_idx * KV_THR_LAYOUT_ROW + - lane_idx / KV_THR_LAYOUT_COL) * - v_stride_n + - kv_head_idx * v_stride_h + - (lane_idx % KV_THR_LAYOUT_COL) * - upcast_size(); + smem_t k_smem(smem_storage.k_smem); + smem_t v_smem(smem_storage.v_smem); + + const uint32_t num_iterations = + ceil_div(MASK_MODE == MaskMode::kCausal + ? min(chunk_size, + sub_if_greater_or_zero( + kv_len - qo_len + ((bx + 1) * CTA_TILE_Q) / group_size, chunk_start)) + : chunk_size, + CTA_TILE_KV); + + const uint32_t window_iteration = + ceil_div(sub_if_greater_or_zero(kv_len + (bx + 1) * CTA_TILE_Q / group_size, + qo_len + window_left + chunk_start), + CTA_TILE_KV); + + const uint32_t mask_iteration = + (MASK_MODE == MaskMode::kCausal + ? min(chunk_size, sub_if_greater_or_zero( + kv_len + (bx * CTA_TILE_Q) / group_size - qo_len, chunk_start)) + : chunk_size) / + CTA_TILE_KV; + + DTypeKV* k_ptr = + k + + (chunk_start + warp_idx * KV_THR_LAYOUT_ROW + lane_idx / KV_THR_LAYOUT_COL) * k_stride_n + + kv_head_idx * k_stride_h + + (lane_idx % KV_THR_LAYOUT_COL) * upcast_size(); + DTypeKV* v_ptr = + v + + (chunk_start + warp_idx * KV_THR_LAYOUT_ROW + lane_idx / KV_THR_LAYOUT_COL) * v_stride_n + + kv_head_idx * v_stride_h + + (lane_idx % KV_THR_LAYOUT_COL) * upcast_size(); #if defined(PLATFORM_HIP_DEVICE) - uint32_t k_smem_offset_r = - k_smem.template get_permuted_offset( - get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + - lane_idx % 16, - (lane_idx / 16)); + uint32_t k_smem_offset_r = k_smem.template get_permuted_offset( + get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + lane_idx % 16, (lane_idx / 16)); #elif defined(PLATFORM_CUDA_DEVICE) - uint32_t k_smem_offset_r = - k_smem.template get_permuted_offset( - get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + - 8 * (lane_idx / 16) + lane_idx % 8, - (lane_idx % 16) / 8); + uint32_t k_smem_offset_r = k_smem.template get_permuted_offset( + get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + 8 * (lane_idx / 16) + lane_idx % 8, + (lane_idx % 16) / 8); #endif - uint32_t v_smem_offset_r = - v_smem.template get_permuted_offset( - get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + - lane_idx % 16, - lane_idx / 16), - k_smem_offset_w = - k_smem.template get_permuted_offset( - warp_idx * KV_THR_LAYOUT_ROW + - lane_idx / KV_THR_LAYOUT_COL, - lane_idx % KV_THR_LAYOUT_COL), - v_smem_offset_w = - v_smem.template get_permuted_offset( - warp_idx * KV_THR_LAYOUT_ROW + - lane_idx / KV_THR_LAYOUT_COL, - lane_idx % KV_THR_LAYOUT_COL); - produce_kv( - k_smem, &k_smem_offset_w, &k_ptr, k_stride_n, 0, chunk_size, tid); - memory::commit_group(); - produce_kv( - v_smem, &v_smem_offset_w, &v_ptr, v_stride_n, 0, chunk_size, tid); - memory::commit_group(); + uint32_t v_smem_offset_r = v_smem.template get_permuted_offset( + get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + lane_idx % 16, lane_idx / 16), + k_smem_offset_w = k_smem.template get_permuted_offset( + warp_idx * KV_THR_LAYOUT_ROW + lane_idx / KV_THR_LAYOUT_COL, + lane_idx % KV_THR_LAYOUT_COL), + v_smem_offset_w = v_smem.template get_permuted_offset( + warp_idx * KV_THR_LAYOUT_ROW + lane_idx / KV_THR_LAYOUT_COL, + lane_idx % KV_THR_LAYOUT_COL); + produce_kv(k_smem, &k_smem_offset_w, &k_ptr, + k_stride_n, 0, chunk_size, tid); + memory::commit_group(); + produce_kv(v_smem, &v_smem_offset_w, &v_ptr, + v_stride_n, 0, chunk_size, tid); + memory::commit_group(); #if Debug - smem_t scratch( - smem_storage.qk_scratch); + smem_t scratch(smem_storage.qk_scratch); - // if (warp_idx == 0 && lane_idx == 0) { - // printf("partition_kv : %d\n", partition_kv); - // printf("kv_len : %d\n", kv_len); - // printf("max_chunk_size : %d\n", max_chunk_size); - // printf("chunk_end : %d\n", chunk_end); - // printf("chunk_start : %d\n", chunk_start); - // } + // if (warp_idx == 0 && lane_idx == 0) { + // printf("partition_kv : %d\n", partition_kv); + // printf("kv_len : %d\n", kv_len); + // printf("max_chunk_size : %d\n", max_chunk_size); + // printf("chunk_end : %d\n", chunk_end); + // printf("chunk_start : %d\n", chunk_start); + // } #if 0 // Test Q if (warp_idx == 0 && lane_idx == 0) { @@ -2326,20 +1836,20 @@ SinglePrefillWithKVCacheDevice(const Params params, } } #endif - // Test K Global values: - // Prints the (NUM_MMA_KV*16) x (NUM_MMA_D*16) matrix from global mem. - - if (warp_idx == 0 && lane_idx == 0) { - // printf("\n DEBUG K Global (HIP):\n"); - // printf("k_stride_n : %d\n", k_stride_n); - // printf("k_stride_h : %d\n", k_stride_h); - // printf("kv_head_idx : %d\n", kv_head_idx); - // printf("num_qo_heads : %d\n", num_qo_heads); - // printf("num_kv_heads : %d\n", num_kv_heads); - // printf("k_stride_n : %d\n", k_stride_n); - // printf("KTraits::NUM_MMA_D_QK : %d\n", KTraits::NUM_MMA_D_QK); - // printf("NUM_MMA_KV : %d\n", NUM_MMA_KV); - // printf("NUM_MMA_Q : %d\n", NUM_MMA_Q); + // Test K Global values: + // Prints the (NUM_MMA_KV*16) x (NUM_MMA_D*16) matrix from global mem. + + if (warp_idx == 0 && lane_idx == 0) { + // printf("\n DEBUG K Global (HIP):\n"); + // printf("k_stride_n : %d\n", k_stride_n); + // printf("k_stride_h : %d\n", k_stride_h); + // printf("kv_head_idx : %d\n", kv_head_idx); + // printf("num_qo_heads : %d\n", num_qo_heads); + // printf("num_kv_heads : %d\n", num_kv_heads); + // printf("k_stride_n : %d\n", k_stride_n); + // printf("KTraits::NUM_MMA_D_QK : %d\n", KTraits::NUM_MMA_D_QK); + // printf("NUM_MMA_KV : %d\n", NUM_MMA_KV); + // printf("NUM_MMA_Q : %d\n", NUM_MMA_Q); #if 0 DTypeKV *k_ptr_tmp = k + (chunk_start + warp_idx * KV_THR_LAYOUT_ROW + @@ -2357,29 +1867,26 @@ SinglePrefillWithKVCacheDevice(const Params params, printf("\n"); } #endif - } + } - // Test K LDS values: - // Prints the (NUM_MMA_KV*16) x (NUM_MMA_D*16) matrix from shared mem. - // Note that LDS is loaded collaboratively by all warps and not each - // warp accesses the whole K matrix loaded into LDS. Each warp will - // only access 1/4 of the K values loaded into LDS. + // Test K LDS values: + // Prints the (NUM_MMA_KV*16) x (NUM_MMA_D*16) matrix from shared mem. + // Note that LDS is loaded collaboratively by all warps and not each + // warp accesses the whole K matrix loaded into LDS. Each warp will + // only access 1/4 of the K values loaded into LDS. #endif #pragma unroll 1 - for (uint32_t iter = 0; iter < num_iterations; ++iter) { - // for (uint32_t iter = 0; iter < 1; ++iter) { - memory::wait_group<1>(); - block.sync(); - - if constexpr (KTraits::POS_ENCODING_MODE == - PosEncodingMode::kRoPELlama) - { - k_smem_inplace_apply_rotary( - chunk_start + iter * CTA_TILE_KV, &k_smem, &k_smem_offset_r, - rope_freq, tid); - block.sync(); - } + for (uint32_t iter = 0; iter < num_iterations; ++iter) { + // for (uint32_t iter = 0; iter < 1; ++iter) { + memory::wait_group<1>(); + block.sync(); + + if constexpr (KTraits::POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) { + k_smem_inplace_apply_rotary(chunk_start + iter * CTA_TILE_KV, &k_smem, + &k_smem_offset_r, rope_freq, tid); + block.sync(); + } #if Debug1 #if 0 @@ -2406,1372 +1913,1077 @@ SinglePrefillWithKVCacheDevice(const Params params, #endif #if 1 - if (warp_idx == 0 && lane_idx == 0) { - uint32_t b_frag[KTraits::INT32_ELEMS_PER_THREAD]; - k_smem.load_fragment(k_smem_offset_r, b_frag); - auto frag_T = reinterpret_cast<__half *>(b_frag); - for (auto reg_id = 0ul; reg_id < 4; ++reg_id) { - for (auto i = 0ul; i < 4; ++i) { - printf("%f ", (float)(*(frag_T + i))); - } - } - printf("\n------------\n"); - k_smem.load_fragment(k_smem_offset_r, b_frag); - frag_T = reinterpret_cast<__half *>(b_frag); - for (auto reg_id = 0ul; reg_id < 4; ++reg_id) { - for (auto i = 0ul; i < 4; ++i) { - printf("%f ", (float)(*(frag_T + i))); - } - } - printf("\n-----===============-------\n"); - } + if (warp_idx == 0 && lane_idx == 0) { + uint32_t b_frag[KTraits::INT32_ELEMS_PER_THREAD]; + k_smem.load_fragment(k_smem_offset_r, b_frag); + auto frag_T = reinterpret_cast<__half*>(b_frag); + for (auto reg_id = 0ul; reg_id < 4; ++reg_id) { + for (auto i = 0ul; i < 4; ++i) { + printf("%f ", (float)(*(frag_T + i))); + } + } + printf("\n------------\n"); + k_smem.load_fragment(k_smem_offset_r, b_frag); + frag_T = reinterpret_cast<__half*>(b_frag); + for (auto reg_id = 0ul; reg_id < 4; ++reg_id) { + for (auto i = 0ul; i < 4; ++i) { + printf("%f ", (float)(*(frag_T + i))); + } + } + printf("\n-----===============-------\n"); + } #endif #endif - // compute attention score - compute_qk(&qo_smem, &q_smem_offset_r, &k_smem, - &k_smem_offset_r, s_frag); + // compute attention score + compute_qk(&qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag); #if Debug1 - debug_write_sfrag_to_scratch( - s_frag, tid, params.debug_thread_id, params.debug_warp_id); + debug_write_sfrag_to_scratch(s_frag, tid, params.debug_thread_id, + params.debug_warp_id); #endif - logits_transform( - params, variant, /*batch_idx=*/0, qo_packed_idx_base, - chunk_start + - (iter * NUM_WARPS_KV + get_warp_idx_kv(tid.z)) * - NUM_MMA_KV * 16, - qo_len, kv_len, group_size, s_frag, tid, kv_head_idx); + logits_transform( + params, variant, /*batch_idx=*/0, qo_packed_idx_base, + chunk_start + (iter * NUM_WARPS_KV + get_warp_idx_kv(tid.z)) * NUM_MMA_KV * 16, + qo_len, kv_len, group_size, s_frag, tid, kv_head_idx); #if Debug1 - debug_write_sfrag_to_scratch( - s_frag, tid, params.debug_thread_id, params.debug_warp_id); + debug_write_sfrag_to_scratch(s_frag, tid, params.debug_thread_id, + params.debug_warp_id); #endif #if Debug1 - debug_write_sfrag_to_scratch(s_frag, &scratch, tid); + debug_write_sfrag_to_scratch(s_frag, &scratch, tid); #endif - // apply mask - if (MASK_MODE == MaskMode::kCustom || - (iter >= mask_iteration || iter < window_iteration)) - { - logits_mask( - params, variant, /*batch_idx=*/0, qo_packed_idx_base, - chunk_start + (iter * NUM_WARPS_KV + - get_warp_idx_kv(tid.z)) * - NUM_MMA_KV * 16, - qo_len, kv_len, chunk_end, group_size, s_frag, tid, - kv_head_idx); - } + // apply mask + if (MASK_MODE == MaskMode::kCustom || (iter >= mask_iteration || iter < window_iteration)) { + logits_mask( + params, variant, /*batch_idx=*/0, qo_packed_idx_base, + chunk_start + (iter * NUM_WARPS_KV + get_warp_idx_kv(tid.z)) * NUM_MMA_KV * 16, + qo_len, kv_len, chunk_end, group_size, s_frag, tid, kv_head_idx); + } #if Debug1 - debug_write_sfrag_to_scratch( - s_frag, tid, params.debug_thread_id, params.debug_warp_id); + debug_write_sfrag_to_scratch(s_frag, tid, params.debug_thread_id, + params.debug_warp_id); #endif - // compute m,d states in online softmax - update_mdo_states(variant, s_frag, o_frag, m, d, warp_idx, - lane_idx); + // compute m,d states in online softmax + update_mdo_states(variant, s_frag, o_frag, m, d, warp_idx, lane_idx); #if Debug1 - debug_write_sfrag_to_scratch( - s_frag, tid, params.debug_thread_id, params.debug_warp_id); + debug_write_sfrag_to_scratch(s_frag, tid, params.debug_thread_id, + params.debug_warp_id); #endif - block.sync(); - produce_kv( - k_smem, &k_smem_offset_w, &k_ptr, k_stride_n, - (iter + 1) * CTA_TILE_KV, chunk_size, tid); - memory::commit_group(); - memory::wait_group<1>(); - block.sync(); - - // compute sfm*v - compute_sfm_v(&v_smem, &v_smem_offset_r, s_frag, o_frag, - d); + block.sync(); + produce_kv( + k_smem, &k_smem_offset_w, &k_ptr, k_stride_n, (iter + 1) * CTA_TILE_KV, chunk_size, tid); + memory::commit_group(); + memory::wait_group<1>(); + block.sync(); + + // compute sfm*v + compute_sfm_v(&v_smem, &v_smem_offset_r, s_frag, o_frag, d); #if Debug - if (lane_idx == params.debug_thread_id && - warp_idx == params.debug_warp_id) - { - for (auto mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { - printf("%f\n", d[mma_q][0]); - printf("%f\n", d[mma_q][1]); - printf("%f\n", d[mma_q][2]); - printf("%f\n", d[mma_q][3]); - } - } + if (lane_idx == params.debug_thread_id && warp_idx == params.debug_warp_id) { + for (auto mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { + printf("%f\n", d[mma_q][0]); + printf("%f\n", d[mma_q][1]); + printf("%f\n", d[mma_q][2]); + printf("%f\n", d[mma_q][3]); + } + } #endif - block.sync(); - produce_kv( - v_smem, &v_smem_offset_w, &v_ptr, v_stride_n, - (iter + 1) * CTA_TILE_KV, chunk_size, tid); - memory::commit_group(); - } - memory::wait_group<0>(); - block.sync(); + block.sync(); + produce_kv( + v_smem, &v_smem_offset_w, &v_ptr, v_stride_n, (iter + 1) * CTA_TILE_KV, chunk_size, tid); + memory::commit_group(); + } + memory::wait_group<0>(); + block.sync(); - finalize_m(variant, m); + finalize_m(variant, m); - // threadblock synchronization - threadblock_sync_mdo_states(o_frag, &smem_storage, m, d, - warp_idx, lane_idx, tid); + // threadblock synchronization + threadblock_sync_mdo_states(o_frag, &smem_storage, m, d, warp_idx, lane_idx, tid); - // normalize d - normalize_d(o_frag, m, d); + // normalize d + normalize_d(o_frag, m, d); - // write back - write_o_reg_gmem( - o_frag, &qo_smem, o_ptr_base, qo_packed_idx_base, qo_len, - /*o_stride_n=*/ - partition_kv ? num_chunks * o_stride_n : o_stride_n, - /*o_stride_h=*/o_stride_h, group_size, tid); + // write back + write_o_reg_gmem(o_frag, &qo_smem, o_ptr_base, qo_packed_idx_base, qo_len, + /*o_stride_n=*/ + partition_kv ? num_chunks * o_stride_n : o_stride_n, + /*o_stride_h=*/o_stride_h, group_size, tid); - // write lse - if constexpr (variant.use_softmax) { - if (lse != nullptr || partition_kv) { - if (get_warp_idx_kv(tid.z) == 0) { + // write lse + if constexpr (variant.use_softmax) { + if (lse != nullptr || partition_kv) { + if (get_warp_idx_kv(tid.z) == 0) { #pragma unroll - for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { + for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { #pragma unroll - for (uint32_t j = 0; j < NUM_ACCUM_ROWS_PER_THREAD; ++j) - { - uint32_t q, r; - group_size.divmod( - qo_packed_idx_base + - lane_idx / THREADS_PER_BMATRIX_ROW_SET + + for (uint32_t j = 0; j < NUM_ACCUM_ROWS_PER_THREAD; ++j) { + uint32_t q, r; + group_size.divmod(qo_packed_idx_base + lane_idx / THREADS_PER_BMATRIX_ROW_SET + j * LOGITS_INDEX_STRIDE + mma_q * 16, q, r); - const uint32_t qo_head_idx = - kv_head_idx * group_size + r; - const uint32_t qo_idx = q; - if (qo_idx < qo_len) { - if (partition_kv) { - lse[(qo_idx * num_chunks + chunk_idx) * - num_qo_heads + - qo_head_idx] = - gpu_iface::math::ptx_log2(d[mma_q][j]) + - float(m[mma_q][j]); - } - else { - lse[qo_idx * num_qo_heads + qo_head_idx] = - gpu_iface::math::ptx_log2(d[mma_q][j]) + - float(m[mma_q][j]); - } - } - } - } + const uint32_t qo_head_idx = kv_head_idx * group_size + r; + const uint32_t qo_idx = q; + if (qo_idx < qo_len) { + if (partition_kv) { + lse[(qo_idx * num_chunks + chunk_idx) * num_qo_heads + qo_head_idx] = + gpu_iface::math::ptx_log2(d[mma_q][j]) + float(m[mma_q][j]); + } else { + lse[qo_idx * num_qo_heads + qo_head_idx] = + gpu_iface::math::ptx_log2(d[mma_q][j]) + float(m[mma_q][j]); } + } } + } } -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) + } } +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) + } #endif } template -__global__ -__launch_bounds__(KTraits::NUM_THREADS) void SinglePrefillWithKVCacheKernel( - const __grid_constant__ Params params) -{ - extern __shared__ uint8_t smem[]; - auto &smem_storage = - reinterpret_cast(smem); - SinglePrefillWithKVCacheDevice(params, smem_storage); +__global__ __launch_bounds__(KTraits::NUM_THREADS) void SinglePrefillWithKVCacheKernel( + const __grid_constant__ Params params) { + extern __shared__ uint8_t smem[]; + auto& smem_storage = reinterpret_cast(smem); + SinglePrefillWithKVCacheDevice(params, smem_storage); } -template -gpuError_t SinglePrefillWithKVCacheDispatched(Params params, - typename Params::DTypeO *tmp, - gpuStream_t stream) -{ - using DTypeQ = typename Params::DTypeQ; - using DTypeKV = typename Params::DTypeKV; - using DTypeO = typename Params::DTypeO; - const uint32_t num_qo_heads = params.num_qo_heads; - const uint32_t num_kv_heads = params.num_kv_heads; - const uint32_t qo_len = params.qo_len; - const uint32_t kv_len = params.kv_len; - if (kv_len < qo_len && MASK_MODE == MaskMode::kCausal) { +gpuError_t SinglePrefillWithKVCacheDispatched(Params params, typename Params::DTypeO* tmp, + gpuStream_t stream) { + using DTypeQ = typename Params::DTypeQ; + using DTypeKV = typename Params::DTypeKV; + using DTypeO = typename Params::DTypeO; + const uint32_t num_qo_heads = params.num_qo_heads; + const uint32_t num_kv_heads = params.num_kv_heads; + const uint32_t qo_len = params.qo_len; + const uint32_t kv_len = params.kv_len; + if (kv_len < qo_len && MASK_MODE == MaskMode::kCausal) { + std::ostringstream err_msg; + err_msg << "When mask_mode is set to MaskMode::kCausal, kv_len must be " + "greater than or equal to qo_len, got kv_len" + << kv_len << " and qo_len " << qo_len; + FLASHINFER_ERROR(err_msg.str()); + } + + const uint32_t group_size = num_qo_heads / num_kv_heads; + constexpr uint32_t NUM_MMA_D_QK = HEAD_DIM_QK / 16; + constexpr uint32_t NUM_MMA_D_VO = HEAD_DIM_VO / 16; + int64_t packed_qo_len = qo_len * group_size; + uint32_t cta_tile_q = FA2DetermineCtaTileQ(packed_qo_len, HEAD_DIM_VO); + + DISPATCH_CTA_TILE_Q(cta_tile_q, CTA_TILE_Q, { + constexpr uint32_t NUM_WARPS_Q = get_num_warps_q(CTA_TILE_Q); + constexpr uint32_t NUM_WARPS_KV = get_num_warps_kv(CTA_TILE_Q); + constexpr uint32_t NUM_MMA_Q = get_num_mma_q(CTA_TILE_Q); + + using DTypeQKAccum = + typename std::conditional, half, + float>::type; + + int dev_id = 0; + FI_GPU_CALL(gpuGetDevice(&dev_id)); + int max_smem_per_sm = getMaxSharedMemPerMultiprocessor(dev_id); + // we expect each sm execute two threadblocks + const int num_ctas_per_sm = + max_smem_per_sm >= 2 * (CTA_TILE_Q * HEAD_DIM_QK * sizeof(DTypeQ) + + (HEAD_DIM_QK + HEAD_DIM_VO) * 16 * NUM_WARPS_KV * sizeof(DTypeKV)) + ? 2 + : 1; + const int max_smem_per_threadblock = max_smem_per_sm / num_ctas_per_sm; + + const uint32_t max_num_mma_kv_reg = + (HEAD_DIM_VO >= 128 && NUM_MMA_Q == 2 && POS_ENCODING_MODE == PosEncodingMode::kRoPELlama && + !USE_FP16_QK_REDUCTION) + ? 2 + : (8 / NUM_MMA_Q); + const uint32_t max_num_mma_kv_smem = + (max_smem_per_threadblock - CTA_TILE_Q * HEAD_DIM_QK * sizeof(DTypeQ)) / + ((HEAD_DIM_QK + HEAD_DIM_VO) * 16 * NUM_WARPS_KV * sizeof(DTypeKV)); + + // control NUM_MMA_KV for maximum warp occupancy + DISPATCH_NUM_MMA_KV(min(max_num_mma_kv_smem, max_num_mma_kv_reg), NUM_MMA_KV, { + using KTraits = + KernelTraits; + if constexpr (KTraits::IsInvalid()) { + // Invalid configuration, skip std::ostringstream err_msg; - err_msg << "When mask_mode is set to MaskMode::kCausal, kv_len must be " - "greater than or equal to qo_len, got kv_len" - << kv_len << " and qo_len " << qo_len; + err_msg << "FlashInfer Internal Error: Invalid " + "configuration : NUM_MMA_Q=" + << NUM_MMA_Q << " NUM_MMA_D_QK=" << NUM_MMA_D_QK << " NUM_MMA_D_VO=" << NUM_MMA_D_VO + << " NUM_MMA_KV=" << NUM_MMA_KV << " NUM_WARPS_Q=" << NUM_WARPS_Q + << " NUM_WARPS_KV=" << NUM_WARPS_KV + << " please create an issue " + "(https://github.com/flashinfer-ai/flashinfer/" + "issues)" + " and report the issue to the developers."; FLASHINFER_ERROR(err_msg.str()); - } - - const uint32_t group_size = num_qo_heads / num_kv_heads; - constexpr uint32_t NUM_MMA_D_QK = HEAD_DIM_QK / 16; - constexpr uint32_t NUM_MMA_D_VO = HEAD_DIM_VO / 16; - int64_t packed_qo_len = qo_len * group_size; - uint32_t cta_tile_q = FA2DetermineCtaTileQ(packed_qo_len, HEAD_DIM_VO); - - DISPATCH_CTA_TILE_Q(cta_tile_q, CTA_TILE_Q, { - constexpr uint32_t NUM_WARPS_Q = get_num_warps_q(CTA_TILE_Q); - constexpr uint32_t NUM_WARPS_KV = get_num_warps_kv(CTA_TILE_Q); - constexpr uint32_t NUM_MMA_Q = get_num_mma_q(CTA_TILE_Q); - - using DTypeQKAccum = - typename std::conditional, - half, float>::type; - - int dev_id = 0; - FI_GPU_CALL(gpuGetDevice(&dev_id)); - int max_smem_per_sm = getMaxSharedMemPerMultiprocessor(dev_id); - // we expect each sm execute two threadblocks - const int num_ctas_per_sm = - max_smem_per_sm >= 2 * (CTA_TILE_Q * HEAD_DIM_QK * sizeof(DTypeQ) + - (HEAD_DIM_QK + HEAD_DIM_VO) * 16 * - NUM_WARPS_KV * sizeof(DTypeKV)) - ? 2 - : 1; - const int max_smem_per_threadblock = max_smem_per_sm / num_ctas_per_sm; - - const uint32_t max_num_mma_kv_reg = - (HEAD_DIM_VO >= 128 && NUM_MMA_Q == 2 && - POS_ENCODING_MODE == PosEncodingMode::kRoPELlama && - !USE_FP16_QK_REDUCTION) - ? 2 - : (8 / NUM_MMA_Q); - const uint32_t max_num_mma_kv_smem = - (max_smem_per_threadblock - - CTA_TILE_Q * HEAD_DIM_QK * sizeof(DTypeQ)) / - ((HEAD_DIM_QK + HEAD_DIM_VO) * 16 * NUM_WARPS_KV * sizeof(DTypeKV)); - - // control NUM_MMA_KV for maximum warp occupancy - DISPATCH_NUM_MMA_KV( - min(max_num_mma_kv_smem, max_num_mma_kv_reg), NUM_MMA_KV, { - using KTraits = - KernelTraits; - if constexpr (KTraits::IsInvalid()) { - // Invalid configuration, skip - std::ostringstream err_msg; - err_msg << "FlashInfer Internal Error: Invalid " - "configuration : NUM_MMA_Q=" - << NUM_MMA_Q << " NUM_MMA_D_QK=" << NUM_MMA_D_QK - << " NUM_MMA_D_VO=" << NUM_MMA_D_VO - << " NUM_MMA_KV=" << NUM_MMA_KV - << " NUM_WARPS_Q=" << NUM_WARPS_Q - << " NUM_WARPS_KV=" << NUM_WARPS_KV - << " please create an issue " - "(https://github.com/flashinfer-ai/flashinfer/" - "issues)" - " and report the issue to the developers."; - FLASHINFER_ERROR(err_msg.str()); - } - else { - constexpr uint32_t num_threads = - (NUM_WARPS_Q * NUM_WARPS_KV) * WARP_SIZE; - auto kernel = - SinglePrefillWithKVCacheKernel; - size_t smem_size = sizeof(typename KTraits::SharedStorage); - FI_GPU_CALL(gpuFuncSetAttribute( - kernel, gpuFuncAttributeMaxDynamicSharedMemorySize, - smem_size)); - int num_blocks_per_sm = 0; - int num_sm = 0; - FI_GPU_CALL(gpuDeviceGetAttribute( - &num_sm, gpuDevAttrMultiProcessorCount, dev_id)); - FI_GPU_CALL(gpuOccupancyMaxActiveBlocksPerMultiprocessor( - &num_blocks_per_sm, kernel, num_threads, smem_size)); - uint32_t max_num_kv_chunks = - (num_blocks_per_sm * num_sm) / - (num_kv_heads * - ceil_div(qo_len * group_size, CTA_TILE_Q)); - uint32_t num_chunks; - if (max_num_kv_chunks > 0) { - uint32_t chunk_size = - max(ceil_div(kv_len, max_num_kv_chunks), 256); - num_chunks = ceil_div(kv_len, chunk_size); - } - else { - num_chunks = 0; - } + } else { + constexpr uint32_t num_threads = (NUM_WARPS_Q * NUM_WARPS_KV) * WARP_SIZE; + auto kernel = SinglePrefillWithKVCacheKernel; + size_t smem_size = sizeof(typename KTraits::SharedStorage); + FI_GPU_CALL( + gpuFuncSetAttribute(kernel, gpuFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + int num_blocks_per_sm = 0; + int num_sm = 0; + FI_GPU_CALL(gpuDeviceGetAttribute(&num_sm, gpuDevAttrMultiProcessorCount, dev_id)); + FI_GPU_CALL(gpuOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks_per_sm, kernel, + num_threads, smem_size)); + uint32_t max_num_kv_chunks = (num_blocks_per_sm * num_sm) / + (num_kv_heads * ceil_div(qo_len * group_size, CTA_TILE_Q)); + uint32_t num_chunks; + if (max_num_kv_chunks > 0) { + uint32_t chunk_size = max(ceil_div(kv_len, max_num_kv_chunks), 256); + num_chunks = ceil_div(kv_len, chunk_size); + } else { + num_chunks = 0; + } - if (num_chunks <= 1 || tmp == nullptr) { - // Enough parallelism, do not split-kv - params.partition_kv = false; - void *args[] = {(void *)¶ms}; - dim3 nblks(ceil_div(qo_len * group_size, CTA_TILE_Q), 1, - num_kv_heads); - dim3 nthrs(WARP_SIZE, NUM_WARPS_Q, NUM_WARPS_KV); - FI_GPU_CALL(gpuLaunchKernel((void *)kernel, nblks, - nthrs, args, smem_size, - stream)); - } - else { - // Use cooperative groups to increase occupancy - params.partition_kv = true; - float *tmp_lse = - (float *)(tmp + num_chunks * qo_len * num_qo_heads * - HEAD_DIM_VO); - auto o = params.o; - auto lse = params.lse; - params.o = tmp; - params.lse = tmp_lse; - void *args[] = {(void *)¶ms}; - dim3 nblks(ceil_div(qo_len * group_size, CTA_TILE_Q), - num_chunks, num_kv_heads); - dim3 nthrs(WARP_SIZE, NUM_WARPS_Q, NUM_WARPS_KV); - FI_GPU_CALL(gpuLaunchKernel((void *)kernel, nblks, - nthrs, args, smem_size, - stream)); - if constexpr (AttentionVariant::use_softmax) { - FI_GPU_CALL(MergeStates( - tmp, tmp_lse, o, lse, num_chunks, qo_len, - num_qo_heads, HEAD_DIM_VO, stream)); - } - else { - FI_GPU_CALL(AttentionSum(tmp, o, num_chunks, qo_len, - num_qo_heads, HEAD_DIM_VO, - stream)); - } - } - } - }) - }); - return gpuSuccess; + if (num_chunks <= 1 || tmp == nullptr) { + // Enough parallelism, do not split-kv + params.partition_kv = false; + void* args[] = {(void*)¶ms}; + dim3 nblks(ceil_div(qo_len * group_size, CTA_TILE_Q), 1, num_kv_heads); + dim3 nthrs(WARP_SIZE, NUM_WARPS_Q, NUM_WARPS_KV); + FI_GPU_CALL(gpuLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + } else { + // Use cooperative groups to increase occupancy + params.partition_kv = true; + float* tmp_lse = (float*)(tmp + num_chunks * qo_len * num_qo_heads * HEAD_DIM_VO); + auto o = params.o; + auto lse = params.lse; + params.o = tmp; + params.lse = tmp_lse; + void* args[] = {(void*)¶ms}; + dim3 nblks(ceil_div(qo_len * group_size, CTA_TILE_Q), num_chunks, num_kv_heads); + dim3 nthrs(WARP_SIZE, NUM_WARPS_Q, NUM_WARPS_KV); + FI_GPU_CALL(gpuLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + if constexpr (AttentionVariant::use_softmax) { + FI_GPU_CALL(MergeStates(tmp, tmp_lse, o, lse, num_chunks, qo_len, num_qo_heads, + HEAD_DIM_VO, stream)); + } else { + FI_GPU_CALL( + AttentionSum(tmp, o, num_chunks, qo_len, num_qo_heads, HEAD_DIM_VO, stream)); + } + } + } + }) + }); + return gpuSuccess; } template -__global__ -__launch_bounds__(KTraits::NUM_THREADS) void BatchPrefillWithRaggedKVCacheKernel( - const __grid_constant__ Params params) -{ - using DTypeQ = typename Params::DTypeQ; +__global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchPrefillWithRaggedKVCacheKernel( + const __grid_constant__ Params params) { + using DTypeQ = typename Params::DTypeQ; #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) - if constexpr (std::is_same_v) { - FLASHINFER_RUNTIME_ASSERT( - "Prefill kernels do not support bf16 on sm75."); - } - else { + if constexpr (std::is_same_v) { + FLASHINFER_RUNTIME_ASSERT("Prefill kernels do not support bf16 on sm75."); + } else { #endif - using DTypeKV = typename Params::DTypeKV; - using DTypeO = typename Params::DTypeO; - using IdType = typename Params::IdType; - using DTypeQKAccum = typename KTraits::DTypeQKAccum; - using AttentionVariant = typename KTraits::AttentionVariant; - [[maybe_unused]] constexpr uint32_t NUM_MMA_Q = KTraits::NUM_MMA_Q; - [[maybe_unused]] constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; - [[maybe_unused]] constexpr uint32_t NUM_MMA_D_QK = - KTraits::NUM_MMA_D_QK; - [[maybe_unused]] constexpr uint32_t NUM_MMA_D_VO = - KTraits::NUM_MMA_D_VO; - [[maybe_unused]] constexpr uint32_t HEAD_DIM_QK = KTraits::HEAD_DIM_QK; - [[maybe_unused]] constexpr uint32_t HEAD_DIM_VO = KTraits::HEAD_DIM_VO; - [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_Q = - KTraits::UPCAST_STRIDE_Q; - [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_K = - KTraits::UPCAST_STRIDE_K; - [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_V = - KTraits::UPCAST_STRIDE_V; - [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_O = - KTraits::UPCAST_STRIDE_O; - [[maybe_unused]] constexpr uint32_t CTA_TILE_Q = KTraits::CTA_TILE_Q; - [[maybe_unused]] constexpr uint32_t CTA_TILE_KV = KTraits::CTA_TILE_KV; - [[maybe_unused]] constexpr uint32_t NUM_WARPS_Q = KTraits::NUM_WARPS_Q; - [[maybe_unused]] constexpr uint32_t NUM_WARPS_KV = - KTraits::NUM_WARPS_KV; - [[maybe_unused]] constexpr SwizzleMode SWIZZLE_MODE_Q = - KTraits::SWIZZLE_MODE_Q; - [[maybe_unused]] constexpr SwizzleMode SWIZZLE_MODE_KV = - KTraits::SWIZZLE_MODE_KV; - [[maybe_unused]] constexpr uint32_t KV_THR_LAYOUT_ROW = - KTraits::KV_THR_LAYOUT_ROW; - [[maybe_unused]] constexpr uint32_t KV_THR_LAYOUT_COL = - KTraits::KV_THR_LAYOUT_COL; - [[maybe_unused]] constexpr MaskMode MASK_MODE = KTraits::MASK_MODE; - [[maybe_unused]] constexpr uint32_t VECTOR_BIT_WIDTH = - KTraits::VECTOR_BIT_WIDTH; - - DTypeQ *q = params.q; - IdType *request_indices = params.request_indices; - IdType *qo_tile_indices = params.qo_tile_indices; - IdType *kv_tile_indices = params.kv_tile_indices; - IdType *q_indptr = params.q_indptr; - IdType *kv_indptr = params.kv_indptr; - DTypeKV *k = params.k; - DTypeKV *v = params.v; - IdType *o_indptr = params.o_indptr; - DTypeO *o = params.o; - float *lse = params.lse; - bool *block_valid_mask = params.block_valid_mask; - const bool partition_kv = params.partition_kv; - const uint32_t q_stride_n = params.q_stride_n; - const uint32_t q_stride_h = params.q_stride_h; - const uint32_t k_stride_n = params.k_stride_n; - const uint32_t k_stride_h = params.k_stride_h; - const uint32_t v_stride_n = params.v_stride_n; - const uint32_t v_stride_h = params.v_stride_h; - const uint_fastdiv &group_size = params.group_size; - - static_assert(sizeof(DTypeQ) == 2); - const uint32_t kv_chunk_size = *(params.kv_chunk_size_ptr); - const dim3 &tid = threadIdx; - - auto block = cg::this_thread_block(); - const uint32_t bx = blockIdx.x, lane_idx = tid.x, - warp_idx = get_warp_idx(tid.y, tid.z), - kv_head_idx = blockIdx.z; - if (block_valid_mask && !block_valid_mask[bx]) { - return; - } - const uint32_t num_kv_heads = gridDim.z, - num_qo_heads = group_size * num_kv_heads; - const uint32_t request_idx = request_indices[bx], - qo_tile_idx = qo_tile_indices[bx], - kv_tile_idx = kv_tile_indices[bx]; - extern __shared__ uint8_t smem[]; - auto &smem_storage = - reinterpret_cast(smem); - AttentionVariant variant(params, /*batch_idx=*/request_idx, smem); - const uint32_t qo_len = variant.qo_len, kv_len = variant.kv_len, - window_left = variant.window_left; - const uint32_t kv_len_safe = kv_len > 0 ? kv_len : 1; - const uint32_t max_chunk_size = partition_kv ? kv_chunk_size : kv_len; - const uint32_t chunk_start = - partition_kv ? kv_tile_idx * max_chunk_size : 0; - const uint32_t chunk_end = - partition_kv ? min((kv_tile_idx + 1) * max_chunk_size, kv_len) - : kv_len; - const uint32_t chunk_size = chunk_end - chunk_start; - const uint32_t qo_upper_bound = - min(qo_len, ceil_div((qo_tile_idx + 1) * CTA_TILE_Q, group_size)); - - DTypeQKAccum s_frag[NUM_MMA_Q][NUM_MMA_KV][8]; - alignas(16) float o_frag[NUM_MMA_Q][NUM_MMA_D_VO][8]; - DTypeQKAccum m[NUM_MMA_Q][2]; - float d[NUM_MMA_Q][2]; - float rope_freq[NUM_MMA_D_QK / 2][4]; - - if constexpr (KTraits::POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) - { - const float rope_rcp_scale = params.rope_rcp_scale; - const float rope_rcp_theta = params.rope_rcp_theta; - init_rope_freq(rope_freq, rope_rcp_scale, rope_rcp_theta, - tid.x); - } - init_states(variant, o_frag, m, d); - - const uint32_t qo_packed_idx_base = - (qo_tile_idx * NUM_WARPS_Q + get_warp_idx_q(tid.y)) * - NUM_MMA_Q * 16; - smem_t qo_smem(smem_storage.q_smem); - const uint32_t o_stride_n = num_qo_heads * HEAD_DIM_VO, - o_stride_h = HEAD_DIM_VO; - - DTypeQ *q_ptr_base = q + q_indptr[request_idx] * q_stride_n + - kv_head_idx * group_size * q_stride_h; - - DTypeO *o_ptr_base = - partition_kv - ? o + (o_indptr[request_idx] + kv_tile_idx) * o_stride_n + - (kv_head_idx * group_size) * o_stride_h - : o + o_indptr[request_idx] * o_stride_n + - (kv_head_idx * group_size) * o_stride_h; - - uint32_t q_smem_offset_r = - qo_smem.template get_permuted_offset( - get_warp_idx_q(tid.y) * NUM_MMA_Q * 16 + lane_idx % 16, - lane_idx / 16); - - load_q_global_smem(qo_packed_idx_base, qo_upper_bound, - q_ptr_base, q_stride_n, q_stride_h, - group_size, &qo_smem, tid); - - memory::commit_group(); - - if constexpr (KTraits::POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) - { - memory::wait_group<0>(); - block.sync(); - IdType *q_rope_offset = nullptr; - - if constexpr (has_maybe_q_rope_offset_v) { - q_rope_offset = params.maybe_q_rope_offset; - } - if (!q_rope_offset) { - q_smem_inplace_apply_rotary( - qo_packed_idx_base, qo_len, kv_len, group_size, &qo_smem, - &q_smem_offset_r, rope_freq, tid); - } - else { - q_smem_inplace_apply_rotary_with_pos( - qo_packed_idx_base, q_rope_offset + q_indptr[request_idx], - &qo_smem, group_size, &q_smem_offset_r, rope_freq, tid); - } - block.sync(); - } + using DTypeKV = typename Params::DTypeKV; + using DTypeO = typename Params::DTypeO; + using IdType = typename Params::IdType; + using DTypeQKAccum = typename KTraits::DTypeQKAccum; + using AttentionVariant = typename KTraits::AttentionVariant; + [[maybe_unused]] constexpr uint32_t NUM_MMA_Q = KTraits::NUM_MMA_Q; + [[maybe_unused]] constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; + [[maybe_unused]] constexpr uint32_t NUM_MMA_D_QK = KTraits::NUM_MMA_D_QK; + [[maybe_unused]] constexpr uint32_t NUM_MMA_D_VO = KTraits::NUM_MMA_D_VO; + [[maybe_unused]] constexpr uint32_t HEAD_DIM_QK = KTraits::HEAD_DIM_QK; + [[maybe_unused]] constexpr uint32_t HEAD_DIM_VO = KTraits::HEAD_DIM_VO; + [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_Q = KTraits::UPCAST_STRIDE_Q; + [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_K = KTraits::UPCAST_STRIDE_K; + [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_V = KTraits::UPCAST_STRIDE_V; + [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_O = KTraits::UPCAST_STRIDE_O; + [[maybe_unused]] constexpr uint32_t CTA_TILE_Q = KTraits::CTA_TILE_Q; + [[maybe_unused]] constexpr uint32_t CTA_TILE_KV = KTraits::CTA_TILE_KV; + [[maybe_unused]] constexpr uint32_t NUM_WARPS_Q = KTraits::NUM_WARPS_Q; + [[maybe_unused]] constexpr uint32_t NUM_WARPS_KV = KTraits::NUM_WARPS_KV; + [[maybe_unused]] constexpr SwizzleMode SWIZZLE_MODE_Q = KTraits::SWIZZLE_MODE_Q; + [[maybe_unused]] constexpr SwizzleMode SWIZZLE_MODE_KV = KTraits::SWIZZLE_MODE_KV; + [[maybe_unused]] constexpr uint32_t KV_THR_LAYOUT_ROW = KTraits::KV_THR_LAYOUT_ROW; + [[maybe_unused]] constexpr uint32_t KV_THR_LAYOUT_COL = KTraits::KV_THR_LAYOUT_COL; + [[maybe_unused]] constexpr MaskMode MASK_MODE = KTraits::MASK_MODE; + [[maybe_unused]] constexpr uint32_t VECTOR_BIT_WIDTH = KTraits::VECTOR_BIT_WIDTH; + + DTypeQ* q = params.q; + IdType* request_indices = params.request_indices; + IdType* qo_tile_indices = params.qo_tile_indices; + IdType* kv_tile_indices = params.kv_tile_indices; + IdType* q_indptr = params.q_indptr; + IdType* kv_indptr = params.kv_indptr; + DTypeKV* k = params.k; + DTypeKV* v = params.v; + IdType* o_indptr = params.o_indptr; + DTypeO* o = params.o; + float* lse = params.lse; + bool* block_valid_mask = params.block_valid_mask; + const bool partition_kv = params.partition_kv; + const uint32_t q_stride_n = params.q_stride_n; + const uint32_t q_stride_h = params.q_stride_h; + const uint32_t k_stride_n = params.k_stride_n; + const uint32_t k_stride_h = params.k_stride_h; + const uint32_t v_stride_n = params.v_stride_n; + const uint32_t v_stride_h = params.v_stride_h; + const uint_fastdiv& group_size = params.group_size; + + static_assert(sizeof(DTypeQ) == 2); + const uint32_t kv_chunk_size = *(params.kv_chunk_size_ptr); + const dim3& tid = threadIdx; + + auto block = cg::this_thread_block(); + const uint32_t bx = blockIdx.x, lane_idx = tid.x, + warp_idx = get_warp_idx(tid.y, tid.z), kv_head_idx = blockIdx.z; + if (block_valid_mask && !block_valid_mask[bx]) { + return; + } + const uint32_t num_kv_heads = gridDim.z, num_qo_heads = group_size * num_kv_heads; + const uint32_t request_idx = request_indices[bx], qo_tile_idx = qo_tile_indices[bx], + kv_tile_idx = kv_tile_indices[bx]; + extern __shared__ uint8_t smem[]; + auto& smem_storage = reinterpret_cast(smem); + AttentionVariant variant(params, /*batch_idx=*/request_idx, smem); + const uint32_t qo_len = variant.qo_len, kv_len = variant.kv_len, + window_left = variant.window_left; + const uint32_t kv_len_safe = kv_len > 0 ? kv_len : 1; + const uint32_t max_chunk_size = partition_kv ? kv_chunk_size : kv_len; + const uint32_t chunk_start = partition_kv ? kv_tile_idx * max_chunk_size : 0; + const uint32_t chunk_end = + partition_kv ? min((kv_tile_idx + 1) * max_chunk_size, kv_len) : kv_len; + const uint32_t chunk_size = chunk_end - chunk_start; + const uint32_t qo_upper_bound = + min(qo_len, ceil_div((qo_tile_idx + 1) * CTA_TILE_Q, group_size)); + + DTypeQKAccum s_frag[NUM_MMA_Q][NUM_MMA_KV][8]; + alignas(16) float o_frag[NUM_MMA_Q][NUM_MMA_D_VO][8]; + DTypeQKAccum m[NUM_MMA_Q][2]; + float d[NUM_MMA_Q][2]; + float rope_freq[NUM_MMA_D_QK / 2][4]; + + if constexpr (KTraits::POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) { + const float rope_rcp_scale = params.rope_rcp_scale; + const float rope_rcp_theta = params.rope_rcp_theta; + init_rope_freq(rope_freq, rope_rcp_scale, rope_rcp_theta, tid.x); + } + init_states(variant, o_frag, m, d); + + const uint32_t qo_packed_idx_base = + (qo_tile_idx * NUM_WARPS_Q + get_warp_idx_q(tid.y)) * NUM_MMA_Q * 16; + smem_t qo_smem(smem_storage.q_smem); + const uint32_t o_stride_n = num_qo_heads * HEAD_DIM_VO, o_stride_h = HEAD_DIM_VO; + + DTypeQ* q_ptr_base = + q + q_indptr[request_idx] * q_stride_n + kv_head_idx * group_size * q_stride_h; + + DTypeO* o_ptr_base = partition_kv ? o + (o_indptr[request_idx] + kv_tile_idx) * o_stride_n + + (kv_head_idx * group_size) * o_stride_h + : o + o_indptr[request_idx] * o_stride_n + + (kv_head_idx * group_size) * o_stride_h; + + uint32_t q_smem_offset_r = qo_smem.template get_permuted_offset( + get_warp_idx_q(tid.y) * NUM_MMA_Q * 16 + lane_idx % 16, lane_idx / 16); + + load_q_global_smem(qo_packed_idx_base, qo_upper_bound, q_ptr_base, q_stride_n, + q_stride_h, group_size, &qo_smem, tid); + + memory::commit_group(); + + if constexpr (KTraits::POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) { + memory::wait_group<0>(); + block.sync(); + IdType* q_rope_offset = nullptr; + + if constexpr (has_maybe_q_rope_offset_v) { + q_rope_offset = params.maybe_q_rope_offset; + } + if (!q_rope_offset) { + q_smem_inplace_apply_rotary(qo_packed_idx_base, qo_len, kv_len, group_size, + &qo_smem, &q_smem_offset_r, rope_freq, tid); + } else { + q_smem_inplace_apply_rotary_with_pos( + qo_packed_idx_base, q_rope_offset + q_indptr[request_idx], &qo_smem, group_size, + &q_smem_offset_r, rope_freq, tid); + } + block.sync(); + } - const uint32_t num_iterations = ceil_div( - (MASK_MODE == MaskMode::kCausal - ? min(chunk_size, - sub_if_greater_or_zero( - kv_len - qo_len + - ((qo_tile_idx + 1) * CTA_TILE_Q) / group_size, - chunk_start)) - : chunk_size), - CTA_TILE_KV); - - const uint32_t window_iteration = - ceil_div(sub_if_greater_or_zero( - kv_len + (qo_tile_idx + 1) * CTA_TILE_Q / group_size, - qo_len + window_left + chunk_start), - CTA_TILE_KV); - - const uint32_t mask_iteration = - (MASK_MODE == MaskMode::kCausal - ? min(chunk_size, - sub_if_greater_or_zero( - kv_len + (qo_tile_idx * CTA_TILE_Q) / group_size - - qo_len, - chunk_start)) - : chunk_size) / - CTA_TILE_KV; - - smem_t k_smem(smem_storage.k_smem), - v_smem(smem_storage.v_smem); - - uint32_t k_smem_offset_r = - k_smem.template get_permuted_offset( - get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + - 8 * (lane_idx / 16) + lane_idx % 8, - (lane_idx % 16) / 8), - v_smem_offset_r = - v_smem.template get_permuted_offset( - get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + - lane_idx % 16, - lane_idx / 16), - k_smem_offset_w = - k_smem.template get_permuted_offset( - warp_idx * KV_THR_LAYOUT_ROW + - lane_idx / KV_THR_LAYOUT_COL, - lane_idx % KV_THR_LAYOUT_COL), - v_smem_offset_w = - v_smem.template get_permuted_offset( - warp_idx * KV_THR_LAYOUT_ROW + - lane_idx / KV_THR_LAYOUT_COL, - lane_idx % KV_THR_LAYOUT_COL); - - DTypeKV *k_ptr = - k + - (kv_indptr[request_idx] + chunk_start + - warp_idx * KV_THR_LAYOUT_ROW + lane_idx / KV_THR_LAYOUT_COL) * - k_stride_n + - kv_head_idx * k_stride_h + - (lane_idx % KV_THR_LAYOUT_COL) * - upcast_size(); - DTypeKV *v_ptr = - v + - (kv_indptr[request_idx] + chunk_start + - warp_idx * KV_THR_LAYOUT_ROW + lane_idx / KV_THR_LAYOUT_COL) * - v_stride_n + - kv_head_idx * v_stride_h + - (lane_idx % KV_THR_LAYOUT_COL) * - upcast_size(); - - produce_kv( - k_smem, &k_smem_offset_w, &k_ptr, k_stride_n, 0, chunk_size, tid); - memory::commit_group(); - produce_kv( - v_smem, &v_smem_offset_w, &v_ptr, v_stride_n, 0, chunk_size, tid); - memory::commit_group(); + const uint32_t num_iterations = ceil_div( + (MASK_MODE == MaskMode::kCausal + ? min(chunk_size, sub_if_greater_or_zero( + kv_len - qo_len + ((qo_tile_idx + 1) * CTA_TILE_Q) / group_size, + chunk_start)) + : chunk_size), + CTA_TILE_KV); + + const uint32_t window_iteration = + ceil_div(sub_if_greater_or_zero(kv_len + (qo_tile_idx + 1) * CTA_TILE_Q / group_size, + qo_len + window_left + chunk_start), + CTA_TILE_KV); + + const uint32_t mask_iteration = + (MASK_MODE == MaskMode::kCausal + ? min(chunk_size, + sub_if_greater_or_zero(kv_len + (qo_tile_idx * CTA_TILE_Q) / group_size - qo_len, + chunk_start)) + : chunk_size) / + CTA_TILE_KV; + + smem_t k_smem(smem_storage.k_smem), v_smem(smem_storage.v_smem); + + uint32_t k_smem_offset_r = k_smem.template get_permuted_offset( + get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + 8 * (lane_idx / 16) + + lane_idx % 8, + (lane_idx % 16) / 8), + v_smem_offset_r = v_smem.template get_permuted_offset( + get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + lane_idx % 16, lane_idx / 16), + k_smem_offset_w = k_smem.template get_permuted_offset( + warp_idx * KV_THR_LAYOUT_ROW + lane_idx / KV_THR_LAYOUT_COL, + lane_idx % KV_THR_LAYOUT_COL), + v_smem_offset_w = v_smem.template get_permuted_offset( + warp_idx * KV_THR_LAYOUT_ROW + lane_idx / KV_THR_LAYOUT_COL, + lane_idx % KV_THR_LAYOUT_COL); + + DTypeKV* k_ptr = k + + (kv_indptr[request_idx] + chunk_start + warp_idx * KV_THR_LAYOUT_ROW + + lane_idx / KV_THR_LAYOUT_COL) * + k_stride_n + + kv_head_idx * k_stride_h + + (lane_idx % KV_THR_LAYOUT_COL) * upcast_size(); + DTypeKV* v_ptr = v + + (kv_indptr[request_idx] + chunk_start + warp_idx * KV_THR_LAYOUT_ROW + + lane_idx / KV_THR_LAYOUT_COL) * + v_stride_n + + kv_head_idx * v_stride_h + + (lane_idx % KV_THR_LAYOUT_COL) * upcast_size(); + + produce_kv(k_smem, &k_smem_offset_w, &k_ptr, + k_stride_n, 0, chunk_size, tid); + memory::commit_group(); + produce_kv(v_smem, &v_smem_offset_w, &v_ptr, + v_stride_n, 0, chunk_size, tid); + memory::commit_group(); #pragma unroll 1 - for (uint32_t iter = 0; iter < num_iterations; ++iter) { - memory::wait_group<1>(); - block.sync(); - - if constexpr (KTraits::POS_ENCODING_MODE == - PosEncodingMode::kRoPELlama) - { - IdType *k_rope_offset = nullptr; - if constexpr (has_maybe_k_rope_offset_v) { - k_rope_offset = params.maybe_k_rope_offset; - } - k_smem_inplace_apply_rotary( - (k_rope_offset == nullptr ? 0 - : k_rope_offset[request_idx]) + - chunk_start + iter * CTA_TILE_KV, - &k_smem, &k_smem_offset_r, rope_freq, tid); - block.sync(); - } - - // compute attention score - compute_qk(&qo_smem, &q_smem_offset_r, &k_smem, - &k_smem_offset_r, s_frag); - - logits_transform( - params, variant, /*batch_idx=*/request_idx, qo_packed_idx_base, - chunk_start + - (iter * NUM_WARPS_KV + get_warp_idx_kv(tid.z)) * - NUM_MMA_KV * 16, - qo_len, kv_len, group_size, s_frag, tid, kv_head_idx); - - // apply mask - if (MASK_MODE == MaskMode::kCustom || - (iter >= mask_iteration || iter < window_iteration)) - { - logits_mask(params, variant, /*batch_idx=*/request_idx, - qo_packed_idx_base, - chunk_start + - (iter * NUM_WARPS_KV + - get_warp_idx_kv(tid.z)) * - NUM_MMA_KV * 16, - qo_len, kv_len, chunk_end, group_size, - s_frag, tid, kv_head_idx); - } - - // compute m,d states in online softmax - update_mdo_states(variant, s_frag, o_frag, m, d); - - block.sync(); - produce_kv( - k_smem, &k_smem_offset_w, &k_ptr, k_stride_n, - (iter + 1) * CTA_TILE_KV, chunk_size, tid); - memory::commit_group(); - memory::wait_group<1>(); - block.sync(); - - // compute sfm*v - compute_sfm_v(&v_smem, &v_smem_offset_r, s_frag, o_frag, - d); - - block.sync(); - produce_kv( - v_smem, &v_smem_offset_w, &v_ptr, v_stride_n, - (iter + 1) * CTA_TILE_KV, chunk_size, tid); - memory::commit_group(); + for (uint32_t iter = 0; iter < num_iterations; ++iter) { + memory::wait_group<1>(); + block.sync(); + + if constexpr (KTraits::POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) { + IdType* k_rope_offset = nullptr; + if constexpr (has_maybe_k_rope_offset_v) { + k_rope_offset = params.maybe_k_rope_offset; } - memory::wait_group<0>(); + k_smem_inplace_apply_rotary( + (k_rope_offset == nullptr ? 0 : k_rope_offset[request_idx]) + chunk_start + + iter * CTA_TILE_KV, + &k_smem, &k_smem_offset_r, rope_freq, tid); block.sync(); + } + + // compute attention score + compute_qk(&qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag); + + logits_transform( + params, variant, /*batch_idx=*/request_idx, qo_packed_idx_base, + chunk_start + (iter * NUM_WARPS_KV + get_warp_idx_kv(tid.z)) * NUM_MMA_KV * 16, + qo_len, kv_len, group_size, s_frag, tid, kv_head_idx); + + // apply mask + if (MASK_MODE == MaskMode::kCustom || (iter >= mask_iteration || iter < window_iteration)) { + logits_mask( + params, variant, /*batch_idx=*/request_idx, qo_packed_idx_base, + chunk_start + (iter * NUM_WARPS_KV + get_warp_idx_kv(tid.z)) * NUM_MMA_KV * 16, + qo_len, kv_len, chunk_end, group_size, s_frag, tid, kv_head_idx); + } + + // compute m,d states in online softmax + update_mdo_states(variant, s_frag, o_frag, m, d); + + block.sync(); + produce_kv( + k_smem, &k_smem_offset_w, &k_ptr, k_stride_n, (iter + 1) * CTA_TILE_KV, chunk_size, tid); + memory::commit_group(); + memory::wait_group<1>(); + block.sync(); + + // compute sfm*v + compute_sfm_v(&v_smem, &v_smem_offset_r, s_frag, o_frag, d); + + block.sync(); + produce_kv( + v_smem, &v_smem_offset_w, &v_ptr, v_stride_n, (iter + 1) * CTA_TILE_KV, chunk_size, tid); + memory::commit_group(); + } + memory::wait_group<0>(); + block.sync(); - finalize_m(variant, m); + finalize_m(variant, m); - // threadblock synchronization - threadblock_sync_mdo_states(o_frag, &smem_storage, m, d, - warp_idx, lane_idx, tid); + // threadblock synchronization + threadblock_sync_mdo_states(o_frag, &smem_storage, m, d, warp_idx, lane_idx, tid); - // normalize d - normalize_d(o_frag, m, d); + // normalize d + normalize_d(o_frag, m, d); - const uint32_t num_kv_chunks = - (kv_len_safe + kv_chunk_size - 1) / kv_chunk_size; + const uint32_t num_kv_chunks = (kv_len_safe + kv_chunk_size - 1) / kv_chunk_size; - // write back - write_o_reg_gmem( - o_frag, &qo_smem, o_ptr_base, qo_packed_idx_base, qo_len, - /*o_stride_n=*/ - partition_kv ? num_kv_chunks * o_stride_n : o_stride_n, - /*o_stride_h=*/o_stride_h, group_size, tid); + // write back + write_o_reg_gmem(o_frag, &qo_smem, o_ptr_base, qo_packed_idx_base, qo_len, + /*o_stride_n=*/ + partition_kv ? num_kv_chunks * o_stride_n : o_stride_n, + /*o_stride_h=*/o_stride_h, group_size, tid); - // write lse - if constexpr (AttentionVariant::use_softmax) { - if (lse != nullptr) { - if (get_warp_idx_kv(tid.z) == 0) { -#pragma unroll - for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { -#pragma unroll - for (uint32_t j = 0; j < 2; ++j) { - uint32_t q, r; - group_size.divmod(qo_packed_idx_base + - lane_idx / 4 + j * 8 + - mma_q * 16, - q, r); - const uint32_t qo_head_idx = - kv_head_idx * group_size + r; - const uint32_t qo_idx = q; - if (qo_idx < qo_len) { - if (partition_kv) { - lse[(o_indptr[request_idx] + - qo_idx * num_kv_chunks + kv_tile_idx) * - num_qo_heads + - qo_head_idx] = - gpu_iface::math::ptx_log2(d[mma_q][j]) + - float(m[mma_q][j]); - } - else { - lse[(o_indptr[request_idx] + qo_idx) * - num_qo_heads + - qo_head_idx] = - gpu_iface::math::ptx_log2(d[mma_q][j]) + - float(m[mma_q][j]); - } - } - } - } + // write lse + if constexpr (AttentionVariant::use_softmax) { + if (lse != nullptr) { + if (get_warp_idx_kv(tid.z) == 0) { +#pragma unroll + for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + uint32_t q, r; + group_size.divmod(qo_packed_idx_base + lane_idx / 4 + j * 8 + mma_q * 16, q, r); + const uint32_t qo_head_idx = kv_head_idx * group_size + r; + const uint32_t qo_idx = q; + if (qo_idx < qo_len) { + if (partition_kv) { + lse[(o_indptr[request_idx] + qo_idx * num_kv_chunks + kv_tile_idx) * + num_qo_heads + + qo_head_idx] = gpu_iface::math::ptx_log2(d[mma_q][j]) + float(m[mma_q][j]); + } else { + lse[(o_indptr[request_idx] + qo_idx) * num_qo_heads + qo_head_idx] = + gpu_iface::math::ptx_log2(d[mma_q][j]) + float(m[mma_q][j]); } + } } + } } -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) + } } +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) + } #endif } template __device__ __forceinline__ void BatchPrefillWithPagedKVCacheDevice( - const Params params, - typename KTraits::SharedStorage &smem_storage, - const dim3 tid = threadIdx, - const uint32_t bx = blockIdx.x, - const uint32_t kv_head_idx = blockIdx.z, - const uint32_t num_kv_heads = gridDim.z) -{ - using DTypeQ = typename Params::DTypeQ; + const Params params, typename KTraits::SharedStorage& smem_storage, const dim3 tid = threadIdx, + const uint32_t bx = blockIdx.x, const uint32_t kv_head_idx = blockIdx.z, + const uint32_t num_kv_heads = gridDim.z) { + using DTypeQ = typename Params::DTypeQ; #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) - if constexpr (std::is_same_v) { - FLASHINFER_RUNTIME_ASSERT( - "Prefill kernels do not support bf16 on sm75."); - } - else { + if constexpr (std::is_same_v) { + FLASHINFER_RUNTIME_ASSERT("Prefill kernels do not support bf16 on sm75."); + } else { #endif - using DTypeKV = typename Params::DTypeKV; - using DTypeO = typename Params::DTypeO; - using IdType = typename Params::IdType; - using DTypeQKAccum = typename KTraits::DTypeQKAccum; - using AttentionVariant = typename KTraits::AttentionVariant; - [[maybe_unused]] constexpr uint32_t NUM_MMA_Q = KTraits::NUM_MMA_Q; - [[maybe_unused]] constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; - [[maybe_unused]] constexpr uint32_t NUM_MMA_D_QK = - KTraits::NUM_MMA_D_QK; - [[maybe_unused]] constexpr uint32_t NUM_MMA_D_VO = - KTraits::NUM_MMA_D_VO; - [[maybe_unused]] constexpr uint32_t HEAD_DIM_QK = KTraits::HEAD_DIM_QK; - [[maybe_unused]] constexpr uint32_t HEAD_DIM_VO = KTraits::HEAD_DIM_VO; - [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_Q = - KTraits::UPCAST_STRIDE_Q; - [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_K = - KTraits::UPCAST_STRIDE_K; - [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_V = - KTraits::UPCAST_STRIDE_V; - [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_O = - KTraits::UPCAST_STRIDE_O; - [[maybe_unused]] constexpr uint32_t NUM_WARPS_Q = KTraits::NUM_WARPS_Q; - [[maybe_unused]] constexpr uint32_t NUM_WARPS_KV = - KTraits::NUM_WARPS_KV; - [[maybe_unused]] constexpr SwizzleMode SWIZZLE_MODE_Q = - KTraits::SWIZZLE_MODE_Q; - [[maybe_unused]] constexpr SwizzleMode SWIZZLE_MODE_KV = - KTraits::SWIZZLE_MODE_KV; - [[maybe_unused]] constexpr uint32_t CTA_TILE_Q = KTraits::CTA_TILE_Q; - [[maybe_unused]] constexpr uint32_t CTA_TILE_KV = KTraits::CTA_TILE_KV; - [[maybe_unused]] constexpr uint32_t KV_THR_LAYOUT_ROW = - KTraits::KV_THR_LAYOUT_ROW; - [[maybe_unused]] constexpr uint32_t KV_THR_LAYOUT_COL = - KTraits::KV_THR_LAYOUT_COL; - [[maybe_unused]] constexpr MaskMode MASK_MODE = KTraits::MASK_MODE; - [[maybe_unused]] constexpr uint32_t VECTOR_BIT_WIDTH = - KTraits::VECTOR_BIT_WIDTH; - - IdType *request_indices = params.request_indices; - IdType *qo_tile_indices = params.qo_tile_indices; - IdType *kv_tile_indices = params.kv_tile_indices; - DTypeQ *q = params.q; - IdType *q_indptr = params.q_indptr; - IdType *o_indptr = params.o_indptr; - DTypeO *o = params.o; - float *lse = params.lse; - bool *block_valid_mask = params.block_valid_mask; - const paged_kv_t &paged_kv = params.paged_kv; - const bool partition_kv = params.partition_kv; - const uint_fastdiv &group_size = params.group_size; - - static_assert(sizeof(DTypeQ) == 2); - auto block = cg::this_thread_block(); - const uint32_t kv_chunk_size = *(params.kv_chunk_size_ptr); - - const uint32_t lane_idx = tid.x, - warp_idx = get_warp_idx(tid.y, tid.z); - if (block_valid_mask && !block_valid_mask[bx]) { - return; - } - const uint32_t num_qo_heads = num_kv_heads * group_size; - - const uint32_t request_idx = request_indices[bx], - qo_tile_idx = qo_tile_indices[bx], - kv_tile_idx = kv_tile_indices[bx]; - auto smem = reinterpret_cast(&smem_storage); - AttentionVariant variant(params, /*batch_idx=*/request_idx, smem); - const uint32_t qo_len = variant.qo_len, kv_len = variant.kv_len, - window_left = variant.window_left; - const uint32_t kv_len_safe = kv_len > 0 ? kv_len : 1; - const uint32_t max_chunk_size = partition_kv ? kv_chunk_size : kv_len; - const uint32_t chunk_start = - partition_kv ? kv_tile_idx * max_chunk_size : 0; - const uint32_t chunk_end = - partition_kv ? min((kv_tile_idx + 1) * max_chunk_size, kv_len) - : kv_len; - const uint32_t chunk_size = chunk_end - chunk_start; - const uint32_t qo_upper_bound = - min(qo_len, ceil_div((qo_tile_idx + 1) * CTA_TILE_Q, group_size)); - - DTypeQKAccum s_frag[NUM_MMA_Q][NUM_MMA_KV][8]; - alignas(16) float o_frag[NUM_MMA_Q][NUM_MMA_D_VO][8]; - DTypeQKAccum m[NUM_MMA_Q][2]; - float d[NUM_MMA_Q][2]; - float rope_freq[NUM_MMA_D_QK / 2][4]; - - if constexpr (KTraits::POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) - { - const float rope_rcp_scale = params.rope_rcp_scale; - const float rope_rcp_theta = params.rope_rcp_theta; - init_rope_freq(rope_freq, rope_rcp_scale, rope_rcp_theta, - tid.x); - } - init_states(variant, o_frag, m, d); - - const uint32_t qo_packed_idx_base = - (qo_tile_idx * NUM_WARPS_Q + get_warp_idx_q(tid.y)) * - NUM_MMA_Q * 16; - const uint32_t q_stride_n = params.q_stride_n, - q_stride_h = params.q_stride_h; - smem_t qo_smem(smem_storage.q_smem); - const uint32_t o_stride_n = num_qo_heads * HEAD_DIM_VO, - o_stride_h = HEAD_DIM_VO; - DTypeQ *q_ptr_base = q + q_indptr[request_idx] * q_stride_n + - (kv_head_idx * group_size) * q_stride_h; - DTypeO *o_ptr_base = - partition_kv - ? o + (o_indptr[request_idx] + kv_tile_idx) * o_stride_n + - (kv_head_idx * group_size) * o_stride_h - : o + o_indptr[request_idx] * o_stride_n + - (kv_head_idx * group_size) * o_stride_h; - uint32_t q_smem_offset_r = - qo_smem.template get_permuted_offset( - get_warp_idx_q(tid.y) * NUM_MMA_Q * 16 + lane_idx % 16, - lane_idx / 16); - - load_q_global_smem(qo_packed_idx_base, qo_upper_bound, - q_ptr_base, q_stride_n, q_stride_h, - group_size, &qo_smem, tid); - - memory::commit_group(); - - if constexpr (KTraits::POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) - { - memory::wait_group<0>(); - block.sync(); - IdType *q_rope_offset = nullptr; - if constexpr (has_maybe_q_rope_offset_v) { - q_rope_offset = params.maybe_q_rope_offset; - } - if (q_rope_offset == nullptr) { - q_smem_inplace_apply_rotary( - qo_packed_idx_base, qo_len, kv_len, group_size, &qo_smem, - &q_smem_offset_r, rope_freq, tid); - } - else { - q_smem_inplace_apply_rotary_with_pos( - qo_packed_idx_base, q_rope_offset + q_indptr[request_idx], - &qo_smem, group_size, &q_smem_offset_r, rope_freq, tid); - } - block.sync(); - } + using DTypeKV = typename Params::DTypeKV; + using DTypeO = typename Params::DTypeO; + using IdType = typename Params::IdType; + using DTypeQKAccum = typename KTraits::DTypeQKAccum; + using AttentionVariant = typename KTraits::AttentionVariant; + [[maybe_unused]] constexpr uint32_t NUM_MMA_Q = KTraits::NUM_MMA_Q; + [[maybe_unused]] constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; + [[maybe_unused]] constexpr uint32_t NUM_MMA_D_QK = KTraits::NUM_MMA_D_QK; + [[maybe_unused]] constexpr uint32_t NUM_MMA_D_VO = KTraits::NUM_MMA_D_VO; + [[maybe_unused]] constexpr uint32_t HEAD_DIM_QK = KTraits::HEAD_DIM_QK; + [[maybe_unused]] constexpr uint32_t HEAD_DIM_VO = KTraits::HEAD_DIM_VO; + [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_Q = KTraits::UPCAST_STRIDE_Q; + [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_K = KTraits::UPCAST_STRIDE_K; + [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_V = KTraits::UPCAST_STRIDE_V; + [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_O = KTraits::UPCAST_STRIDE_O; + [[maybe_unused]] constexpr uint32_t NUM_WARPS_Q = KTraits::NUM_WARPS_Q; + [[maybe_unused]] constexpr uint32_t NUM_WARPS_KV = KTraits::NUM_WARPS_KV; + [[maybe_unused]] constexpr SwizzleMode SWIZZLE_MODE_Q = KTraits::SWIZZLE_MODE_Q; + [[maybe_unused]] constexpr SwizzleMode SWIZZLE_MODE_KV = KTraits::SWIZZLE_MODE_KV; + [[maybe_unused]] constexpr uint32_t CTA_TILE_Q = KTraits::CTA_TILE_Q; + [[maybe_unused]] constexpr uint32_t CTA_TILE_KV = KTraits::CTA_TILE_KV; + [[maybe_unused]] constexpr uint32_t KV_THR_LAYOUT_ROW = KTraits::KV_THR_LAYOUT_ROW; + [[maybe_unused]] constexpr uint32_t KV_THR_LAYOUT_COL = KTraits::KV_THR_LAYOUT_COL; + [[maybe_unused]] constexpr MaskMode MASK_MODE = KTraits::MASK_MODE; + [[maybe_unused]] constexpr uint32_t VECTOR_BIT_WIDTH = KTraits::VECTOR_BIT_WIDTH; + + IdType* request_indices = params.request_indices; + IdType* qo_tile_indices = params.qo_tile_indices; + IdType* kv_tile_indices = params.kv_tile_indices; + DTypeQ* q = params.q; + IdType* q_indptr = params.q_indptr; + IdType* o_indptr = params.o_indptr; + DTypeO* o = params.o; + float* lse = params.lse; + bool* block_valid_mask = params.block_valid_mask; + const paged_kv_t& paged_kv = params.paged_kv; + const bool partition_kv = params.partition_kv; + const uint_fastdiv& group_size = params.group_size; + + static_assert(sizeof(DTypeQ) == 2); + auto block = cg::this_thread_block(); + const uint32_t kv_chunk_size = *(params.kv_chunk_size_ptr); + + const uint32_t lane_idx = tid.x, warp_idx = get_warp_idx(tid.y, tid.z); + if (block_valid_mask && !block_valid_mask[bx]) { + return; + } + const uint32_t num_qo_heads = num_kv_heads * group_size; + + const uint32_t request_idx = request_indices[bx], qo_tile_idx = qo_tile_indices[bx], + kv_tile_idx = kv_tile_indices[bx]; + auto smem = reinterpret_cast(&smem_storage); + AttentionVariant variant(params, /*batch_idx=*/request_idx, smem); + const uint32_t qo_len = variant.qo_len, kv_len = variant.kv_len, + window_left = variant.window_left; + const uint32_t kv_len_safe = kv_len > 0 ? kv_len : 1; + const uint32_t max_chunk_size = partition_kv ? kv_chunk_size : kv_len; + const uint32_t chunk_start = partition_kv ? kv_tile_idx * max_chunk_size : 0; + const uint32_t chunk_end = + partition_kv ? min((kv_tile_idx + 1) * max_chunk_size, kv_len) : kv_len; + const uint32_t chunk_size = chunk_end - chunk_start; + const uint32_t qo_upper_bound = + min(qo_len, ceil_div((qo_tile_idx + 1) * CTA_TILE_Q, group_size)); + + DTypeQKAccum s_frag[NUM_MMA_Q][NUM_MMA_KV][8]; + alignas(16) float o_frag[NUM_MMA_Q][NUM_MMA_D_VO][8]; + DTypeQKAccum m[NUM_MMA_Q][2]; + float d[NUM_MMA_Q][2]; + float rope_freq[NUM_MMA_D_QK / 2][4]; + + if constexpr (KTraits::POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) { + const float rope_rcp_scale = params.rope_rcp_scale; + const float rope_rcp_theta = params.rope_rcp_theta; + init_rope_freq(rope_freq, rope_rcp_scale, rope_rcp_theta, tid.x); + } + init_states(variant, o_frag, m, d); + + const uint32_t qo_packed_idx_base = + (qo_tile_idx * NUM_WARPS_Q + get_warp_idx_q(tid.y)) * NUM_MMA_Q * 16; + const uint32_t q_stride_n = params.q_stride_n, q_stride_h = params.q_stride_h; + smem_t qo_smem(smem_storage.q_smem); + const uint32_t o_stride_n = num_qo_heads * HEAD_DIM_VO, o_stride_h = HEAD_DIM_VO; + DTypeQ* q_ptr_base = + q + q_indptr[request_idx] * q_stride_n + (kv_head_idx * group_size) * q_stride_h; + DTypeO* o_ptr_base = partition_kv ? o + (o_indptr[request_idx] + kv_tile_idx) * o_stride_n + + (kv_head_idx * group_size) * o_stride_h + : o + o_indptr[request_idx] * o_stride_n + + (kv_head_idx * group_size) * o_stride_h; + uint32_t q_smem_offset_r = qo_smem.template get_permuted_offset( + get_warp_idx_q(tid.y) * NUM_MMA_Q * 16 + lane_idx % 16, lane_idx / 16); + + load_q_global_smem(qo_packed_idx_base, qo_upper_bound, q_ptr_base, q_stride_n, + q_stride_h, group_size, &qo_smem, tid); + + memory::commit_group(); + + if constexpr (KTraits::POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) { + memory::wait_group<0>(); + block.sync(); + IdType* q_rope_offset = nullptr; + if constexpr (has_maybe_q_rope_offset_v) { + q_rope_offset = params.maybe_q_rope_offset; + } + if (q_rope_offset == nullptr) { + q_smem_inplace_apply_rotary(qo_packed_idx_base, qo_len, kv_len, group_size, + &qo_smem, &q_smem_offset_r, rope_freq, tid); + } else { + q_smem_inplace_apply_rotary_with_pos( + qo_packed_idx_base, q_rope_offset + q_indptr[request_idx], &qo_smem, group_size, + &q_smem_offset_r, rope_freq, tid); + } + block.sync(); + } - smem_t k_smem(smem_storage.k_smem), - v_smem(smem_storage.v_smem); - size_t thr_local_kv_offset[NUM_MMA_KV * KV_THR_LAYOUT_COL / 2 / - NUM_WARPS_Q]; - - uint32_t k_smem_offset_r = - k_smem.template get_permuted_offset( - get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + - 8 * (lane_idx / 16) + lane_idx % 8, - (lane_idx % 16) / 8), - v_smem_offset_r = - v_smem.template get_permuted_offset( - get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + - lane_idx % 16, - lane_idx / 16), - k_smem_offset_w = - k_smem.template get_permuted_offset( - warp_idx * KV_THR_LAYOUT_ROW + - lane_idx / KV_THR_LAYOUT_COL, - lane_idx % KV_THR_LAYOUT_COL), - v_smem_offset_w = - v_smem.template get_permuted_offset( - warp_idx * KV_THR_LAYOUT_ROW + - lane_idx / KV_THR_LAYOUT_COL, - lane_idx % KV_THR_LAYOUT_COL); - const IdType last_indptr = paged_kv.indptr[paged_kv.batch_size]; - - uint32_t packed_page_iter_base = - paged_kv.indptr[request_idx] * paged_kv.page_size + chunk_start; -#pragma unroll - for (uint32_t i = 0; - i < NUM_MMA_KV * (SWIZZLE_MODE_KV == SwizzleMode::k128B ? 4 : 2) / - NUM_WARPS_Q; - ++i) - { - uint32_t page_iter, entry_idx; - paged_kv.page_size.divmod( - packed_page_iter_base + warp_idx * KV_THR_LAYOUT_ROW + - lane_idx / KV_THR_LAYOUT_COL + - KV_THR_LAYOUT_ROW * NUM_WARPS_Q * NUM_WARPS_KV * i, - page_iter, entry_idx); - thr_local_kv_offset[i] = paged_kv.protective_get_kv_offset( - page_iter, kv_head_idx, entry_idx, - (lane_idx % KV_THR_LAYOUT_COL) * - upcast_size(), - last_indptr); - } - page_produce_kv(k_smem, &k_smem_offset_w, paged_kv, 0, - thr_local_kv_offset, chunk_size, tid); - memory::commit_group(); - page_produce_kv(v_smem, &v_smem_offset_w, paged_kv, 0, - thr_local_kv_offset, chunk_size, tid); - memory::commit_group(); - - const uint32_t num_iterations = ceil_div( - (MASK_MODE == MaskMode::kCausal - ? min(chunk_size, - sub_if_greater_or_zero( - kv_len - qo_len + - ((qo_tile_idx + 1) * CTA_TILE_Q) / group_size, - chunk_start)) - : chunk_size), - CTA_TILE_KV); - - const uint32_t window_iteration = - ceil_div(sub_if_greater_or_zero( - kv_len + (qo_tile_idx + 1) * CTA_TILE_Q / group_size, - qo_len + window_left + chunk_start), - CTA_TILE_KV); - - const uint32_t mask_iteration = - (MASK_MODE == MaskMode::kCausal - ? min(chunk_size, - sub_if_greater_or_zero( - kv_len + (qo_tile_idx * CTA_TILE_Q) / group_size - - qo_len, - chunk_start)) - : chunk_size) / - CTA_TILE_KV; + smem_t k_smem(smem_storage.k_smem), v_smem(smem_storage.v_smem); + size_t thr_local_kv_offset[NUM_MMA_KV * KV_THR_LAYOUT_COL / 2 / NUM_WARPS_Q]; + + uint32_t k_smem_offset_r = k_smem.template get_permuted_offset( + get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + 8 * (lane_idx / 16) + + lane_idx % 8, + (lane_idx % 16) / 8), + v_smem_offset_r = v_smem.template get_permuted_offset( + get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + lane_idx % 16, lane_idx / 16), + k_smem_offset_w = k_smem.template get_permuted_offset( + warp_idx * KV_THR_LAYOUT_ROW + lane_idx / KV_THR_LAYOUT_COL, + lane_idx % KV_THR_LAYOUT_COL), + v_smem_offset_w = v_smem.template get_permuted_offset( + warp_idx * KV_THR_LAYOUT_ROW + lane_idx / KV_THR_LAYOUT_COL, + lane_idx % KV_THR_LAYOUT_COL); + const IdType last_indptr = paged_kv.indptr[paged_kv.batch_size]; + + uint32_t packed_page_iter_base = + paged_kv.indptr[request_idx] * paged_kv.page_size + chunk_start; +#pragma unroll + for (uint32_t i = 0; + i < NUM_MMA_KV * (SWIZZLE_MODE_KV == SwizzleMode::k128B ? 4 : 2) / NUM_WARPS_Q; ++i) { + uint32_t page_iter, entry_idx; + paged_kv.page_size.divmod(packed_page_iter_base + warp_idx * KV_THR_LAYOUT_ROW + + lane_idx / KV_THR_LAYOUT_COL + + KV_THR_LAYOUT_ROW * NUM_WARPS_Q * NUM_WARPS_KV * i, + page_iter, entry_idx); + thr_local_kv_offset[i] = paged_kv.protective_get_kv_offset( + page_iter, kv_head_idx, entry_idx, + (lane_idx % KV_THR_LAYOUT_COL) * upcast_size(), last_indptr); + } + page_produce_kv(k_smem, &k_smem_offset_w, paged_kv, 0, thr_local_kv_offset, + chunk_size, tid); + memory::commit_group(); + page_produce_kv(v_smem, &v_smem_offset_w, paged_kv, 0, thr_local_kv_offset, + chunk_size, tid); + memory::commit_group(); + + const uint32_t num_iterations = ceil_div( + (MASK_MODE == MaskMode::kCausal + ? min(chunk_size, sub_if_greater_or_zero( + kv_len - qo_len + ((qo_tile_idx + 1) * CTA_TILE_Q) / group_size, + chunk_start)) + : chunk_size), + CTA_TILE_KV); + + const uint32_t window_iteration = + ceil_div(sub_if_greater_or_zero(kv_len + (qo_tile_idx + 1) * CTA_TILE_Q / group_size, + qo_len + window_left + chunk_start), + CTA_TILE_KV); + + const uint32_t mask_iteration = + (MASK_MODE == MaskMode::kCausal + ? min(chunk_size, + sub_if_greater_or_zero(kv_len + (qo_tile_idx * CTA_TILE_Q) / group_size - qo_len, + chunk_start)) + : chunk_size) / + CTA_TILE_KV; #pragma unroll 1 - for (uint32_t iter = 0; iter < num_iterations; ++iter) { - packed_page_iter_base += CTA_TILE_KV; -#pragma unroll - for (uint32_t i = 0; - i < NUM_MMA_KV * - (SWIZZLE_MODE_KV == SwizzleMode::k128B ? 4 : 2) / - NUM_WARPS_Q; - ++i) - { - uint32_t page_iter, entry_idx; - paged_kv.page_size.divmod( - packed_page_iter_base + warp_idx * KV_THR_LAYOUT_ROW + - lane_idx / KV_THR_LAYOUT_COL + - KV_THR_LAYOUT_ROW * NUM_WARPS_Q * NUM_WARPS_KV * i, - page_iter, entry_idx); - thr_local_kv_offset[i] = paged_kv.protective_get_kv_offset( - page_iter, kv_head_idx, entry_idx, - (lane_idx % KV_THR_LAYOUT_COL) * - upcast_size(), - last_indptr); - } - memory::wait_group<1>(); - block.sync(); - - if constexpr (KTraits::POS_ENCODING_MODE == - PosEncodingMode::kRoPELlama) - { - k_smem_inplace_apply_rotary( - (paged_kv.rope_pos_offset == nullptr - ? 0 - : paged_kv.rope_pos_offset[request_idx]) + - chunk_start + iter * CTA_TILE_KV, - &k_smem, &k_smem_offset_r, rope_freq, tid); - block.sync(); - } + for (uint32_t iter = 0; iter < num_iterations; ++iter) { + packed_page_iter_base += CTA_TILE_KV; +#pragma unroll + for (uint32_t i = 0; + i < NUM_MMA_KV * (SWIZZLE_MODE_KV == SwizzleMode::k128B ? 4 : 2) / NUM_WARPS_Q; ++i) { + uint32_t page_iter, entry_idx; + paged_kv.page_size.divmod(packed_page_iter_base + warp_idx * KV_THR_LAYOUT_ROW + + lane_idx / KV_THR_LAYOUT_COL + + KV_THR_LAYOUT_ROW * NUM_WARPS_Q * NUM_WARPS_KV * i, + page_iter, entry_idx); + thr_local_kv_offset[i] = paged_kv.protective_get_kv_offset( + page_iter, kv_head_idx, entry_idx, + (lane_idx % KV_THR_LAYOUT_COL) * upcast_size(), last_indptr); + } + memory::wait_group<1>(); + block.sync(); + + if constexpr (KTraits::POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) { + k_smem_inplace_apply_rotary( + (paged_kv.rope_pos_offset == nullptr ? 0 : paged_kv.rope_pos_offset[request_idx]) + + chunk_start + iter * CTA_TILE_KV, + &k_smem, &k_smem_offset_r, rope_freq, tid); + block.sync(); + } + + // compute attention score + compute_qk(&qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag); + + logits_transform( + params, variant, /*batch_idx=*/request_idx, qo_packed_idx_base, + chunk_start + (iter * NUM_WARPS_KV + get_warp_idx_kv(tid.z)) * NUM_MMA_KV * 16, + qo_len, kv_len, group_size, s_frag, tid, kv_head_idx); + + // apply mask + if (MASK_MODE == MaskMode::kCustom || (iter >= mask_iteration || iter < window_iteration)) { + logits_mask( + params, variant, /*batch_idx=*/request_idx, qo_packed_idx_base, + chunk_start + (iter * NUM_WARPS_KV + get_warp_idx_kv(tid.z)) * NUM_MMA_KV * 16, + qo_len, kv_len, chunk_end, group_size, s_frag, tid, kv_head_idx); + } + + // compute m,d states in online softmax + update_mdo_states(variant, s_frag, o_frag, m, d); + + block.sync(); + page_produce_kv(k_smem, &k_smem_offset_w, paged_kv, (iter + 1) * CTA_TILE_KV, + thr_local_kv_offset, chunk_size, tid); + memory::commit_group(); + memory::wait_group<1>(); + block.sync(); + + // compute sfm*v + compute_sfm_v(&v_smem, &v_smem_offset_r, s_frag, o_frag, d); + + block.sync(); + page_produce_kv(v_smem, &v_smem_offset_w, paged_kv, (iter + 1) * CTA_TILE_KV, + thr_local_kv_offset, chunk_size, tid); + memory::commit_group(); + } + memory::wait_group<0>(); + block.sync(); - // compute attention score - compute_qk(&qo_smem, &q_smem_offset_r, &k_smem, - &k_smem_offset_r, s_frag); - - logits_transform( - params, variant, /*batch_idx=*/request_idx, qo_packed_idx_base, - chunk_start + - (iter * NUM_WARPS_KV + get_warp_idx_kv(tid.z)) * - NUM_MMA_KV * 16, - qo_len, kv_len, group_size, s_frag, tid, kv_head_idx); - - // apply mask - if (MASK_MODE == MaskMode::kCustom || - (iter >= mask_iteration || iter < window_iteration)) - { - logits_mask(params, variant, /*batch_idx=*/request_idx, - qo_packed_idx_base, - chunk_start + - (iter * NUM_WARPS_KV + - get_warp_idx_kv(tid.z)) * - NUM_MMA_KV * 16, - qo_len, kv_len, chunk_end, group_size, - s_frag, tid, kv_head_idx); - } + finalize_m(variant, m); - // compute m,d states in online softmax - update_mdo_states(variant, s_frag, o_frag, m, d); - - block.sync(); - page_produce_kv( - k_smem, &k_smem_offset_w, paged_kv, (iter + 1) * CTA_TILE_KV, - thr_local_kv_offset, chunk_size, tid); - memory::commit_group(); - memory::wait_group<1>(); - block.sync(); - - // compute sfm*v - compute_sfm_v(&v_smem, &v_smem_offset_r, s_frag, o_frag, - d); - - block.sync(); - page_produce_kv( - v_smem, &v_smem_offset_w, paged_kv, (iter + 1) * CTA_TILE_KV, - thr_local_kv_offset, chunk_size, tid); - memory::commit_group(); - } - memory::wait_group<0>(); - block.sync(); + // threadblock synchronization + threadblock_sync_mdo_states(o_frag, &smem_storage, m, d, warp_idx, lane_idx, tid); - finalize_m(variant, m); - - // threadblock synchronization - threadblock_sync_mdo_states(o_frag, &smem_storage, m, d, - warp_idx, lane_idx, tid); - - // normalize d - normalize_d(o_frag, m, d); - - const uint32_t num_kv_chunks = - (kv_len_safe + kv_chunk_size - 1) / kv_chunk_size; - - // write_back - write_o_reg_gmem( - o_frag, &qo_smem, o_ptr_base, qo_packed_idx_base, qo_len, - /*o_stride_n=*/ - partition_kv ? num_kv_chunks * o_stride_n : o_stride_n, - /*o_stride_h=*/o_stride_h, group_size, tid); - - // write lse - if constexpr (variant.use_softmax) { - if (lse != nullptr) { - if (get_warp_idx_kv(tid.z) == 0) { -#pragma unroll - for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { -#pragma unroll - for (uint32_t j = 0; j < 2; ++j) { - uint32_t q, r; - group_size.divmod(qo_packed_idx_base + - lane_idx / 4 + j * 8 + - mma_q * 16, - q, r); - const uint32_t qo_head_idx = - kv_head_idx * group_size + r; - const uint32_t qo_idx = q; - if (qo_idx < qo_upper_bound) { - if (partition_kv) { - lse[(o_indptr[request_idx] + - qo_idx * num_kv_chunks + kv_tile_idx) * - num_qo_heads + - qo_head_idx] = - gpu_iface::math::ptx_log2(d[mma_q][j]) + - float(m[mma_q][j]); - } - else { - lse[(o_indptr[request_idx] + qo_idx) * - num_qo_heads + - qo_head_idx] = - gpu_iface::math::ptx_log2(d[mma_q][j]) + - float(m[mma_q][j]); - } - } - } - } + // normalize d + normalize_d(o_frag, m, d); + + const uint32_t num_kv_chunks = (kv_len_safe + kv_chunk_size - 1) / kv_chunk_size; + + // write_back + write_o_reg_gmem(o_frag, &qo_smem, o_ptr_base, qo_packed_idx_base, qo_len, + /*o_stride_n=*/ + partition_kv ? num_kv_chunks * o_stride_n : o_stride_n, + /*o_stride_h=*/o_stride_h, group_size, tid); + + // write lse + if constexpr (variant.use_softmax) { + if (lse != nullptr) { + if (get_warp_idx_kv(tid.z) == 0) { +#pragma unroll + for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + uint32_t q, r; + group_size.divmod(qo_packed_idx_base + lane_idx / 4 + j * 8 + mma_q * 16, q, r); + const uint32_t qo_head_idx = kv_head_idx * group_size + r; + const uint32_t qo_idx = q; + if (qo_idx < qo_upper_bound) { + if (partition_kv) { + lse[(o_indptr[request_idx] + qo_idx * num_kv_chunks + kv_tile_idx) * + num_qo_heads + + qo_head_idx] = gpu_iface::math::ptx_log2(d[mma_q][j]) + float(m[mma_q][j]); + } else { + lse[(o_indptr[request_idx] + qo_idx) * num_qo_heads + qo_head_idx] = + gpu_iface::math::ptx_log2(d[mma_q][j]) + float(m[mma_q][j]); } + } } + } } -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) + } } +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) + } #endif } template -__global__ -__launch_bounds__(KTraits::NUM_THREADS) void BatchPrefillWithPagedKVCacheKernel( - const __grid_constant__ Params params) -{ - extern __shared__ uint8_t smem[]; - auto &smem_storage = - reinterpret_cast(smem); - BatchPrefillWithPagedKVCacheDevice(params, smem_storage); +__global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchPrefillWithPagedKVCacheKernel( + const __grid_constant__ Params params) { + extern __shared__ uint8_t smem[]; + auto& smem_storage = reinterpret_cast(smem); + BatchPrefillWithPagedKVCacheDevice(params, smem_storage); } -template -gpuError_t -BatchPrefillWithRaggedKVCacheDispatched(Params params, - typename Params::DTypeO *tmp_v, - float *tmp_s, - gpuStream_t stream) -{ - using DTypeQ = typename Params::DTypeQ; - using DTypeKV = typename Params::DTypeKV; - using DTypeO = typename Params::DTypeO; - const uint32_t padded_batch_size = params.padded_batch_size; - const uint32_t num_qo_heads = params.num_qo_heads; - const uint32_t num_kv_heads = params.num_kv_heads; - constexpr uint32_t NUM_MMA_Q = get_num_mma_q(CTA_TILE_Q); - constexpr uint32_t NUM_WARPS_Q = get_num_warps_q(CTA_TILE_Q); - constexpr uint32_t NUM_WARPS_KV = get_num_warps_kv(CTA_TILE_Q); - - if (padded_batch_size == 0) { - // No request, skip - // this won't happen in CUDAGraph mode because we fixed the - // padded_batch_size - return gpuSuccess; - } - - dim3 nblks(padded_batch_size, 1, num_kv_heads); - dim3 nthrs(WARP_SIZE, NUM_WARPS_Q, NUM_WARPS_KV); - constexpr uint32_t NUM_MMA_D_QK = HEAD_DIM_QK / 16; - constexpr uint32_t NUM_MMA_D_VO = HEAD_DIM_VO / 16; - using DTypeQKAccum = - typename std::conditional, - half, float>::type; - - int dev_id = 0; - FI_GPU_CALL(gpuGetDevice(&dev_id)); - int max_smem_per_sm = getMaxSharedMemPerMultiprocessor(dev_id); - // we expect each sm execute two threadblocks - const int num_ctas_per_sm = - max_smem_per_sm >= 2 * (CTA_TILE_Q * HEAD_DIM_QK * sizeof(DTypeQ) + - (HEAD_DIM_QK + HEAD_DIM_VO) * 16 * - NUM_WARPS_KV * sizeof(DTypeKV)) - ? 2 - : 1; - const int max_smem_per_threadblock = max_smem_per_sm / num_ctas_per_sm; - - const uint32_t max_num_mma_kv_reg = - (HEAD_DIM_VO >= 128 && NUM_MMA_Q == 2 && - POS_ENCODING_MODE == PosEncodingMode::kRoPELlama && - !USE_FP16_QK_REDUCTION) - ? 2 - : (8 / NUM_MMA_Q); - const uint32_t max_num_mma_kv_smem = - (max_smem_per_threadblock - CTA_TILE_Q * HEAD_DIM_QK * sizeof(DTypeQ)) / - ((HEAD_DIM_QK + HEAD_DIM_VO) * 16 * NUM_WARPS_KV * sizeof(DTypeKV)); - - DISPATCH_NUM_MMA_KV( - min(max_num_mma_kv_smem, max_num_mma_kv_reg), NUM_MMA_KV, { - using KTraits = - KernelTraits; - if constexpr (KTraits::IsInvalid()) { - // Invalid configuration, skip - std::ostringstream err_msg; - err_msg - << "FlashInfer Internal Error: Invalid configuration : " - "NUM_MMA_Q=" - << NUM_MMA_Q << " NUM_MMA_D_QK=" << NUM_MMA_D_QK - << " NUM_MMA_D_VO=" << NUM_MMA_D_VO - << " NUM_MMA_KV=" << NUM_MMA_KV - << " NUM_WARPS_Q=" << NUM_WARPS_Q - << " NUM_WARPS_KV=" << NUM_WARPS_KV - << " please create an issue " - "(https://github.com/flashinfer-ai/flashinfer/issues)" - " and report the issue to the developers."; - FLASHINFER_ERROR(err_msg.str()); - } - else { - size_t smem_size = sizeof(typename KTraits::SharedStorage); - auto kernel = - BatchPrefillWithRaggedKVCacheKernel; - FI_GPU_CALL(gpuFuncSetAttribute( - kernel, gpuFuncAttributeMaxDynamicSharedMemorySize, - smem_size)); - if (tmp_v == nullptr) { - // do not partition kv - params.partition_kv = false; - void *args[] = {(void *)¶ms}; - FI_GPU_CALL(gpuLaunchKernel((void *)kernel, nblks, nthrs, - args, smem_size, stream)); - } - else { - // partition kv - params.partition_kv = true; - auto o = params.o; - auto lse = params.lse; - params.o = tmp_v; - params.lse = tmp_s; - void *args[] = {(void *)¶ms}; - FI_GPU_CALL(gpuLaunchKernel((void *)kernel, nblks, nthrs, - args, smem_size, stream)); - if constexpr (AttentionVariant::use_softmax) { - FI_GPU_CALL(VariableLengthMergeStates( - tmp_v, tmp_s, params.merge_indptr, o, lse, - params.max_total_num_rows, params.total_num_rows, - num_qo_heads, HEAD_DIM_VO, stream)); - } - else { - FI_GPU_CALL(VariableLengthAttentionSum( - tmp_v, params.merge_indptr, o, - params.max_total_num_rows, params.total_num_rows, - num_qo_heads, HEAD_DIM_VO, stream)); - } - } - } - }); +template +gpuError_t BatchPrefillWithRaggedKVCacheDispatched(Params params, typename Params::DTypeO* tmp_v, + float* tmp_s, gpuStream_t stream) { + using DTypeQ = typename Params::DTypeQ; + using DTypeKV = typename Params::DTypeKV; + using DTypeO = typename Params::DTypeO; + const uint32_t padded_batch_size = params.padded_batch_size; + const uint32_t num_qo_heads = params.num_qo_heads; + const uint32_t num_kv_heads = params.num_kv_heads; + constexpr uint32_t NUM_MMA_Q = get_num_mma_q(CTA_TILE_Q); + constexpr uint32_t NUM_WARPS_Q = get_num_warps_q(CTA_TILE_Q); + constexpr uint32_t NUM_WARPS_KV = get_num_warps_kv(CTA_TILE_Q); + + if (padded_batch_size == 0) { + // No request, skip + // this won't happen in CUDAGraph mode because we fixed the + // padded_batch_size return gpuSuccess; -} - -template -gpuError_t -BatchPrefillWithPagedKVCacheDispatched(Params params, - typename Params::DTypeO *tmp_v, - float *tmp_s, - gpuStream_t stream) -{ - using DTypeQ = typename Params::DTypeQ; - using DTypeKV = typename Params::DTypeKV; - using DTypeO = typename Params::DTypeO; - const uint32_t padded_batch_size = params.padded_batch_size; - const uint32_t num_qo_heads = params.num_qo_heads; - const uint32_t num_kv_heads = params.paged_kv.num_heads; - constexpr uint32_t NUM_MMA_Q = get_num_mma_q(CTA_TILE_Q); - constexpr uint32_t NUM_WARPS_Q = get_num_warps_q(CTA_TILE_Q); - constexpr uint32_t NUM_WARPS_KV = get_num_warps_kv(CTA_TILE_Q); - - if (padded_batch_size == 0) { - // No request, skip - // this won't happen in CUDAGraph mode because we fixed the - // padded_batch_size - return gpuSuccess; + } + + dim3 nblks(padded_batch_size, 1, num_kv_heads); + dim3 nthrs(WARP_SIZE, NUM_WARPS_Q, NUM_WARPS_KV); + constexpr uint32_t NUM_MMA_D_QK = HEAD_DIM_QK / 16; + constexpr uint32_t NUM_MMA_D_VO = HEAD_DIM_VO / 16; + using DTypeQKAccum = + typename std::conditional, half, + float>::type; + + int dev_id = 0; + FI_GPU_CALL(gpuGetDevice(&dev_id)); + int max_smem_per_sm = getMaxSharedMemPerMultiprocessor(dev_id); + // we expect each sm execute two threadblocks + const int num_ctas_per_sm = + max_smem_per_sm >= 2 * (CTA_TILE_Q * HEAD_DIM_QK * sizeof(DTypeQ) + + (HEAD_DIM_QK + HEAD_DIM_VO) * 16 * NUM_WARPS_KV * sizeof(DTypeKV)) + ? 2 + : 1; + const int max_smem_per_threadblock = max_smem_per_sm / num_ctas_per_sm; + + const uint32_t max_num_mma_kv_reg = + (HEAD_DIM_VO >= 128 && NUM_MMA_Q == 2 && POS_ENCODING_MODE == PosEncodingMode::kRoPELlama && + !USE_FP16_QK_REDUCTION) + ? 2 + : (8 / NUM_MMA_Q); + const uint32_t max_num_mma_kv_smem = + (max_smem_per_threadblock - CTA_TILE_Q * HEAD_DIM_QK * sizeof(DTypeQ)) / + ((HEAD_DIM_QK + HEAD_DIM_VO) * 16 * NUM_WARPS_KV * sizeof(DTypeKV)); + + DISPATCH_NUM_MMA_KV(min(max_num_mma_kv_smem, max_num_mma_kv_reg), NUM_MMA_KV, { + using KTraits = + KernelTraits; + if constexpr (KTraits::IsInvalid()) { + // Invalid configuration, skip + std::ostringstream err_msg; + err_msg << "FlashInfer Internal Error: Invalid configuration : " + "NUM_MMA_Q=" + << NUM_MMA_Q << " NUM_MMA_D_QK=" << NUM_MMA_D_QK << " NUM_MMA_D_VO=" << NUM_MMA_D_VO + << " NUM_MMA_KV=" << NUM_MMA_KV << " NUM_WARPS_Q=" << NUM_WARPS_Q + << " NUM_WARPS_KV=" << NUM_WARPS_KV + << " please create an issue " + "(https://github.com/flashinfer-ai/flashinfer/issues)" + " and report the issue to the developers."; + FLASHINFER_ERROR(err_msg.str()); + } else { + size_t smem_size = sizeof(typename KTraits::SharedStorage); + auto kernel = BatchPrefillWithRaggedKVCacheKernel; + FI_GPU_CALL( + gpuFuncSetAttribute(kernel, gpuFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + if (tmp_v == nullptr) { + // do not partition kv + params.partition_kv = false; + void* args[] = {(void*)¶ms}; + FI_GPU_CALL(gpuLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + } else { + // partition kv + params.partition_kv = true; + auto o = params.o; + auto lse = params.lse; + params.o = tmp_v; + params.lse = tmp_s; + void* args[] = {(void*)¶ms}; + FI_GPU_CALL(gpuLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + if constexpr (AttentionVariant::use_softmax) { + FI_GPU_CALL(VariableLengthMergeStates(tmp_v, tmp_s, params.merge_indptr, o, lse, + params.max_total_num_rows, params.total_num_rows, + num_qo_heads, HEAD_DIM_VO, stream)); + } else { + FI_GPU_CALL(VariableLengthAttentionSum(tmp_v, params.merge_indptr, o, + params.max_total_num_rows, params.total_num_rows, + num_qo_heads, HEAD_DIM_VO, stream)); + } + } } + }); + return gpuSuccess; +} - dim3 nblks(padded_batch_size, 1, num_kv_heads); - dim3 nthrs(WARP_SIZE, NUM_WARPS_Q, NUM_WARPS_KV); - - constexpr uint32_t NUM_MMA_D_QK = HEAD_DIM_QK / 16; - constexpr uint32_t NUM_MMA_D_VO = HEAD_DIM_VO / 16; - using DTypeQKAccum = - typename std::conditional, - half, float>::type; - - int dev_id = 0; - FI_GPU_CALL(gpuGetDevice(&dev_id)); - int max_smem_per_sm = getMaxSharedMemPerMultiprocessor(dev_id); - // we expect each sm execute two threadblocks - const int num_ctas_per_sm = - max_smem_per_sm >= 2 * (CTA_TILE_Q * HEAD_DIM_QK * sizeof(DTypeQ) + - (HEAD_DIM_QK + HEAD_DIM_VO) * 16 * - NUM_WARPS_KV * sizeof(DTypeKV)) - ? 2 - : 1; - const int max_smem_per_threadblock = max_smem_per_sm / num_ctas_per_sm; - - const uint32_t max_num_mma_kv_reg = - (HEAD_DIM_VO >= 128 && NUM_MMA_Q == 2 && - POS_ENCODING_MODE == PosEncodingMode::kRoPELlama && - !USE_FP16_QK_REDUCTION) - ? 2 - : (8 / NUM_MMA_Q); - const uint32_t max_num_mma_kv_smem = - (max_smem_per_threadblock - CTA_TILE_Q * HEAD_DIM_QK * sizeof(DTypeQ)) / - ((HEAD_DIM_QK + HEAD_DIM_VO) * 16 * NUM_WARPS_KV * sizeof(DTypeKV)); - - DISPATCH_NUM_MMA_KV( - min(max_num_mma_kv_smem, max_num_mma_kv_reg), NUM_MMA_KV, { - using KTraits = - KernelTraits; - if constexpr (KTraits::IsInvalid()) { - // Invalid configuration, skip - std::ostringstream err_msg; - err_msg - << "FlashInfer Internal Error: Invalid configuration : " - "NUM_MMA_Q=" - << NUM_MMA_Q << " NUM_MMA_D_QK=" << NUM_MMA_D_QK - << " NUM_MMA_D_VO=" << NUM_MMA_D_VO - << " NUM_MMA_KV=" << NUM_MMA_KV - << " NUM_WARPS_Q=" << NUM_WARPS_Q - << " NUM_WARPS_KV=" << NUM_WARPS_KV - << " please create an issue " - "(https://github.com/flashinfer-ai/flashinfer/issues)" - " and report the issue to the developers."; - FLASHINFER_ERROR(err_msg.str()); - } - else { - size_t smem_size = sizeof(typename KTraits::SharedStorage); - auto kernel = - BatchPrefillWithPagedKVCacheKernel; - FI_GPU_CALL(gpuFuncSetAttribute( - kernel, gpuFuncAttributeMaxDynamicSharedMemorySize, - smem_size)); - if (tmp_v == nullptr) { - // do not partition kv - params.partition_kv = false; - void *args[] = {(void *)¶ms}; - FI_GPU_CALL(gpuLaunchKernel((void *)kernel, nblks, nthrs, - args, smem_size, stream)); - } - else { - params.partition_kv = true; - auto o = params.o; - auto lse = params.lse; - params.o = tmp_v; - params.lse = tmp_s; - void *args[] = {(void *)¶ms}; - FI_GPU_CALL(gpuLaunchKernel((void *)kernel, nblks, nthrs, - args, smem_size, stream)); - if constexpr (AttentionVariant::use_softmax) { - FI_GPU_CALL(VariableLengthMergeStates( - tmp_v, tmp_s, params.merge_indptr, o, lse, - params.max_total_num_rows, params.total_num_rows, - num_qo_heads, HEAD_DIM_VO, stream)); - } - else { - FI_GPU_CALL(VariableLengthAttentionSum( - tmp_v, params.merge_indptr, o, - params.max_total_num_rows, params.total_num_rows, - num_qo_heads, HEAD_DIM_VO, stream)); - } - } - } - }); +template +gpuError_t BatchPrefillWithPagedKVCacheDispatched(Params params, typename Params::DTypeO* tmp_v, + float* tmp_s, gpuStream_t stream) { + using DTypeQ = typename Params::DTypeQ; + using DTypeKV = typename Params::DTypeKV; + using DTypeO = typename Params::DTypeO; + const uint32_t padded_batch_size = params.padded_batch_size; + const uint32_t num_qo_heads = params.num_qo_heads; + const uint32_t num_kv_heads = params.paged_kv.num_heads; + constexpr uint32_t NUM_MMA_Q = get_num_mma_q(CTA_TILE_Q); + constexpr uint32_t NUM_WARPS_Q = get_num_warps_q(CTA_TILE_Q); + constexpr uint32_t NUM_WARPS_KV = get_num_warps_kv(CTA_TILE_Q); + + if (padded_batch_size == 0) { + // No request, skip + // this won't happen in CUDAGraph mode because we fixed the + // padded_batch_size return gpuSuccess; + } + + dim3 nblks(padded_batch_size, 1, num_kv_heads); + dim3 nthrs(WARP_SIZE, NUM_WARPS_Q, NUM_WARPS_KV); + + constexpr uint32_t NUM_MMA_D_QK = HEAD_DIM_QK / 16; + constexpr uint32_t NUM_MMA_D_VO = HEAD_DIM_VO / 16; + using DTypeQKAccum = + typename std::conditional, half, + float>::type; + + int dev_id = 0; + FI_GPU_CALL(gpuGetDevice(&dev_id)); + int max_smem_per_sm = getMaxSharedMemPerMultiprocessor(dev_id); + // we expect each sm execute two threadblocks + const int num_ctas_per_sm = + max_smem_per_sm >= 2 * (CTA_TILE_Q * HEAD_DIM_QK * sizeof(DTypeQ) + + (HEAD_DIM_QK + HEAD_DIM_VO) * 16 * NUM_WARPS_KV * sizeof(DTypeKV)) + ? 2 + : 1; + const int max_smem_per_threadblock = max_smem_per_sm / num_ctas_per_sm; + + const uint32_t max_num_mma_kv_reg = + (HEAD_DIM_VO >= 128 && NUM_MMA_Q == 2 && POS_ENCODING_MODE == PosEncodingMode::kRoPELlama && + !USE_FP16_QK_REDUCTION) + ? 2 + : (8 / NUM_MMA_Q); + const uint32_t max_num_mma_kv_smem = + (max_smem_per_threadblock - CTA_TILE_Q * HEAD_DIM_QK * sizeof(DTypeQ)) / + ((HEAD_DIM_QK + HEAD_DIM_VO) * 16 * NUM_WARPS_KV * sizeof(DTypeKV)); + + DISPATCH_NUM_MMA_KV(min(max_num_mma_kv_smem, max_num_mma_kv_reg), NUM_MMA_KV, { + using KTraits = + KernelTraits; + if constexpr (KTraits::IsInvalid()) { + // Invalid configuration, skip + std::ostringstream err_msg; + err_msg << "FlashInfer Internal Error: Invalid configuration : " + "NUM_MMA_Q=" + << NUM_MMA_Q << " NUM_MMA_D_QK=" << NUM_MMA_D_QK << " NUM_MMA_D_VO=" << NUM_MMA_D_VO + << " NUM_MMA_KV=" << NUM_MMA_KV << " NUM_WARPS_Q=" << NUM_WARPS_Q + << " NUM_WARPS_KV=" << NUM_WARPS_KV + << " please create an issue " + "(https://github.com/flashinfer-ai/flashinfer/issues)" + " and report the issue to the developers."; + FLASHINFER_ERROR(err_msg.str()); + } else { + size_t smem_size = sizeof(typename KTraits::SharedStorage); + auto kernel = BatchPrefillWithPagedKVCacheKernel; + FI_GPU_CALL( + gpuFuncSetAttribute(kernel, gpuFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + if (tmp_v == nullptr) { + // do not partition kv + params.partition_kv = false; + void* args[] = {(void*)¶ms}; + FI_GPU_CALL(gpuLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + } else { + params.partition_kv = true; + auto o = params.o; + auto lse = params.lse; + params.o = tmp_v; + params.lse = tmp_s; + void* args[] = {(void*)¶ms}; + FI_GPU_CALL(gpuLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + if constexpr (AttentionVariant::use_softmax) { + FI_GPU_CALL(VariableLengthMergeStates(tmp_v, tmp_s, params.merge_indptr, o, lse, + params.max_total_num_rows, params.total_num_rows, + num_qo_heads, HEAD_DIM_VO, stream)); + } else { + FI_GPU_CALL(VariableLengthAttentionSum(tmp_v, params.merge_indptr, o, + params.max_total_num_rows, params.total_num_rows, + num_qo_heads, HEAD_DIM_VO, stream)); + } + } + } + }); + return gpuSuccess; } -} // namespace flashinfer +} // namespace flashinfer -#endif // FLASHINFER_PREFILL_CUH_ +#endif // FLASHINFER_PREFILL_CUH_ From 3e0fa2a44e3f02466978dbfb7cb315e447bb7f17 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Tue, 16 Sep 2025 15:16:51 -0400 Subject: [PATCH 086/109] Silence warnings --- libflashinfer/include/flashinfer/attention/generic/page.cuh | 1 - .../include/flashinfer/attention/generic/prefill.cuh | 3 --- libflashinfer/utils/cpu_reference_hip.h | 4 +--- 3 files changed, 1 insertion(+), 7 deletions(-) diff --git a/libflashinfer/include/flashinfer/attention/generic/page.cuh b/libflashinfer/include/flashinfer/attention/generic/page.cuh index 28ed38bebc..14df54be80 100644 --- a/libflashinfer/include/flashinfer/attention/generic/page.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/page.cuh @@ -262,7 +262,6 @@ __global__ void AppendPagedKVCacheKernel(paged_kv_t paged_kv, size_t append_k_stride_n, size_t append_k_stride_h, size_t append_v_stride_n, size_t append_v_stride_h) { uint32_t tx = threadIdx.x, ty = threadIdx.y; - uint32_t num_heads = paged_kv.num_heads; uint32_t head_idx = ty; uint32_t cta_id = blockIdx.x; uint32_t num_ctas = gridDim.x; diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh index d9df683e80..86941c687c 100644 --- a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -384,7 +384,6 @@ __device__ __forceinline__ void produce_kv_impl_cdna3_( static_assert(KTraits::SWIZZLE_MODE_KV == SwizzleMode::kLinear); using DTypeKV = typename KTraits::DTypeKV; constexpr uint32_t KV_THR_LAYOUT_COL = KTraits::KV_THR_LAYOUT_COL; // 16 - constexpr uint32_t KV_THR_LAYOUT_ROW = KTraits::KV_THR_LAYOUT_ROW; // 4 constexpr uint32_t NUM_WARPS = KTraits::NUM_WARPS; constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; constexpr uint32_t NUM_WARPS_Q = KTraits::NUM_WARPS_Q; @@ -392,7 +391,6 @@ __device__ __forceinline__ void produce_kv_impl_cdna3_( constexpr uint32_t UPCAST_STRIDE = produce_v ? KTraits::UPCAST_STRIDE_V : KTraits::UPCAST_STRIDE_K; constexpr uint32_t VECTOR_BIT_WIDTH = KTraits::VECTOR_BIT_WIDTH; - constexpr uint32_t HALF_ELEMS_PER_THREAD = KTraits::HALF_ELEMS_PER_THREAD; // NOTE: NUM_MMA_KV*4/NUM_WARPS_Q = NUM_WARPS_KV*NUM_MMA_KV*4/num_warps static_assert(NUM_MMA_KV * 4 % NUM_WARPS_Q == 0); @@ -1594,7 +1592,6 @@ template __device__ __forceinline__ void debug_write_sfrag_to_scratch( typename KTraits::DTypeQKAccum (*s_frag)[KTraits::NUM_MMA_KV][KTraits::HALF_ELEMS_PER_THREAD], const dim3 tid = threadIdx, uint32_t debug_thread_id = 0, uint32_t debug_warp_id = 0) { - using DTypeQKAccum = typename KTraits::DTypeQKAccum; constexpr uint32_t NUM_MMA_Q = KTraits::NUM_MMA_Q; constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; const uint32_t warp_idx = get_warp_idx(tid.y, tid.z), lane_idx = tid.x; diff --git a/libflashinfer/utils/cpu_reference_hip.h b/libflashinfer/utils/cpu_reference_hip.h index cc335f1cf6..f502ded1f3 100644 --- a/libflashinfer/utils/cpu_reference_hip.h +++ b/libflashinfer/utils/cpu_reference_hip.h @@ -67,7 +67,7 @@ inline std::vector apply_llama_rope(const T* input, size_t D, size_t offs if (std::is_same_v) rst[k] = cos * fi::con::explicit_casting(input[k]) + sin * permuted_input[k]; } - return std::move(rst); + return rst; } template @@ -87,8 +87,6 @@ std::vector single_mha(const std::vector& q, const std::vect std::vector q_rotary_local(head_dim); std::vector k_rotary_local(head_dim); - float soft_cap_pre_tanh_scale = sm_scale / logits_soft_cap; - DISPATCH_head_dim(head_dim, HEAD_DIM, { tensor_info_t info(qo_len, kv_len, num_qo_heads, num_kv_heads, kv_layout, HEAD_DIM); #if Debug1 From 4c8e5745b83c04ebe8afadd91f4bbed80183b370 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Wed, 17 Sep 2025 02:20:02 -0400 Subject: [PATCH 087/109] Remove redundant files --- .../attention/generic/prefill_tester.cuh | 2234 ----------------- libflashinfer/utils/compute_qk_stub.cuh | 332 --- 2 files changed, 2566 deletions(-) delete mode 100644 libflashinfer/include/flashinfer/attention/generic/prefill_tester.cuh delete mode 100644 libflashinfer/utils/compute_qk_stub.cuh diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill_tester.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill_tester.cuh deleted file mode 100644 index 93903d0d4a..0000000000 --- a/libflashinfer/include/flashinfer/attention/generic/prefill_tester.cuh +++ /dev/null @@ -1,2234 +0,0 @@ -// SPDX - FileCopyrightText : 2023-2025 FlashInfer team. -// SPDX - FileCopyrightText : 2025 Advanced Micro Devices, Inc. -// -// SPDX - License - Identifier : Apache - 2.0 -#ifndef FLASHINFER_PREFILL_CUH_ -#define FLASHINFER_PREFILL_CUH_ - -#include "gpu_iface/cooperative_groups.h" -#include "gpu_iface/fastdiv.cuh" -#include "gpu_iface/math_ops.hpp" -#include "gpu_iface/memory_ops.hpp" -#include "gpu_iface/mma_ops.hpp" -#include "gpu_iface/platform.hpp" -#include "gpu_iface/utils.cuh" - -#ifdef FP16_QK_REDUCTION_SUPPORTED -#include "../../fp16.h" -#endif -#include "frag_layout_swizzle.cuh" - -#include "cascade.cuh" -#include "dispatch.cuh" -#include "page.cuh" -#include "permuted_smem.cuh" -#include "pos_enc.cuh" -#include "variants.cuh" - -namespace flashinfer -{ - -DEFINE_HAS_MEMBER(maybe_q_rope_offset) -DEFINE_HAS_MEMBER(maybe_k_rope_offset) - -namespace cg = flashinfer::gpu_iface::cg; -namespace memory = flashinfer::gpu_iface::memory; -namespace mma = gpu_iface::mma; - -using gpu_iface::vec_dtypes::vec_cast; -using mma::MMAMode; - -constexpr uint32_t WARP_SIZE = gpu_iface::kWarpSize; - -constexpr uint32_t get_num_warps_q(const uint32_t cta_tile_q) -{ - if (cta_tile_q > 16) { - return 4; - } - else { - return 1; - } -} - -constexpr uint32_t get_num_warps_kv(const uint32_t cta_tile_kv) -{ - return 4 / get_num_warps_q(cta_tile_kv); -} - -constexpr uint32_t get_num_mma_q(const uint32_t cta_tile_q) -{ - if (cta_tile_q > 64) { - return 2; - } - else { - return 1; - } -} - -template -struct SharedStorageQKVO -{ - union - { - struct - { - alignas(16) DTypeQ q_smem[CTA_TILE_Q * HEAD_DIM_QK]; - alignas(16) DTypeKV k_smem[CTA_TILE_KV * HEAD_DIM_QK]; - alignas(16) DTypeKV v_smem[CTA_TILE_KV * HEAD_DIM_VO]; - }; - struct - { // NOTE(Zihao): synchronize attention states across warps - alignas(16) std::conditional_t< - NUM_WARPS_KV == 1, - float[1], - float[NUM_WARPS_KV * CTA_TILE_Q * HEAD_DIM_VO]> cta_sync_o_smem; - alignas(16) std::conditional_t< - NUM_WARPS_KV == 1, - float2[1], - float2[NUM_WARPS_KV * CTA_TILE_Q]> cta_sync_md_smem; - }; - alignas(16) DTypeO smem_o[CTA_TILE_Q * HEAD_DIM_VO]; - }; -}; - -template -struct KernelTraits -{ - static constexpr MaskMode MASK_MODE = MASK_MODE_; - static constexpr uint32_t NUM_MMA_Q = NUM_MMA_Q_; - static constexpr uint32_t NUM_MMA_KV = NUM_MMA_KV_; - static constexpr uint32_t NUM_MMA_D_QK = NUM_MMA_D_QK_; - static constexpr uint32_t NUM_MMA_D_VO = NUM_MMA_D_VO_; - static constexpr uint32_t NUM_WARPS_Q = NUM_WARPS_Q_; - static constexpr uint32_t NUM_WARPS_KV = NUM_WARPS_KV_; - static constexpr uint32_t NUM_WARPS = NUM_WARPS_Q * NUM_WARPS_KV; - static constexpr uint32_t HEAD_DIM_QK = NUM_MMA_D_QK * 16; - static constexpr uint32_t HEAD_DIM_VO = NUM_MMA_D_VO * 16; - static constexpr uint32_t CTA_TILE_Q = CTA_TILE_Q_; - static constexpr uint32_t CTA_TILE_KV = NUM_MMA_KV * NUM_WARPS_KV * 16; - static constexpr PosEncodingMode POS_ENCODING_MODE = POS_ENCODING_MODE_; - - using DTypeQ = DTypeQ_; - using DTypeKV = DTypeKV_; - using DTypeO = DTypeO_; - using DTypeQKAccum = DTypeQKAccum_; - using IdType = IdType_; - using AttentionVariant = AttentionVariant_; - - static_assert(sizeof(DTypeKV_) != 1, "8-bit types not supported for CDNA3"); - - using SmemBasePtrTy = uint2; - static constexpr uint32_t NUM_THREADS = NUM_WARPS_Q * NUM_WARPS_KV * 64; - static constexpr uint32_t WARP_THREAD_ROWS = 4; - static constexpr uint32_t WARP_THREAD_COLS = 16; - static constexpr uint32_t HALF_ELEMS_PER_THREAD = 4; - static constexpr uint32_t INT32_ELEMS_PER_THREAD = 2; - static constexpr uint32_t VECTOR_BIT_WIDTH = HALF_ELEMS_PER_THREAD * 16; - // FIXME: Update with a proper swizzle pattern. Linear is used primarily - // for intial testing. - static constexpr SwizzleMode SWIZZLE_MODE_Q = SwizzleMode::kLinear; - static constexpr SwizzleMode SWIZZLE_MODE_KV = SwizzleMode::kLinear; - - // Presently we use 16x4 thread layout for all cases. - static constexpr uint32_t KV_THR_LAYOUT_ROW = WARP_THREAD_ROWS; - static constexpr uint32_t KV_THR_LAYOUT_COL = WARP_THREAD_COLS; - // The constant is defined based on the matrix layout of the "D/C" - // accumulator matrix in a D = A*B+C computation. On CDNA3 the D/C matrices - // are distributed as four 4x16 bands across the 64 threads. Each thread - // owns one element from four different rows. - static constexpr uint32_t NUM_ACCUM_ROWS_PER_THREAD = 4; - // Number of threads that collaboratively handle the same set of matrix rows - // in attention score computation and cross-warp synchronization. - // CUDA: 4 threads (each thread handles 2 elements from same row group) - // CDNA3: 16 threads (each thread handles 1 element from same row group) - static constexpr uint32_t THREADS_PER_MATRIX_ROW_SET = 16; - // controls the indexing stride used in logits-related functions - // (logits_transform, logits_mask, and LSE writing). - static constexpr uint32_t LOGITS_INDEX_STRIDE = 4; - - static constexpr uint32_t UPCAST_STRIDE_Q = - HEAD_DIM_QK / upcast_size(); - static constexpr uint32_t UPCAST_STRIDE_K = - HEAD_DIM_QK / upcast_size(); - static constexpr uint32_t UPCAST_STRIDE_V = - HEAD_DIM_VO / upcast_size(); - static constexpr uint32_t UPCAST_STRIDE_O = - HEAD_DIM_VO / upcast_size(); - - static constexpr bool IsInvalid() - { - return ((NUM_MMA_D_VO < 4) || - (NUM_MMA_D_VO == 4 && NUM_MMA_KV % 2 == 1) || - (POS_ENCODING_MODE == PosEncodingMode::kRoPELlama && - NUM_MMA_D_VO > 4 && NUM_MMA_D_VO % (2 * NUM_WARPS_Q) != 0) || - (NUM_MMA_Q * (8 * NUM_MMA_D_VO + - 2 * sizeof(DTypeQKAccum) * NUM_MMA_KV) >= - 256) || - (sizeof(DTypeKV) == 1 && NUM_MMA_KV * 2 % NUM_WARPS_Q != 0) || - (sizeof(DTypeKV) == 1 && - POS_ENCODING_MODE == PosEncodingMode::kRoPELlama)); - } - - using SharedStorage = SharedStorageQKVO; -#ifdef FP16_QK_REDUCTION_SUPPORTED - template static constexpr DT getNegInf() - { - if constexpr (std::is_same::value) { - return std::bit_cast( - fp16_ieee_from_fp32_value(-gpu_iface::math::inf)); - } - else { - return static_cast(-gpu_iface::math::inf); - } - } - - static constexpr DTypeQKAccum MaskFillValue = - AttentionVariant::use_softmax ? getNegInf() - : DTypeQKAccum(0.f); -#else - static_assert(!std::is_same::value, - "Set -DFP16_QK_REDUCTION_SUPPORTED and install boost_math " - "then recompile to support fp16 reduction"); - static constexpr DTypeQKAccum MaskFillValue = - AttentionVariant::use_softmax ? DTypeQKAccum(-gpu_iface::math::inf) - : DTypeQKAccum(0.f); -#endif -}; - -namespace -{ - -template -__device__ __forceinline__ uint32_t -get_warp_idx_q(const uint32_t tid_y = threadIdx.y) -{ - if constexpr (KTraits::NUM_WARPS_Q == 1) { - return 0; - } - else { - return tid_y; - } -} - -template -__device__ __forceinline__ uint32_t -get_warp_idx_kv(const uint32_t tid_z = threadIdx.z) -{ - if constexpr (KTraits::NUM_WARPS_KV == 1) { - return 0; - } - else { - return tid_z; - } -} - -template -__device__ __forceinline__ uint32_t -get_warp_idx(const uint32_t tid_y = threadIdx.y, - const uint32_t tid_z = threadIdx.z) -{ - return get_warp_idx_kv(tid_z) * KTraits::NUM_WARPS_Q + - get_warp_idx_q(tid_y); -} - -/*! - * \brief Apply Llama style rotary embedding to two 16x16 fragments. - * \tparam T The data type of the input fragments. - * \param x_first_half First fragment x[offset:offset+16, j*16:(j+1)*16] - * \param x_second_half Second fragment x[offset:offset*16, - * j*16+d/2:(j+1)*16+d/2] - * \param rope_freq Rope frequency - * \param offset The offset of the first row in both fragments. - * \note The sin/cos computation is slow, especially for A100 GPUs which has low - * non tensor-ops flops, will optimize in the future. - */ -template -__device__ __forceinline__ void -k_frag_apply_llama_rope(T *x_first_half, - T *x_second_half, - const float *rope_freq, - const uint32_t kv_offset) -{ - static_assert(sizeof(T) == 2); -#pragma unroll - for (uint32_t reg_id = 0; reg_id < HALF_ELEMS_PER_THREAD; ++reg_id) { - float cos, sin, tmp; - // 0 1 | 2 3 - // --------- - // 4 5 | 6 7 - - uint32_t i = reg_id / 2, j = reg_id % 2; - __sincosf(float(kv_offset + 8 * i) * rope_freq[2 * j + reg_id % 2], - &sin, &cos); - tmp = x_first_half[reg_id]; - x_first_half[reg_id] = (tmp * cos - (float)x_second_half[reg_id] * sin); - x_second_half[reg_id] = - ((float)x_second_half[reg_id] * cos + tmp * sin); - } -} - -template -__device__ __forceinline__ void -q_frag_apply_llama_rope(T *x_first_half, - T *x_second_half, - const float *rope_freq, - const uint32_t qo_packed_offset, - const uint_fastdiv group_size) -{ -#pragma unroll - for (uint32_t reg_id = 0; reg_id < HALF_ELEMS_PER_THREAD; ++reg_id) { - float cos, sin, tmp; - // 0 1 | 4 5 - // --------- - // 2 3 | 6 7 - - // // Same sequence for all 4 features - // uint32_t i = 0; - // Direct mapping to frequency array - uint32_t freq_idx = reg_id; - // Same position for this thread's sequence - uint32_t position = qo_packed_offset; - - __sincosf(float(position / group_size) * rope_freq[freq_idx], &sin, - &cos); - tmp = x_first_half[reg_id]; - x_first_half[reg_id] = (tmp * cos - (float)x_second_half[reg_id] * sin); - x_second_half[reg_id] = - ((float)x_second_half[reg_id] * cos + tmp * sin); - } -} - -template -__device__ __forceinline__ void -q_frag_apply_llama_rope_with_pos(T *x_first_half, - T *x_second_half, - const float *rope_freq, - const uint32_t qo_packed_offset, - const uint_fastdiv group_size, - const IdType *q_rope_offset) -{ - float pos[2] = { - static_cast(q_rope_offset[qo_packed_offset / group_size]), - static_cast(q_rope_offset[(qo_packed_offset + 8) / group_size])}; -#pragma unroll - for (uint32_t reg_id = 0; reg_id < HALF_ELEMS_PER_THREAD; ++reg_id) { - float cos, sin, tmp; - // 0 1 | 4 5 - // --------- - // 2 3 | 6 7 - - const uint32_t i = reg_id / 2; - const uint32_t j = reg_id % 2; - - __sincosf(pos[i] * rope_freq[2 * j + reg_id % 2], &sin, &cos); - tmp = x_first_half[reg_id]; - x_first_half[reg_id] = (tmp * cos - (float)x_second_half[reg_id] * sin); - x_second_half[reg_id] = - ((float)x_second_half[reg_id] * cos + tmp * sin); - } -} - -template -__device__ __forceinline__ void produce_kv_impl_cuda_( - uint32_t warp_idx, - uint32_t lane_idx, - smem_t smem, - uint32_t *smem_offset, - typename KTraits::DTypeKV **gptr, - const uint32_t stride_n, - const uint32_t kv_idx_base, - const uint32_t kv_len) -{ - using DTypeKV = typename KTraits::DTypeKV; - constexpr uint32_t NUM_WARPS = KTraits::NUM_WARPS; - constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; - constexpr uint32_t NUM_WARPS_Q = KTraits::NUM_WARPS_Q; - constexpr uint32_t NUM_MMA_D = - produce_v ? KTraits::NUM_MMA_D_VO : KTraits::NUM_MMA_D_QK; - constexpr uint32_t UPCAST_STRIDE = - produce_v ? KTraits::UPCAST_STRIDE_V : KTraits::UPCAST_STRIDE_K; - constexpr uint32_t VECTOR_BIT_WIDTH = KTraits::VECTOR_BIT_WIDTH; - - if constexpr (KTraits::SWIZZLE_MODE_KV == SwizzleMode::k128B) { - uint32_t kv_idx = kv_idx_base + warp_idx * 4 + lane_idx / 8; - // NOTE: NUM_MMA_KV * 4 / NUM_WARPS_Q = NUM_WARPS_KV * NUM_MMA_KV * 4 / - // num_warps - static_assert(NUM_MMA_KV * 4 % NUM_WARPS_Q == 0); -#pragma unroll - for (uint32_t i = 0; i < NUM_MMA_KV * 4 / NUM_WARPS_Q; ++i) { -#pragma unroll - for (uint32_t j = 0; j < NUM_MMA_D / (8 / sizeof(DTypeKV)); ++j) { - smem.template load_128b_async(*smem_offset, *gptr, - kv_idx < kv_len); - *smem_offset = - smem.template advance_offset_by_column<8>(*smem_offset, j); - *gptr += 8 * upcast_size(); - } - kv_idx += NUM_WARPS * 4; - *smem_offset = smem.template advance_offset_by_row( - *smem_offset) - - sizeof(DTypeKV) * NUM_MMA_D; - *gptr += NUM_WARPS * 4 * stride_n - - sizeof(DTypeKV) * NUM_MMA_D * - upcast_size(); - } - *smem_offset -= KTraits::CTA_TILE_KV * UPCAST_STRIDE; - } - else { - uint32_t kv_idx = kv_idx_base + warp_idx * 8 + lane_idx / 4; - // NOTE: NUM_MMA_KV * 2 / NUM_WARPS_Q = NUM_WARPS_KV * NUM_MMA_KV * 2 / - // num_warps - static_assert(NUM_MMA_KV * 2 % NUM_WARPS_Q == 0); -#pragma unroll - for (uint32_t i = 0; i < NUM_MMA_KV * 2 / NUM_WARPS_Q; ++i) { - smem.template load_128b_async(*smem_offset, *gptr, - kv_idx < kv_len); - *smem_offset = smem.template advance_offset_by_row( - *smem_offset); - kv_idx += NUM_WARPS * 8; - *gptr += NUM_WARPS * 8 * stride_n; - } - *smem_offset -= KTraits::CTA_TILE_KV * UPCAST_STRIDE; - } -} - -template -__device__ __forceinline__ void produce_kv_impl_cdna3_( - uint32_t warp_idx, - uint32_t lane_idx, - smem_t smem, - uint32_t *smem_offset, - typename KTraits::DTypeKV **gptr, - const uint32_t stride_n, - const uint32_t kv_idx_base, - const uint32_t kv_len) -{ - static_assert(KTraits::SWIZZLE_MODE_KV == SwizzleMode::kLinear); - using DTypeKV = typename KTraits::DTypeKV; - constexpr uint32_t KV_THR_LAYOUT_COL = KTraits::KV_THR_LAYOUT_COL; // 16 - constexpr uint32_t KV_THR_LAYOUT_ROW = KTraits::KV_THR_LAYOUT_ROW; // 4 - constexpr uint32_t NUM_WARPS = KTraits::NUM_WARPS; - constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; - constexpr uint32_t NUM_WARPS_Q = KTraits::NUM_WARPS_Q; - constexpr uint32_t NUM_MMA_D = - produce_v ? KTraits::NUM_MMA_D_VO : KTraits::NUM_MMA_D_QK; - constexpr uint32_t UPCAST_STRIDE = - produce_v ? KTraits::UPCAST_STRIDE_V : KTraits::UPCAST_STRIDE_K; - constexpr uint32_t VECTOR_BIT_WIDTH = KTraits::VECTOR_BIT_WIDTH; - constexpr uint32_t HALF_ELEMS_PER_THREAD = - KTraits::HALF_ELEMS_PER_THREAD; // 4 - - // CDNA3-specific constants - constexpr uint32_t SEQUENCES_PER_MMA_TILE = 16; - constexpr uint32_t SEQUENCES_PER_THREAD_GROUP = KV_THR_LAYOUT_ROW; // 4 - constexpr uint32_t THREAD_GROUPS_PER_MMA_TILE = - SEQUENCES_PER_MMA_TILE / SEQUENCES_PER_THREAD_GROUP; // 4 - constexpr uint32_t FEATURE_CHUNKS_PER_THREAD_GROUP = - NUM_MMA_D / HALF_ELEMS_PER_THREAD; // NUM_MMA_D/4 - constexpr uint32_t COLUMN_RESET_OFFSET = - FEATURE_CHUNKS_PER_THREAD_GROUP * KV_THR_LAYOUT_COL; - - uint32_t row = lane_idx / KV_THR_LAYOUT_COL; - uint32_t kv_idx = kv_idx_base + warp_idx * KV_THR_LAYOUT_ROW + row; - - // NOTE: NUM_MMA_KV*4/NUM_WARPS_Q = NUM_WARPS_KV*NUM_MMA_KV*4/num_warps - static_assert(NUM_MMA_KV * 4 % NUM_WARPS_Q == 0); - -#pragma unroll - for (uint32_t i = 0; i < NUM_MMA_KV * 4 / NUM_WARPS_Q; ++i) - { // MMA tile iterations - - // CDNA3: Load complete 16×HEAD_DIM tile per i iteration -#pragma unroll - for (uint32_t k = 0; k < THREAD_GROUPS_PER_MMA_TILE; ++k) - { // 4 sequence groups -#pragma unroll - for (uint32_t j = 0; j < FEATURE_CHUNKS_PER_THREAD_GROUP; ++j) - { // Feature chunks - smem.template load_vector_async(*smem_offset, *gptr, - kv_idx < kv_len); - - // Advance to next feature chunk (same sequence group) - *smem_offset = - smem.template advance_offset_by_column( - *smem_offset, j); - *gptr += KV_THR_LAYOUT_COL * - upcast_size(); - } - - // Advance to next sequence group within same MMA tile - if (k < THREAD_GROUPS_PER_MMA_TILE - 1) - { // Don't advance after last group - kv_idx += NUM_WARPS * KV_THR_LAYOUT_ROW; - *smem_offset = - smem.template advance_offset_by_row< - NUM_WARPS * KV_THR_LAYOUT_ROW, UPCAST_STRIDE>( - *smem_offset) - - COLUMN_RESET_OFFSET; - *gptr += NUM_WARPS * KV_THR_LAYOUT_ROW * stride_n - - FEATURE_CHUNKS_PER_THREAD_GROUP * KV_THR_LAYOUT_COL * - upcast_size(); - } - } - - // Final advance to next MMA tile - kv_idx += NUM_WARPS * KV_THR_LAYOUT_ROW; - *smem_offset = - smem.template advance_offset_by_row(*smem_offset) - - COLUMN_RESET_OFFSET; - *gptr += NUM_WARPS * KV_THR_LAYOUT_ROW * stride_n - - FEATURE_CHUNKS_PER_THREAD_GROUP * KV_THR_LAYOUT_COL * - upcast_size(); - } - *smem_offset -= KTraits::CTA_TILE_KV * UPCAST_STRIDE; -} - -/*! - * \brief Produce k/v fragments from global memory to shared memory. - * \tparam fill_mode The fill mode of the shared memory. - * \tparam NUM_MMA_D_VO The number of fragments in y dimension. - * \tparam NUM_MMA_KV The number of fragments in z dimension. - * \tparam num_warps The number of warps in the threadblock. - * \tparam T The data type of the input tensor. - * \param smem The shared memory to store kv fragments. - * \param gptr The global memory pointer. - * \param kv_idx_base The base kv index. - * \param kv_len The length of kv tensor. - */ -template -__device__ __forceinline__ void produce_kv( - smem_t smem, - uint32_t *smem_offset, - typename KTraits::DTypeKV **gptr, - const uint32_t stride_n, - const uint32_t kv_idx_base, - const uint32_t kv_len, - const dim3 tid = threadIdx) -{ - // NOTE: for fp8, this function doesn't work for head_dim = 64 at the moment - const uint32_t warp_idx = get_warp_idx(tid.y, tid.z), - lane_idx = tid.x; - - produce_kv_impl_cdna3_( - warp_idx, lane_idx, smem, smem_offset, gptr, stride_n, kv_idx_base, - kv_len); -} - -template -__device__ __forceinline__ void page_produce_kv( - smem_t smem, - uint32_t *smem_offset, - const paged_kv_t - &paged_kv, - const uint32_t kv_idx_base, - const size_t *thr_local_kv_offset, - const uint32_t kv_len, - const dim3 tid = threadIdx) -{ - // NOTE: for fp8, this function doesn't work for head_dim = 64 at the moment - using DType = typename KTraits::DTypeKV; - constexpr SharedMemFillMode fill_mode = - produce_v ? SharedMemFillMode::kFillZero : SharedMemFillMode::kNoFill; - constexpr uint32_t NUM_WARPS = KTraits::NUM_WARPS; - constexpr uint32_t NUM_WARPS_Q = KTraits::NUM_WARPS_Q; - constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; - constexpr uint32_t NUM_MMA_D = - produce_v ? KTraits::NUM_MMA_D_VO : KTraits::NUM_MMA_D_QK; - constexpr uint32_t UPCAST_STRIDE = - produce_v ? KTraits::UPCAST_STRIDE_V : KTraits::UPCAST_STRIDE_K; - constexpr uint32_t VECTOR_BIT_WIDTH = KTraits::VECTOR_BIT_WIDTH; - - const uint32_t warp_idx = get_warp_idx(tid.y, tid.z), - lane_idx = tid.x; - if constexpr (KTraits::SWIZZLE_MODE_KV == SwizzleMode::k128B) { - uint32_t kv_idx = kv_idx_base + warp_idx * 4 + lane_idx / 8; - // NOTE: NUM_MMA_KV * 4 / NUM_WARPS_Q = NUM_WARPS_KV * NUM_MMA_KV * 4 / - // num_warps - static_assert(NUM_MMA_KV * 4 % NUM_WARPS_Q == 0); -#pragma unroll - for (uint32_t i = 0; i < NUM_MMA_KV * 4 / NUM_WARPS_Q; ++i) { - DType *gptr = produce_v ? paged_kv.v_data + thr_local_kv_offset[i] - : paged_kv.k_data + thr_local_kv_offset[i]; -#pragma unroll - for (uint32_t j = 0; j < NUM_MMA_D / (8 / sizeof(DType)); ++j) { - smem.template load_vector_async(*smem_offset, gptr, - kv_idx < kv_len); - *smem_offset = - smem.template advance_offset_by_column<8>(*smem_offset, j); - gptr += 8 * upcast_size(); - } - kv_idx += NUM_WARPS * 4; - *smem_offset = smem.template advance_offset_by_row( - *smem_offset) - - sizeof(DType) * NUM_MMA_D; - } - *smem_offset -= KTraits::CTA_TILE_KV * UPCAST_STRIDE; - } - else { - uint32_t kv_idx = kv_idx_base + warp_idx * 8 + lane_idx / 4; - // NOTE: NUM_MMA_KV * 2 / NUM_WARPS_Q = NUM_WARPS_KV * NUM_MMA_KV * 2 / - // num_warps - static_assert(NUM_MMA_KV * 2 % NUM_WARPS_Q == 0); -#pragma unroll - for (uint32_t i = 0; i < NUM_MMA_KV * 2 / NUM_WARPS_Q; ++i) { - DType *gptr = produce_v ? paged_kv.v_data + thr_local_kv_offset[i] - : paged_kv.k_data + thr_local_kv_offset[i]; - smem.template load_vector_async(*smem_offset, gptr, - kv_idx < kv_len); - kv_idx += NUM_WARPS * 8; - *smem_offset = smem.template advance_offset_by_row( - *smem_offset); - } - *smem_offset -= KTraits::CTA_TILE_KV * UPCAST_STRIDE; - } -} - -__device__ __forceinline__ uint32_t get_feature_index(uint32_t j) -{ - - // CDNA3 A-matrix MMA tile to thread mapping for a 64-thread wavefront: - // Each group of 16 threads handles the same four consecutive features for - // different sequences: - // T0-T15: Features [0,1,2,3] for sequences 0-15 respectively - // T16-T31: Features [4,5,6,7] for sequences 0-15 respectively - // T32-T47: Features [8,9,10,11] for sequences 0-15 respectively - // T48-T63: Features [12,13,14,15] for sequences 0-15 respectively - // - uint32_t feature_index = (mma_d * 16 + (lane_idx / 4) + j) % (HEAD_DIM / 2); - - return feature_index; -} - -template -__device__ __forceinline__ void -init_rope_freq(float (*rope_freq)[4], - const float rope_rcp_scale, - const float rope_rcp_theta, - const uint32_t tid_x = threadIdx.x) -{ - constexpr uint32_t HEAD_DIM = KTraits::NUM_MMA_D_QK * 16; - const uint32_t lane_idx = tid_x; - -#pragma unroll - for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO / 2; ++mma_d) { -#pragma unroll - for (uint32_t j = 0; j < 4; ++j) { - rope_freq[mma_d][j] = - rope_rcp_scale * - __powf(rope_rcp_theta, - float(2 * get_feature_index(j)) / float(HEAD_DIM)); - } - } -} - -template -__device__ __forceinline__ void init_states( - typename KTraits::AttentionVariant variant, - float (*o_frag)[KTraits::NUM_MMA_D_VO][KTraits::HALF_ELEMS_PER_THREAD], - typename KTraits::DTypeQKAccum (*m)[KTraits::NUM_ACCUM_ROWS_PER_THREAD], - float (*d)[KTraits::NUM_ACCUM_ROWS_PER_THREAD]) -{ - constexpr uint32_t NUM_ACCUM_ROWS_PER_THREAD = - KTraits::NUM_ACCUM_ROWS_PER_THREAD; -#pragma unroll - for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { -#pragma unroll - for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { -#pragma unroll - for (uint32_t reg_id = 0; reg_id < KTraits::HALF_ELEMS_PER_THREAD; - ++reg_id) - { - o_frag[mma_q][mma_d][reg_id] = 0.f; - } - } - } - - if constexpr (variant.use_softmax) { -#pragma unroll - for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { -#pragma unroll - for (uint32_t j = 0; j < NUM_ACCUM_ROWS_PER_THREAD; ++j) { - m[mma_q][j] = - typename KTraits::DTypeQKAccum(-gpu_iface::math::inf); - d[mma_q][j] = 1.f; - } - } - } -} - -template -__device__ __forceinline__ void load_q_global_smem( - uint32_t packed_offset, - const uint32_t qo_upper_bound, - typename KTraits::DTypeQ *q_ptr_base, - const uint32_t q_stride_n, - const uint32_t q_stride_h, - const uint_fastdiv group_size, - smem_t *q_smem, - const dim3 tid = threadIdx) -{ - using DTypeQ = typename KTraits::DTypeQ; - constexpr uint32_t WARP_THREAD_COLS = KTraits::WARP_THREAD_COLS; - constexpr uint32_t WARP_THREAD_ROWS = KTraits::WARP_THREAD_ROWS; - constexpr uint32_t HALF_ELEMS_PER_THREAD = KTraits::HALF_ELEMS_PER_THREAD; - constexpr uint32_t NUM_MMA_D_QK = KTraits::NUM_MMA_D_QK; - constexpr uint32_t UPCAST_STRIDE_Q = KTraits::UPCAST_STRIDE_Q; - constexpr uint32_t VECTOR_BIT_WIDTH = KTraits::VECTOR_BIT_WIDTH; - - constexpr uint32_t COLUMN_RESET_OFFSET = - (NUM_MMA_D_QK / 4) * WARP_THREAD_COLS; - - const uint32_t lane_idx = tid.x, - warp_idx_x = get_warp_idx_q(tid.y); - uint32_t row = lane_idx / WARP_THREAD_COLS; - uint32_t col = lane_idx % WARP_THREAD_COLS; - - if (get_warp_idx_kv(tid.z) == 0) { - uint32_t q_smem_offset_w = - q_smem->template get_permuted_offset( - warp_idx_x * KTraits::NUM_MMA_Q * 16 + row, col); - -#pragma unroll - for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { -#pragma unroll - for (uint32_t j = 0; j < 2 * 2; ++j) { - uint32_t q, r; - group_size.divmod(packed_offset + row + mma_q * 16 + j * 4, q, - r); - const uint32_t q_idx = q; - DTypeQ *q_ptr = q_ptr_base + q * q_stride_n + r * q_stride_h + - col * upcast_size(); -#pragma unroll - for (uint32_t mma_do = 0; mma_do < KTraits::NUM_MMA_D_QK / 4; - ++mma_do) - { - // load q fragment from gmem to smem - q_smem->template load_vector_async< - SharedMemFillMode::kNoFill>(q_smem_offset_w, q_ptr, - q_idx < qo_upper_bound); - q_smem_offset_w = q_smem->template advance_offset_by_column< - WARP_THREAD_COLS>(q_smem_offset_w, mma_do); - q_ptr += HALF_ELEMS_PER_THREAD * - upcast_size(); - } - q_smem_offset_w = - q_smem->template advance_offset_by_row( - q_smem_offset_w) - - COLUMN_RESET_OFFSET; - } - } - } -} - -template -__device__ __forceinline__ void q_smem_inplace_apply_rotary( - const uint32_t q_packed_idx, - const uint32_t qo_len, - const uint32_t kv_len, - const uint_fastdiv group_size, - smem_t *q_smem, - uint32_t *q_smem_offset_r, - float (*rope_freq)[4], - const dim3 tid = threadIdx) -{ - if (get_warp_idx_kv(tid.z) == 0) { - constexpr uint32_t UPCAST_STRIDE_Q = KTraits::UPCAST_STRIDE_Q; - const uint32_t lane_idx = tid.x; - uint32_t q_frag_local[2][KTraits::INT32_ELEMS_PER_THREAD]; - static_assert(KTraits::NUM_MMA_D_QK % 4 == 0, - "NUM_MMA_D_QK must be a multiple of 4"); - constexpr uint32_t LAST_HALF_OFFSET = KTraits::NUM_MMA_D_QK * 2; - constexpr uint32_t FIRST_HALF_OFFSET = KTraits::NUM_MMA_D_QK; - const uint32_t SEQ_ID = lane_idx % 16; - -#pragma unroll - for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { - uint32_t q_smem_offset_r_first_half = *q_smem_offset_r; -#pragma unroll - for (uint32_t mma_di = 0; mma_di < FIRST_HALF_OFFSET; ++mma_di) { - q_smem->template load_fragment(q_smem_offset_r_first_half, - q_frag_local[0]); - uint32_t q_smem_offset_r_last_half = - q_smem->template advance_offset_by_column( - q_smem_offset_r_first_half, 0); - q_smem->template load_fragment(q_smem_offset_r_last_half, - q_frag_local[1]); - q_frag_apply_llama_rope( - (typename KTraits::DTypeQ *)q_frag_local[0], - (typename KTraits::DTypeQ *)q_frag_local[1], - rope_freq[mma_di], - q_packed_idx + kv_len * group_size - qo_len * group_size + - mma_q * 16 + SEQ_ID, - group_size); - q_smem->template store_fragment(q_smem_offset_r_last_half, - q_frag_local[1]); - q_smem->template store_fragment(q_smem_offset_r_first_half, - q_frag_local[0]); - q_smem_offset_r_first_half = - q_smem - ->template advance_offset_by_column( - q_smem_offset_r_first_half, mma_di); - } - *q_smem_offset_r += 16 * UPCAST_STRIDE_Q; - } - *q_smem_offset_r -= KTraits::NUM_MMA_Q * 16 * UPCAST_STRIDE_Q; - } -} - -template -__device__ __forceinline__ void compute_qk( - smem_t *q_smem, - uint32_t *q_smem_offset_r, - smem_t *k_smem, - uint32_t *k_smem_offset_r, - typename KTraits::DTypeQKAccum ( - *s_frag)[KTraits::NUM_MMA_KV][KTraits::HALF_ELEMS_PER_THREAD]) -{ - constexpr uint32_t UPCAST_STRIDE_Q = KTraits::UPCAST_STRIDE_Q; - constexpr uint32_t UPCAST_STRIDE_K = KTraits::UPCAST_STRIDE_K; - constexpr uint32_t QK_SMEM_COLUMN_ADVANCE = - 16 / KTraits::HALF_ELEMS_PER_THREAD; - - uint32_t a_frag[KTraits::NUM_MMA_Q][KTraits::INT32_ELEMS_PER_THREAD], - b_frag[KTraits::INT32_ELEMS_PER_THREAD]; - // compute q*k^T -#pragma unroll - for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_QK; ++mma_d) { -#pragma unroll - for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { - q_smem->load_fragment(*q_smem_offset_r, a_frag[mma_q]); - *q_smem_offset_r = - q_smem->template advance_offset_by_row<16, UPCAST_STRIDE_Q>( - *q_smem_offset_r); - } - - *q_smem_offset_r = - q_smem->template advance_offset_by_column( - *q_smem_offset_r, mma_d) - - KTraits::NUM_MMA_Q * 16 * UPCAST_STRIDE_Q; - -#pragma unroll - for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { - - k_smem->load_fragment(*k_smem_offset_r, b_frag); - *k_smem_offset_r = - k_smem->template advance_offset_by_row<16, UPCAST_STRIDE_K>( - *k_smem_offset_r); - -#pragma unroll - for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { - if constexpr (std::is_same_v) - { - if (mma_d == 0) { - mma::mma_sync_m16n16k16_row_col_f16f16f32< - typename KTraits::DTypeQ, MMAMode::kInit>( - s_frag[mma_q][mma_kv], a_frag[mma_q], b_frag); - } - else { - mma::mma_sync_m16n16k16_row_col_f16f16f32< - typename KTraits::DTypeQ>(s_frag[mma_q][mma_kv], - a_frag[mma_q], b_frag); - } - } - else if (std::is_same_v) { - static_assert( - false, - "FP16 DTypeQKAccum not yet implemented for CDNA3"); - } - } - } - if constexpr (sizeof(typename KTraits::DTypeKV) == 1) { - if (mma_d % 2 == 1) { - *k_smem_offset_r = k_smem->template advance_offset_by_column< - QK_SMEM_COLUMN_ADVANCE>(*k_smem_offset_r, mma_d / 2); - } - *k_smem_offset_r -= KTraits::NUM_MMA_KV * 16 * UPCAST_STRIDE_K; - } - else { - *k_smem_offset_r = - k_smem - ->template advance_offset_by_column( - *k_smem_offset_r, mma_d) - - KTraits::NUM_MMA_KV * 16 * UPCAST_STRIDE_K; - } - } - *q_smem_offset_r -= KTraits::NUM_MMA_D_QK * QK_SMEM_COLUMN_ADVANCE; - *k_smem_offset_r -= - KTraits::NUM_MMA_D_QK * sizeof(typename KTraits::DTypeKV); -} - -template -__device__ __forceinline__ void logits_transform( - const Params ¶ms, - typename KTraits::AttentionVariant variant, - const uint32_t batch_idx, - const uint32_t qo_packed_idx_base, - const uint32_t kv_idx_base, - const uint32_t qo_len, - const uint32_t kv_len, - const uint_fastdiv group_size, - DTypeQKAccum (*s_frag)[KTraits::NUM_MMA_KV][KTraits::HALF_ELEMS_PER_THREAD], - const dim3 tid = threadIdx, - const uint32_t kv_head_idx = blockIdx.z) -{ - constexpr uint32_t TPR = KTraits::THREADS_PER_MATRIX_ROW_SET; - constexpr uint32_t NAPTR = KTraits::NUM_ACCUM_ROWS_PER_THREAD; - constexpr uint32_t LIS = KTraits::LOGITS_INDEX_STRIDE; - - const uint32_t lane_idx = tid.x; - uint32_t q[KTraits::NUM_MMA_Q][NAPTR], r[KTraits::NUM_MMA_Q][NAPTR]; - float logits = 0., logitsTransformed = 0.; - -#pragma unroll - for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { -#pragma unroll - for (uint32_t j = 0; j < NAPTR; ++j) { - group_size.divmod(qo_packed_idx_base + mma_q * 16 + lane_idx / TPR + - LIS * j, - q[mma_q][j], r[mma_q][j]); - } - } - -#pragma unroll - for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { -#pragma unroll - for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { -#pragma unroll - for (uint32_t reg_id = 0; reg_id < KTraits::HALF_ELEMS_PER_THREAD; - ++reg_id) - { - const uint32_t q_idx = q[mma_q][reg_id % NAPTR]; - const uint32_t qo_head_idx = - kv_head_idx * group_size + r[mma_q][reg_id % NAPTR]; - const uint32_t kv_idx = kv_idx_base + mma_kv * 16 + - 2 * (lane_idx % TPR) + - 8 * (reg_id / 2) + reg_id % 2; - -#ifdef FP16_QK_REDUCTION_SUPPORTED - if constexpr (std::is_same::value) { - logits = std::bit_cast( - fp16_ieee_to_fp32_value(s_frag[mma_q][mma_kv][reg_id])); - } - else if constexpr (!std::is_same::value) { - logits = s_frag[mma_q][mma_kv][reg_id]; - } -#else - static_assert( - !std::is_same::value, - "Set -DFP16_QK_REDUCTION_SUPPORTED and install boost_math " - "then recompile to support fp16 reduction"); - logits = s_frag[mma_q][mma_kv][reg_id]; -#endif - logitsTransformed = - variant.LogitsTransform(params, logits, batch_idx, q_idx, - kv_idx, qo_head_idx, kv_head_idx); -#ifdef FP16_QK_REDUCTION_SUPPORTED - if constexpr (std::is_same::value) { - s_frag[mma_q][mma_kv][reg_id] = std::bit_cast( - fp16_ieee_from_fp32_value(logitsTransformed)); - } - else if constexpr (!std::is_same::value) { - s_frag[mma_q][mma_kv][reg_id] = logitsTransformed; - } -#else - s_frag[mma_q][mma_kv][reg_id] = logitsTransformed; -#endif - } - } - } -} - -template -__device__ __forceinline__ void -logits_mask(const Params ¶ms, - typename KTraits::AttentionVariant variant, - const uint32_t batch_idx, - const uint32_t qo_packed_idx_base, - const uint32_t kv_idx_base, - const uint32_t qo_len, - const uint32_t kv_len, - const uint32_t chunk_end, - const uint_fastdiv group_size, - typename KTraits::DTypeQKAccum ( - *s_frag)[KTraits::NUM_MMA_KV][KTraits::HALF_ELEMS_PER_THREAD], - const dim3 tid = threadIdx, - const uint32_t kv_head_idx = blockIdx.z) -{ - const uint32_t lane_idx = tid.x; - constexpr uint32_t NUM_MMA_Q = KTraits::NUM_MMA_Q; - constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; - constexpr MaskMode MASK_MODE = KTraits::MASK_MODE; - constexpr uint32_t TPR = KTraits::THREADS_PER_MATRIX_ROW_SET; - constexpr uint32_t NAPTR = KTraits::NUM_ACCUM_ROWS_PER_THREAD; - constexpr uint32_t LIS = KTraits::LOGITS_INDEX_STRIDE; - - uint32_t q[NUM_MMA_Q][NAPTR], r[NUM_MMA_Q][NAPTR]; -#pragma unroll - for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { -#pragma unroll - for (uint32_t j = 0; j < NAPTR; ++j) { - group_size.divmod(qo_packed_idx_base + mma_q * 16 + lane_idx / TPR + - LIS * j, - q[mma_q][j], r[mma_q][j]); - } - } - -#pragma unroll - for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { -#pragma unroll - for (uint32_t mma_kv = 0; mma_kv < NUM_MMA_KV; ++mma_kv) { -#pragma unroll - for (uint32_t reg_id = 0; reg_id < KTraits::HALF_ELEMS_PER_THREAD; - ++reg_id) - { - - const uint32_t q_idx = q[mma_q][(reg_id % NAPTR)], - kv_idx = kv_idx_base + mma_kv * 16 + - 2 * (lane_idx % TPR) + - 8 * (reg_id / 2) + reg_id % 2; - const uint32_t qo_head_idx = - kv_head_idx * group_size + r[mma_q][(reg_id % NAPTR)]; - const bool mask = - (!(MASK_MODE == MaskMode::kCausal - ? (kv_idx + qo_len > kv_len + q_idx || - (kv_idx >= chunk_end)) - : kv_idx >= chunk_end)) && - variant.LogitsMask(params, batch_idx, q_idx, kv_idx, - qo_head_idx, kv_head_idx); - s_frag[mma_q][mma_kv][reg_id] = - (mask) ? s_frag[mma_q][mma_kv][reg_id] - : (KTraits::MaskFillValue); - } - } - } -} - -template -__device__ __forceinline__ void update_mdo_states( - typename KTraits::AttentionVariant variant, - typename KTraits::DTypeQKAccum ( - *s_frag)[KTraits::NUM_MMA_KV][KTraits::HALF_ELEMS_PER_THREAD], - float (*o_frag)[KTraits::NUM_MMA_D_VO][KTraits::HALF_ELEMS_PER_THREAD], - typename KTraits::DTypeQKAccum (*m)[KTraits::NUM_ACCUM_ROWS_PER_THREAD], - float (*d)[KTraits::NUM_ACCUM_ROWS_PER_THREAD]) -{ - using DTypeQKAccum = typename KTraits::DTypeQKAccum; - using AttentionVariant = typename KTraits::AttentionVariant; - constexpr uint32_t NUM_ACCUM_ROWS_PER_THREAD = - KTraits::NUM_ACCUM_ROWS_PER_THREAD; - constexpr bool use_softmax = AttentionVariant::use_softmax; - - if constexpr (use_softmax) { - const float sm_scale = variant.sm_scale_log2; - if constexpr (std::is_same_v) { -#pragma unroll - for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { -#pragma unroll - for (uint32_t j = 0; j < NUM_ACCUM_ROWS_PER_THREAD; ++j) { - float m_prev = m[mma_q][j]; -#pragma unroll - for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; - ++mma_kv) - { - m[mma_q][j] = - max(m[mma_q][j], s_frag[mma_q][mma_kv][j]); - } - // Butterfly reduction across all threads in the band (16 - // threads) for CDNA3's 64-thread wavefront - m[mma_q][j] = - max(m[mma_q][j], gpu_iface::math::shfl_xor_sync( - m[mma_q][j], 0x8)); // 16 apart - m[mma_q][j] = - max(m[mma_q][j], gpu_iface::math::shfl_xor_sync( - m[mma_q][j], 0x4)); // 8 apart - m[mma_q][j] = - max(m[mma_q][j], gpu_iface::math::shfl_xor_sync( - m[mma_q][j], 0x2)); // 4 apart - m[mma_q][j] = - max(m[mma_q][j], gpu_iface::math::shfl_xor_sync( - m[mma_q][j], 0x1)); // 2 apart - - float o_scale = gpu_iface::math::ptx_exp2( - m_prev * sm_scale - m[mma_q][j] * sm_scale); - d[mma_q][j] *= o_scale; - - // Scale output fragments for this specific row -#pragma unroll - for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; - ++mma_d) - { - o_frag[mma_q][mma_d][j] *= o_scale; // Direct indexing - } - - // Convert logits to probabilities for this row -#pragma unroll - for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; - ++mma_kv) - { - s_frag[mma_q][mma_kv][j] = gpu_iface::math::ptx_exp2( - s_frag[mma_q][mma_kv][j] * sm_scale - - m[mma_q][j] * sm_scale); - } - } - } - } - } -} - -template -__device__ __forceinline__ void compute_sfm_v( - smem_t *v_smem, - uint32_t *v_smem_offset_r, - typename KTraits::DTypeQKAccum ( - *s_frag)[KTraits::NUM_MMA_KV][KTraits::HALF_ELEMS_PER_THREAD], - float (*o_frag)[KTraits::NUM_MMA_D_VO][KTraits::HALF_ELEMS_PER_THREAD], - float (*d)[KTraits::NUM_ACCUM_ROWS_PER_THREAD]) -{ - constexpr uint32_t UPCAST_STRIDE_V = KTraits::UPCAST_STRIDE_V; - constexpr uint32_t HALF_ELEMS_PER_THREAD = KTraits::HALF_ELEMS_PER_THREAD; - constexpr uint32_t INT32_ELEMS_PER_THREAD = KTraits::INT32_ELEMS_PER_THREAD; - - constexpr uint32_t V_SMEM_COLUMN_ADVANCE = - 16 / KTraits::HALF_ELEMS_PER_THREAD; - - typename KTraits::DTypeQ s_frag_f16[KTraits::NUM_MMA_Q][KTraits::NUM_MMA_KV] - [HALF_ELEMS_PER_THREAD]; - if constexpr (std::is_same_v) { -#pragma unroll - for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { -#pragma unroll - for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { - vec_cast::template cast< - HALF_ELEMS_PER_THREAD>(s_frag_f16[mma_q][mma_kv], - s_frag[mma_q][mma_kv]); - } - } - } - - if constexpr (KTraits::AttentionVariant::use_softmax) { -#pragma unroll - for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { -#pragma unroll - for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { - if constexpr (std::is_same_v) - { - mma::m16k16_rowsum_f16f16f32(d[mma_q], - s_frag_f16[mma_q][mma_kv]); - } - else { - static_assert( - !std::is_same_v, - "FP16 reduction path not implemented for CDNA3"); - } - } - } - } - -#pragma unroll - for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { -#pragma unroll - for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { - uint32_t b_frag[INT32_ELEMS_PER_THREAD]; - - v_smem->load_fragment_4x4_transposed(*v_smem_offset_r, b_frag); - -#pragma unroll - for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { - if constexpr (std::is_same_v) - { - mma::mma_sync_m16n16k16_row_col_f16f16f32< - typename KTraits::DTypeQ>( - o_frag[mma_q][mma_d], - (uint32_t *)s_frag_f16[mma_q][mma_kv], b_frag); - } - else { - mma::mma_sync_m16n16k16_row_col_f16f16f32< - typename KTraits::DTypeQ>( - o_frag[mma_q][mma_d], (uint32_t *)s_frag[mma_q][mma_kv], - b_frag); - } - } - if constexpr (sizeof(typename KTraits::DTypeKV) == 1) { - if (mma_d % 2 == 1) { - *v_smem_offset_r = - v_smem->template advance_offset_by_column< - V_SMEM_COLUMN_ADVANCE>(*v_smem_offset_r, mma_d / 2); - } - } - else { - *v_smem_offset_r = v_smem->template advance_offset_by_column< - V_SMEM_COLUMN_ADVANCE>(*v_smem_offset_r, mma_d); - } - } - *v_smem_offset_r = - v_smem->template advance_offset_by_row<16, UPCAST_STRIDE_V>( - *v_smem_offset_r) - - sizeof(typename KTraits::DTypeKV) * KTraits::NUM_MMA_D_VO; - } - *v_smem_offset_r -= 16 * KTraits::NUM_MMA_KV * UPCAST_STRIDE_V; -} - -template -__device__ __forceinline__ void normalize_d( - float (*o_frag)[KTraits::NUM_MMA_D_VO][KTraits::HALF_ELEMS_PER_THREAD], - typename KTraits::DTypeQKAccum (*m)[KTraits::NUM_ACCUM_ROWS_PER_THREAD], - float (*d)[KTraits::NUM_ACCUM_ROWS_PER_THREAD]) -{ - using AttentionVariant = typename KTraits::AttentionVariant; - constexpr uint32_t NAPTR = KTraits::NUM_ACCUM_ROWS_PER_THREAD; - - if constexpr (AttentionVariant::use_softmax) { - float d_rcp[KTraits::NUM_MMA_Q][KTraits::NUM_ACCUM_ROWS_PER_THREAD]; - // compute reciprocal of d -#pragma unroll - for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { -#pragma unroll - for (uint32_t j = 0; j < KTraits::NUM_ACCUM_ROWS_PER_THREAD; ++j) { - d_rcp[mma_q][j] = - (m[mma_q][j] != - typename KTraits::DTypeQKAccum(-gpu_iface::math::inf)) - ? gpu_iface::math::ptx_rcp(d[mma_q][j]) - : 0.f; - } - } - -#pragma unroll - for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { -#pragma unroll - for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { -#pragma unroll - for (uint32_t reg_id = 0; - reg_id < KTraits::HALF_ELEMS_PER_THREAD; ++reg_id) - { - o_frag[mma_q][mma_d][reg_id] = - o_frag[mma_q][mma_d][reg_id] * - d_rcp[mma_q][reg_id % NAPTR]; - } - } - } - } -} - -template -__device__ __forceinline__ void finalize_m( - typename KTraits::AttentionVariant variant, - typename KTraits::DTypeQKAccum (*m)[KTraits::NUM_ACCUM_ROWS_PER_THREAD]) -{ - if constexpr (variant.use_softmax) { -#pragma unroll - for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { -#pragma unroll - for (uint32_t j = 0; j < KTraits::NUM_ACCUM_ROWS_PER_THREAD; ++j) { - if (m[mma_q][j] != - typename KTraits::DTypeQKAccum(-gpu_iface::math::inf)) - { - m[mma_q][j] *= variant.sm_scale_log2; - } - } - } - } -} - -/*! - * \brief Synchronize the states of the MDO kernel across the threadblock along - * threadIdx.z. - */ -template -__device__ __forceinline__ void threadblock_sync_mdo_states( - float (*o_frag)[KTraits::NUM_MMA_D_VO][KTraits::HALF_ELEMS_PER_THREAD], - typename KTraits::SharedStorage *smem_storage, - typename KTraits::DTypeQKAccum (*m)[KTraits::NUM_ACCUM_ROWS_PER_THREAD], - float (*d)[KTraits::NUM_ACCUM_ROWS_PER_THREAD], - const uint32_t warp_idx, - const uint32_t lane_idx, - const dim3 tid = threadIdx) -{ - constexpr uint32_t TPR = KTraits::THREADS_PER_MATRIX_ROW_SET; - constexpr uint32_t NARPT = KTraits::NUM_ACCUM_ROWS_PER_THREAD; - - static_assert(WARP_SIZE % TPR == 0, - "THREADS_PER_MATRIX_ROW_SET must divide WARP_SIZE"); - constexpr uint32_t GROUPS_PER_WARP = WARP_SIZE / TPR; - const uint32_t lane_group_idx = lane_idx / TPR; - - // only necessary when blockDim.z > 1 - if constexpr (KTraits::NUM_WARPS_KV > 1) { - float *smem_o = smem_storage->cta_sync_o_smem; - float2 *smem_md = smem_storage->cta_sync_md_smem; - // o: [num_warps, - // NUM_MMA_Q, - // NUM_MMA_D_VO, - // WARP_SIZE, - // HALF_ELEMS_PER_THREAD] - // md: [num_warps, NUM_MMA_Q, 16, 2 (m/d)] -#pragma unroll - for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { -#pragma unroll - for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { - vec_t::memcpy( - smem_o + (((warp_idx * KTraits::NUM_MMA_Q + mma_q) * - KTraits::NUM_MMA_D_VO + - mma_d) * - WARP_SIZE + - lane_idx) * - KTraits::HALF_ELEMS_PER_THREAD, - o_frag[mma_q][mma_d]); - } - } - - if constexpr (KTraits::AttentionVariant::use_softmax) { -#pragma unroll - for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { -#pragma unroll - for (uint32_t j = 0; j < NARPT; ++j) { - smem_md[((warp_idx * KTraits::NUM_MMA_Q + mma_q) * NARPT + - j) * - GROUPS_PER_WARP + - lane_group_idx] = - make_float2(float(m[mma_q][j]), d[mma_q][j]); - } - } - - // synchronize m,d first - __syncthreads(); -#pragma unroll - for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { - float o_scale[NARPT][KTraits::NUM_WARPS_KV]; -#pragma unroll - for (uint32_t j = 0; j < NARPT; ++j) { - float m_new = -gpu_iface::math::inf, d_new = 1.f; -#pragma unroll - for (uint32_t i = 0; i < KTraits::NUM_WARPS_KV; ++i) { - float2 md = smem_md[(((i * KTraits::NUM_WARPS_Q + - get_warp_idx_q(tid.y)) * - KTraits::NUM_MMA_Q + - mma_q) * - NARPT + - j) * - GROUPS_PER_WARP + - lane_group_idx]; - float m_prev = m_new, d_prev = d_new; - m_new = max(m_new, md.x); - d_new = - d_prev * gpu_iface::math::ptx_exp2(m_prev - m_new) + - md.y * gpu_iface::math::ptx_exp2(md.x - m_new); - } - -#pragma unroll - for (uint32_t i = 0; i < KTraits::NUM_WARPS_KV; ++i) { - float2 md = smem_md[(((i * KTraits::NUM_WARPS_Q + - get_warp_idx_q(tid.y)) * - KTraits::NUM_MMA_Q + - mma_q) * - NARPT + - j) * - GROUPS_PER_WARP + - lane_group_idx]; - float mi = md.x; - o_scale[j][i] = - gpu_iface::math::ptx_exp2(float(mi - m_new)); - } - m[mma_q][j] = typename KTraits::DTypeQKAccum(m_new); - d[mma_q][j] = d_new; - } - -#pragma unroll - for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) - { - vec_t o_new; - o_new.fill(0.f); -#pragma unroll - for (uint32_t i = 0; i < KTraits::NUM_WARPS_KV; ++i) { - vec_t oi; - oi.load(smem_o + ((((i * KTraits::NUM_WARPS_Q + - get_warp_idx_q(tid.y)) * - KTraits::NUM_MMA_Q + - mma_q) * - KTraits::NUM_MMA_D_VO + - mma_d) * - WARP_SIZE + - lane_idx) * - KTraits::HALF_ELEMS_PER_THREAD); - -#pragma unroll - for (uint32_t reg_id = 0; - reg_id < KTraits::HALF_ELEMS_PER_THREAD; ++reg_id) - { - // CDNA3: Direct mapping - each reg_id corresponds - // to one accumulator row - o_new[reg_id] += oi[reg_id] * o_scale[reg_id][i]; - } - } - o_new.store(o_frag[mma_q][mma_d]); - } - } - } - else { - // synchronize m,d first - __syncthreads(); -#pragma unroll - for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { -#pragma unroll - for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) - { - vec_t o_new; - o_new.fill(0.f); -#pragma unroll - for (uint32_t i = 0; i < KTraits::NUM_WARPS_KV; ++i) { - vec_t oi; - oi.load(smem_o + ((((i * KTraits::NUM_WARPS_Q + - get_warp_idx_q(tid.y)) * - KTraits::NUM_MMA_Q + - mma_q) * - KTraits::NUM_MMA_D_VO + - mma_d) * - WARP_SIZE + - lane_idx) * - KTraits::HALF_ELEMS_PER_THREAD); -#pragma unroll - for (uint32_t reg_id = 0; - reg_id < KTraits::HALF_ELEMS_PER_THREAD; ++reg_id) - { - o_new[reg_id] += oi[reg_id]; - } - } - o_new.store(o_frag[mma_q][mma_d]); - } - } - } - } -} - -template -__device__ __forceinline__ void write_o_reg_gmem( - float (*o_frag)[KTraits::NUM_MMA_D_VO][KTraits::HALF_ELEMS_PER_THREAD], - smem_t *o_smem, - typename KTraits::DTypeO *o_ptr_base, - const uint32_t o_packed_idx_base, - const uint32_t qo_upper_bound, - const uint32_t o_stride_n, - const uint32_t o_stride_h, - const uint_fastdiv group_size, - const dim3 tid = threadIdx) -{ - using DTypeO = typename KTraits::DTypeO; - constexpr uint32_t UPCAST_STRIDE_O = KTraits::UPCAST_STRIDE_O; - constexpr uint32_t TPR = KTraits::THREADS_PER_MATRIX_ROW_SET; - constexpr uint32_t NAPTR = KTraits::NUM_ACCUM_ROWS_PER_THREAD; - constexpr uint32_t HALF_ELEMS_PER_THREAD = KTraits::HALF_ELEMS_PER_THREAD; - constexpr uint32_t WARP_THREAD_COLS = KTraits::WARP_THREAD_COLS; - constexpr uint32_t VECTOR_BIT_WIDTH = KTraits::VECTOR_BIT_WIDTH; - - const uint32_t warp_idx_x = get_warp_idx_q(tid.y); - const uint32_t lane_idx = tid.x; - - if constexpr (sizeof(DTypeO) == 4) { -#pragma unroll - for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { -#pragma unroll - for (uint32_t j = 0; j < NAPTR; ++j) { - uint32_t q, r; - group_size.divmod(o_packed_idx_base + lane_idx / TPR + - mma_q * 16 + j * 8, - q, r); - const uint32_t o_idx = q; -#pragma unroll - for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) - { - if (o_idx < qo_upper_bound) { - auto base_addr = o_ptr_base + q * o_stride_n + - r * o_stride_h + mma_d * 16; - auto col_offset = lane_idx % 16; - *(base_addr + col_offset) = o_frag[mma_q][mma_d][j]; - } - } - } - } - } - else { - if (get_warp_idx_kv(tid.z) == 0) { -#pragma unroll - for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { -#pragma unroll - for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) - { - uint32_t o_frag_f16[HALF_ELEMS_PER_THREAD / 2]; - vec_cast::template cast< - HALF_ELEMS_PER_THREAD>((DTypeO *)o_frag_f16, - o_frag[mma_q][mma_d]); - -#ifdef FLASHINFER_STMATRIX_M8N8X4_ENABLED - uint32_t o_smem_offset_w = - o_smem->template get_permuted_offset( - (warp_idx_x * KTraits::NUM_MMA_Q + mma_q) * 16 + - lane_idx % 16, - mma_d * 2 + lane_idx / 16); - o_smem->stmatrix_m8n8x4(o_smem_offset_w, o_frag_f16); -#else - uint32_t o_smem_offset_w = - o_smem->template get_permuted_offset( - (warp_idx_x * KTraits::NUM_MMA_Q + mma_q) * 16 + - lane_idx / TPR, - mma_d * 2); - ((uint32_t *)(o_smem->base + - o_smem_offset_w))[lane_idx % TPR] = - o_frag_f16[0]; - // Move 2 elements forward in the same row - uint32_t offset_2 = o_smem_offset_w + 2; - ((uint32_t *)(o_smem->base + offset_2))[lane_idx % 16] = - o_frag_f16[1]; - -#endif - } - } - - uint32_t o_smem_offset_w = - o_smem->template get_permuted_offset( - warp_idx_x * KTraits::NUM_MMA_Q * 16 + - lane_idx / WARP_THREAD_COLS, - lane_idx % WARP_THREAD_COLS); - -#pragma unroll - for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { -#pragma unroll - for (uint32_t j = 0; j < 2 * 2; ++j) { - uint32_t q, r; - group_size.divmod(o_packed_idx_base + - lane_idx / WARP_THREAD_COLS + - mma_q * 16 + j * 4, - q, r); - const uint32_t o_idx = q; - DTypeO *o_ptr = o_ptr_base + q * o_stride_n + - r * o_stride_h + - (lane_idx % WARP_THREAD_COLS) * - upcast_size(); -#pragma unroll - for (uint32_t mma_do = 0; - mma_do < KTraits::NUM_MMA_D_VO / 4; ++mma_do) - { - if (o_idx < qo_upper_bound) { - o_smem->store_vector(o_smem_offset_w, o_ptr); - } - o_ptr += WARP_THREAD_COLS * - upcast_size(); - o_smem_offset_w = - o_smem->template advance_offset_by_column< - WARP_THREAD_COLS>(o_smem_offset_w, mma_do); - } - o_smem_offset_w = o_smem->template advance_offset_by_row< - 4, UPCAST_STRIDE_O>(o_smem_offset_w) - - 2 * KTraits::NUM_MMA_D_VO; - } - } - } - } -} - -} // namespace - -/*! - * \brief FlashAttention prefill CUDA kernel for a single request. - * \tparam partition_kv Whether to split kv_len into chunks. - * \tparam mask_mode The mask mode used in the attention operation. - * \tparam POS_ENCODING_MODE The positional encoding mode. - * \tparam NUM_MMA_Q The number of fragments in x dimension. - * \tparam NUM_MMA_D_VO The number of fragments in y dimension. - * \tparam NUM_MMA_KV The number of fragments in z dimension. - * \tparam num_warps The number of warps in the threadblock. - * \tparam DTypeQ The data type of the query tensor. - * \tparam DTypeKV The data type of the key/value tensor. - * \tparam DTypeO The data type of the output tensor. - * \param q The query tensor. - * \param k The key tensor. - * \param v The value tensor. - * \param o The output tensor. - * \param tmp The temporary buffer (used when partition_kv is true). - * \param lse The logsumexp value. - * \param rope_rcp_scale 1/(rope_scale), where rope_scale is the scaling - * factor used in RoPE interpolation. - * \param rope_rcp_theta 1/(rope_theta), where rope_theta is the theta - * used in RoPE. - */ -template -__device__ __forceinline__ void -SinglePrefillWithKVCacheDevice(const Params params, - typename KTraits::SharedStorage &smem_storage, - const dim3 tid = threadIdx, - const uint32_t bx = blockIdx.x, - const uint32_t chunk_idx = blockIdx.y, - const uint32_t kv_head_idx = blockIdx.z, - const uint32_t num_chunks = gridDim.y, - const uint32_t num_kv_heads = gridDim.z) -{ - using DTypeQ = typename Params::DTypeQ; - using DTypeKV = typename Params::DTypeKV; - using DTypeO = typename Params::DTypeO; - using DTypeQKAccum = typename KTraits::DTypeQKAccum; - using AttentionVariant = typename KTraits::AttentionVariant; - [[maybe_unused]] constexpr uint32_t NUM_MMA_Q = KTraits::NUM_MMA_Q; - [[maybe_unused]] constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; - [[maybe_unused]] constexpr uint32_t NUM_MMA_D_QK = KTraits::NUM_MMA_D_QK; - [[maybe_unused]] constexpr uint32_t NUM_MMA_D_VO = KTraits::NUM_MMA_D_VO; - [[maybe_unused]] constexpr uint32_t HEAD_DIM_QK = KTraits::HEAD_DIM_QK; - [[maybe_unused]] constexpr uint32_t HEAD_DIM_VO = KTraits::HEAD_DIM_VO; - [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_Q = - KTraits::UPCAST_STRIDE_Q; - [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_K = - KTraits::UPCAST_STRIDE_K; - [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_V = - KTraits::UPCAST_STRIDE_V; - [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_O = - KTraits::UPCAST_STRIDE_O; - [[maybe_unused]] constexpr uint32_t CTA_TILE_Q = KTraits::CTA_TILE_Q; - [[maybe_unused]] constexpr uint32_t CTA_TILE_KV = KTraits::CTA_TILE_KV; - [[maybe_unused]] constexpr uint32_t NUM_WARPS_Q = KTraits::NUM_WARPS_Q; - [[maybe_unused]] constexpr uint32_t NUM_WARPS_KV = KTraits::NUM_WARPS_KV; - [[maybe_unused]] constexpr SwizzleMode SWIZZLE_MODE_Q = - KTraits::SWIZZLE_MODE_Q; - [[maybe_unused]] constexpr SwizzleMode SWIZZLE_MODE_KV = - KTraits::SWIZZLE_MODE_KV; - [[maybe_unused]] constexpr uint32_t KV_THR_LAYOUT_ROW = - KTraits::KV_THR_LAYOUT_ROW; - [[maybe_unused]] constexpr uint32_t KV_THR_LAYOUT_COL = - KTraits::KV_THR_LAYOUT_COL; - [[maybe_unused]] constexpr MaskMode MASK_MODE = KTraits::MASK_MODE; - [[maybe_unused]] constexpr uint32_t HALF_ELEMS_PER_THREAD = - KTraits::HALF_ELEMS_PER_THREAD; - [[maybe_unused]] constexpr uint32_t NUM_ACCUM_ROWS_PER_THREAD = - KTraits::NUM_ACCUM_ROWS_PER_THREAD; - [[maybe_unused]] constexpr uint32_t LOGITS_INDEX_STRIDE = - KTraits::LOGITS_INDEX_STRIDE; - [[maybe_unused]] constexpr uint32_t THREADS_PER_MATRIX_ROW_SET = - KTraits::THREADS_PER_MATRIX_ROW_SET; - [[maybe_unused]] constexpr uint32_t VECTOR_BIT_WIDTH = - KTraits::VECTOR_BIT_WIDTH; - - DTypeQ *q = params.q; - DTypeKV *k = params.k; - DTypeKV *v = params.v; - DTypeO *o = params.o; - float *lse = params.lse; - const uint32_t qo_len = params.qo_len; - const uint32_t kv_len = params.kv_len; - const bool partition_kv = params.partition_kv; - const uint32_t q_stride_n = params.q_stride_n; - const uint32_t q_stride_h = params.q_stride_h; - const uint32_t k_stride_n = params.k_stride_n; - const uint32_t k_stride_h = params.k_stride_h; - const uint32_t v_stride_n = params.v_stride_n; - const uint32_t v_stride_h = params.v_stride_h; - const uint_fastdiv &group_size = params.group_size; - - static_assert(sizeof(DTypeQ) == 2); - const uint32_t lane_idx = tid.x, - warp_idx = get_warp_idx(tid.y, tid.z); - const uint32_t num_qo_heads = num_kv_heads * group_size; - - const uint32_t max_chunk_size = - partition_kv ? ceil_div(kv_len, num_chunks) : kv_len; - const uint32_t chunk_start = partition_kv ? chunk_idx * max_chunk_size : 0; - const uint32_t chunk_end = - partition_kv ? min((chunk_idx + 1) * max_chunk_size, kv_len) : kv_len; - const uint32_t chunk_size = chunk_end - chunk_start; - - auto block = cg::this_thread_block(); - auto smem = reinterpret_cast(&smem_storage); - AttentionVariant variant(params, /*batch_idx=*/0, smem); - const uint32_t window_left = variant.window_left; - - DTypeQKAccum s_frag[NUM_MMA_Q][NUM_MMA_KV][HALF_ELEMS_PER_THREAD]; - alignas(16) float o_frag[NUM_MMA_Q][NUM_MMA_D_VO][HALF_ELEMS_PER_THREAD]; - DTypeQKAccum m[NUM_MMA_Q][NUM_ACCUM_ROWS_PER_THREAD]; - float d[NUM_MMA_Q][NUM_ACCUM_ROWS_PER_THREAD]; - float rope_freq[NUM_MMA_D_QK / 2][4]; - - init_states(variant, o_frag, m, d); - - // cooperative fetch q fragment from gmem to reg - const uint32_t qo_packed_idx_base = - (bx * NUM_WARPS_Q + get_warp_idx_q(tid.y)) * NUM_MMA_Q * 16; - smem_t qo_smem( - smem_storage.q_smem); - const uint32_t o_stride_n = num_qo_heads * HEAD_DIM_VO, - o_stride_h = HEAD_DIM_VO; - DTypeQ *q_ptr_base = q + (kv_head_idx * group_size) * q_stride_h; - DTypeO *o_ptr_base = partition_kv - ? o + chunk_idx * o_stride_n + - (kv_head_idx * group_size) * o_stride_h - : o + (kv_head_idx * group_size) * o_stride_h; - - uint32_t q_smem_offset_r = - qo_smem.template get_permuted_offset( - get_warp_idx_q(tid.y) * NUM_MMA_Q * 16 + lane_idx % 16, - lane_idx / 16); - - load_q_global_smem(qo_packed_idx_base, qo_len, q_ptr_base, - q_stride_n, q_stride_h, group_size, &qo_smem, - tid); - - memory::commit_group(); - - smem_t k_smem( - smem_storage.k_smem); - smem_t v_smem( - smem_storage.v_smem); - - const uint32_t num_iterations = ceil_div( - MASK_MODE == MaskMode::kCausal - ? min(chunk_size, - sub_if_greater_or_zero( - kv_len - qo_len + ((bx + 1) * CTA_TILE_Q) / group_size, - chunk_start)) - : chunk_size, - CTA_TILE_KV); - - const uint32_t window_iteration = ceil_div( - sub_if_greater_or_zero(kv_len + (bx + 1) * CTA_TILE_Q / group_size, - qo_len + window_left + chunk_start), - CTA_TILE_KV); - - const uint32_t mask_iteration = - (MASK_MODE == MaskMode::kCausal - ? min(chunk_size, - sub_if_greater_or_zero( - kv_len + (bx * CTA_TILE_Q) / group_size - qo_len, - chunk_start)) - : chunk_size) / - CTA_TILE_KV; - - DTypeKV *k_ptr = k + - (chunk_start + warp_idx * KV_THR_LAYOUT_ROW + - lane_idx / KV_THR_LAYOUT_COL) * - k_stride_n + - kv_head_idx * k_stride_h + - (lane_idx % KV_THR_LAYOUT_COL) * - upcast_size(); - DTypeKV *v_ptr = v + - (chunk_start + warp_idx * KV_THR_LAYOUT_ROW + - lane_idx / KV_THR_LAYOUT_COL) * - v_stride_n + - kv_head_idx * v_stride_h + - (lane_idx % KV_THR_LAYOUT_COL) * - upcast_size(); - uint32_t k_smem_offset_r = - k_smem.template get_permuted_offset( - get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + lane_idx % 16, - (lane_idx / 16)); - - uint32_t - v_smem_offset_r = v_smem.template get_permuted_offset( - get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + lane_idx % 16, - lane_idx / 16), - k_smem_offset_w = k_smem.template get_permuted_offset( - warp_idx * KV_THR_LAYOUT_ROW + lane_idx / KV_THR_LAYOUT_COL, - lane_idx % KV_THR_LAYOUT_COL), - v_smem_offset_w = v_smem.template get_permuted_offset( - warp_idx * KV_THR_LAYOUT_ROW + lane_idx / KV_THR_LAYOUT_COL, - lane_idx % KV_THR_LAYOUT_COL); - produce_kv( - k_smem, &k_smem_offset_w, &k_ptr, k_stride_n, 0, chunk_size, tid); - memory::commit_group(); - produce_kv( - v_smem, &v_smem_offset_w, &v_ptr, v_stride_n, 0, chunk_size, tid); - memory::commit_group(); - -#if Debug - int global_idx = (blockIdx.z * gridDim.y * gridDim.x + - blockIdx.y * gridDim.x + blockIdx.x) * - (blockDim.z * blockDim.y * blockDim.x) + - (threadIdx.z * blockDim.y * blockDim.x + - threadIdx.y * blockDim.x + threadIdx.x); - - if (global_idx == 0) { - printf("partition_kv : %d\n", partition_kv); - printf("kv_len : %d\n", kv_len); - printf("max_chunk_size : %d\n", max_chunk_size); - printf("chunk_end : %d\n", chunk_end); - printf("chunk_start : %d\n", chunk_start); - } - // Test Q - // if (global_idx == 0) { - // uint32_t q_smem_offset_r_debug; - // //for (auto i = 0; i < 4; ++i) { - // for (auto j = 0; j < 16; ++j) { - // uint32_t q_smem_offset_r_debug = - // qo_smem.template - // get_permuted_offset( - // get_warp_idx_q(tid.y) * NUM_MMA_Q * 16 - // + (j) % 16, (j) / 16); - // uint32_t a_frag[KTraits::INT32_ELEMS_PER_THREAD]; - // k_smem.load_fragment(q_smem_offset_r_debug, a_frag); - // auto frag_T = reinterpret_cast<__half *>(a_frag); - // for (auto i = 0ul; i < 4; ++i) { - // printf("%f ", (float)(*(frag_T + i))); - // } - // printf("\n"); - // } - // // q_smem_offset_r_debug = qo_smem.template - // advance_offset_by_column<4>( - // // q_smem_offset_r_debug, 0); - // // } - // } - - // for (auto mma_q = 0ul; mma_q < 4; ++mma_q) { - // uint32_t a_frag[KTraits::INT32_ELEMS_PER_THREAD]; - // qo_smem.load_fragment(q_smem_offset_r, a_frag); - // if (global_idx == 0) { - // auto frag_T = reinterpret_cast<__half *>(a_frag); - // printf("DEBUG: Q Frag in permuted_smem for mma_q %lu \n", - // mma_q); for (auto i = 0ul; i < 4; ++i) { - // printf("%f ", (float)(*(frag_T + i))); - // } - // printf("\n"); - // } - - // q_smem_offset_r = qo_smem.template advance_offset_by_column<4>( - // q_smem_offset_r, 0); - // } - - uint32_t a_frag[KTraits::INT32_ELEMS_PER_THREAD]; - qo_smem.load_fragment(q_smem_offset_r, a_frag); - if (global_idx == 0) { - auto frag_T = reinterpret_cast<__half *>(a_frag); - printf("DEBUG: Q Frag \n"); - for (auto i = 0ul; i < 4; ++i) { - printf("%f ", (float)(*(frag_T + i))); - } - printf("\n"); - } - - memory::wait_group<0>(); - block.sync(); - q_smem_inplace_apply_rotary(qo_packed_idx_base, qo_len, kv_len, - group_size, &qo_smem, &q_smem_offset_r, - rope_freq, tid); - block.sync(); - - qo_smem.load_fragment(q_smem_offset_r, a_frag); - if (global_idx == 0) { - auto frag_T = reinterpret_cast<__half *>(a_frag); - printf("DEBUG: LLAMA Rope transformed Q Frag \n"); - for (auto i = 0ul; i < 4; ++i) { - printf("%f ", (float)(*(frag_T + i))); - } - printf("\n"); - } - - // // Test K loads - // if (global_idx == 0) { - - // for (auto j = 0; j < 64; ++j) { - // uint32_t k_smem_offset_r_test = - // k_smem.template get_permuted_offset( - // get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + - // j % 16, - // (j / 16)); - // uint32_t b_frag[KTraits::INT32_ELEMS_PER_THREAD]; - // k_smem.load_fragment(k_smem_offset_r_test, b_frag); - // auto frag_T = reinterpret_cast<__half *>(b_frag); - // // printf("DEBUG: K Frag in permuted_smem for mma_kv %lu \n", - // // mma_kv); - // for (auto i = 0ul; i < 4; ++i) { - // printf("%f ", (float)(*(frag_T + i))); - // } - // printf("\n"); - // } - // } - - // if (global_idx == 0) { - // printf("DEBUG Q ORIGINAL (HIP):\n"); - - // for (uint32_t seq_idx = 0; seq_idx < 16; ++seq_idx) { - // printf("Q[%u] original: ", seq_idx); - - // // Load all feature groups for this sequence - // for (uint32_t feat_group = 0; feat_group < NUM_MMA_D_QK; - // ++feat_group) { - // uint32_t feat_offset = qo_smem.template - // get_permuted_offset( - // seq_idx, feat_group * HALF_ELEMS_PER_THREAD); - - // uint32_t q_frag[KTraits::INT32_ELEMS_PER_THREAD]; - // qo_smem.load_fragment(feat_offset, q_frag); - // auto frag_T = reinterpret_cast<__half *>(q_frag); - - // // Print 4 features from this group - // for (auto feat = 0ul; feat < HALF_ELEMS_PER_THREAD; - // ++feat) { - // printf("%f ", (float)(*(frag_T + feat))); - // } - // } - // printf("\n"); - // } - // } - - // memory::wait_group<0>(); - // block.sync(); - // q_smem_inplace_apply_rotary( - // qo_packed_idx_base, qo_len, kv_len, group_size, &qo_smem, - // &q_smem_offset_r, rope_freq, tid); - // block.sync(); - - // // Debug: Print Q fragments after RoPE - // if (global_idx == 0) { - // printf("DEBUG Q LLAMA ROPE (HIP):\n"); - - // // Reset q_smem_offset_r to start - // uint32_t q_smem_offset_r_debug = - // qo_smem.template get_permuted_offset( - // get_warp_idx_q(tid.y) * NUM_MMA_Q * 16 + - // lane_idx % 16, lane_idx / 16); - - // for (uint32_t seq_idx = 0; seq_idx < 16; ++seq_idx) { - // // Calculate offset for this sequence - // uint32_t seq_offset = qo_smem.template - // get_permuted_offset( - // seq_idx, 0); - - // printf("Q[%u] after RoPE: ", seq_idx); - - // // Load all feature groups for this sequence - // for (uint32_t feat_group = 0; feat_group < NUM_MMA_D_QK; - // ++feat_group) { - // uint32_t feat_offset = qo_smem.template - // get_permuted_offset( - // seq_idx, feat_group * HALF_ELEMS_PER_THREAD); - - // uint32_t q_frag[KTraits::INT32_ELEMS_PER_THREAD]; - // qo_smem.load_fragment(feat_offset, q_frag); - // auto frag_T = reinterpret_cast<__half *>(q_frag); - - // // Print 4 features from this group - // for (auto feat = 0ul; feat < HALF_ELEMS_PER_THREAD; - // ++feat) { - // printf("%f ", (float)(*(frag_T + feat))); - // } - // } - // printf("\n"); - // } - // } -#endif - -#pragma unroll 1 - for (uint32_t iter = 0; iter < num_iterations; ++iter) { - memory::wait_group<1>(); - block.sync(); - - // compute attention score - compute_qk(&qo_smem, &q_smem_offset_r, &k_smem, - &k_smem_offset_r, s_frag); - - logits_transform( - params, variant, /*batch_idx=*/0, qo_packed_idx_base, - chunk_start + - (iter * NUM_WARPS_KV + get_warp_idx_kv(tid.z)) * - NUM_MMA_KV * 16, - qo_len, kv_len, group_size, s_frag, tid, kv_head_idx); - - // // apply mask - // if (MASK_MODE == MaskMode::kCustom || - // (iter >= mask_iteration || iter < window_iteration)) - // { - // logits_mask( - // params, variant, /*batch_idx=*/0, qo_packed_idx_base, - // chunk_start + (iter * NUM_WARPS_KV + - // get_warp_idx_kv(tid.z)) * - // NUM_MMA_KV * 16, - // qo_len, kv_len, chunk_end, group_size, s_frag, tid, - // kv_head_idx); - // } - - // compute m,d states in online softmax - update_mdo_states(variant, s_frag, o_frag, m, d); - - block.sync(); - - produce_kv( - k_smem, &k_smem_offset_w, &k_ptr, k_stride_n, - (iter + 1) * CTA_TILE_KV, chunk_size, tid); - memory::commit_group(); - memory::wait_group<1>(); - block.sync(); - - // compute sfm*v - compute_sfm_v(&v_smem, &v_smem_offset_r, s_frag, o_frag, d); - - block.sync(); - produce_kv( - v_smem, &v_smem_offset_w, &v_ptr, v_stride_n, - (iter + 1) * CTA_TILE_KV, chunk_size, tid); - memory::commit_group(); - } - memory::wait_group<0>(); - block.sync(); - - finalize_m(variant, m); - - // threadblock synchronization - threadblock_sync_mdo_states(o_frag, &smem_storage, m, d, warp_idx, - lane_idx, tid); - - // normalize d - normalize_d(o_frag, m, d); - - // write back - write_o_reg_gmem( - o_frag, &qo_smem, o_ptr_base, qo_packed_idx_base, qo_len, - /*o_stride_n=*/ - partition_kv ? num_chunks * o_stride_n : o_stride_n, - /*o_stride_h=*/o_stride_h, group_size, tid); - - // write lse - if constexpr (variant.use_softmax) { - if (lse != nullptr || partition_kv) { - if (get_warp_idx_kv(tid.z) == 0) { -#pragma unroll - for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { -#pragma unroll - for (uint32_t j = 0; j < NUM_ACCUM_ROWS_PER_THREAD; ++j) { - uint32_t q, r; - group_size.divmod( - qo_packed_idx_base + - lane_idx / THREADS_PER_MATRIX_ROW_SET + - j * LOGITS_INDEX_STRIDE + mma_q * 16, - q, r); - const uint32_t qo_head_idx = - kv_head_idx * group_size + r; - const uint32_t qo_idx = q; - if (qo_idx < qo_len) { - if (partition_kv) { - lse[(qo_idx * num_chunks + chunk_idx) * - num_qo_heads + - qo_head_idx] = - gpu_iface::math::ptx_log2(d[mma_q][j]) + - float(m[mma_q][j]); - } - else { - lse[qo_idx * num_qo_heads + qo_head_idx] = - gpu_iface::math::ptx_log2(d[mma_q][j]) + - float(m[mma_q][j]); - } - } - } - } - } - } - } -} - -template -__global__ -__launch_bounds__(KTraits::NUM_THREADS) void SinglePrefillWithKVCacheKernel( - const __grid_constant__ Params params) -{ - extern __shared__ uint8_t smem[]; - auto &smem_storage = - reinterpret_cast(smem); - SinglePrefillWithKVCacheDevice(params, smem_storage); -} - -template -gpuError_t SinglePrefillWithKVCacheDispatched(Params params, - typename Params::DTypeO *tmp, - gpuStream_t stream) -{ - using DTypeQ = typename Params::DTypeQ; - using DTypeKV = typename Params::DTypeKV; - using DTypeO = typename Params::DTypeO; - const uint32_t num_qo_heads = params.num_qo_heads; - const uint32_t num_kv_heads = params.num_kv_heads; - const uint32_t qo_len = params.qo_len; - const uint32_t kv_len = params.kv_len; - if (kv_len < qo_len && MASK_MODE == MaskMode::kCausal) { - std::ostringstream err_msg; - err_msg << "When mask_mode is set to MaskMode::kCausal, kv_len must be " - "greater than or equal to qo_len, got kv_len" - << kv_len << " and qo_len " << qo_len; - FLASHINFER_ERROR(err_msg.str()); - } - - const uint32_t group_size = num_qo_heads / num_kv_heads; - constexpr uint32_t NUM_MMA_D_QK = HEAD_DIM_QK / 16; - constexpr uint32_t NUM_MMA_D_VO = HEAD_DIM_VO / 16; - int64_t packed_qo_len = qo_len * group_size; - uint32_t cta_tile_q = FA2DetermineCtaTileQ(packed_qo_len, HEAD_DIM_VO); - - DISPATCH_CTA_TILE_Q(cta_tile_q, CTA_TILE_Q, { - constexpr uint32_t NUM_WARPS_Q = get_num_warps_q(CTA_TILE_Q); - constexpr uint32_t NUM_WARPS_KV = get_num_warps_kv(CTA_TILE_Q); - constexpr uint32_t NUM_MMA_Q = get_num_mma_q(CTA_TILE_Q); - - using DTypeQKAccum = - typename std::conditional, - half, float>::type; - - int dev_id = 0; - FI_GPU_CALL(gpuGetDevice(&dev_id)); - int max_smem_per_sm = getMaxSharedMemPerMultiprocessor(dev_id); - // we expect each sm execute two threadblocks - const int num_ctas_per_sm = - max_smem_per_sm >= 2 * (CTA_TILE_Q * HEAD_DIM_QK * sizeof(DTypeQ) + - (HEAD_DIM_QK + HEAD_DIM_VO) * 16 * - NUM_WARPS_KV * sizeof(DTypeKV)) - ? 2 - : 1; - const int max_smem_per_threadblock = max_smem_per_sm / num_ctas_per_sm; - - const uint32_t max_num_mma_kv_reg = - (HEAD_DIM_VO >= 128 && NUM_MMA_Q == 2 && - POS_ENCODING_MODE == PosEncodingMode::kRoPELlama && - !USE_FP16_QK_REDUCTION) - ? 2 - : (8 / NUM_MMA_Q); - const uint32_t max_num_mma_kv_smem = - (max_smem_per_threadblock - - CTA_TILE_Q * HEAD_DIM_QK * sizeof(DTypeQ)) / - ((HEAD_DIM_QK + HEAD_DIM_VO) * 16 * NUM_WARPS_KV * sizeof(DTypeKV)); - - // control NUM_MMA_KV for maximum warp occupancy - DISPATCH_NUM_MMA_KV( - min(max_num_mma_kv_smem, max_num_mma_kv_reg), NUM_MMA_KV, { - using KTraits = - KernelTraits; - if constexpr (KTraits::IsInvalid()) { - // Invalid configuration, skip - std::ostringstream err_msg; - err_msg << "FlashInfer Internal Error: Invalid " - "configuration : NUM_MMA_Q=" - << NUM_MMA_Q << " NUM_MMA_D_QK=" << NUM_MMA_D_QK - << " NUM_MMA_D_VO=" << NUM_MMA_D_VO - << " NUM_MMA_KV=" << NUM_MMA_KV - << " NUM_WARPS_Q=" << NUM_WARPS_Q - << " NUM_WARPS_KV=" << NUM_WARPS_KV - << " please create an issue " - "(https://github.com/flashinfer-ai/flashinfer/" - "issues)" - " and report the issue to the developers."; - FLASHINFER_ERROR(err_msg.str()); - } - else { - constexpr uint32_t num_threads = - (NUM_WARPS_Q * NUM_WARPS_KV) * WARP_SIZE; - auto kernel = - SinglePrefillWithKVCacheKernel; - size_t smem_size = sizeof(typename KTraits::SharedStorage); - FI_GPU_CALL(gpuFuncSetAttribute( - kernel, gpuFuncAttributeMaxDynamicSharedMemorySize, - smem_size)); - int num_blocks_per_sm = 0; - int num_sm = 0; - FI_GPU_CALL(gpuDeviceGetAttribute( - &num_sm, gpuDevAttrMultiProcessorCount, dev_id)); - FI_GPU_CALL(gpuOccupancyMaxActiveBlocksPerMultiprocessor( - &num_blocks_per_sm, kernel, num_threads, smem_size)); - uint32_t max_num_kv_chunks = - (num_blocks_per_sm * num_sm) / - (num_kv_heads * - ceil_div(qo_len * group_size, CTA_TILE_Q)); - uint32_t num_chunks; - if (max_num_kv_chunks > 0) { - uint32_t chunk_size = - max(ceil_div(kv_len, max_num_kv_chunks), 256); - num_chunks = ceil_div(kv_len, chunk_size); - } - else { - num_chunks = 0; - } - - if (num_chunks <= 1 || tmp == nullptr) { - // Enough parallelism, do not split-kv - params.partition_kv = false; - void *args[] = {(void *)¶ms}; - dim3 nblks(ceil_div(qo_len * group_size, CTA_TILE_Q), 1, - num_kv_heads); - dim3 nthrs(WARP_SIZE, NUM_WARPS_Q, NUM_WARPS_KV); - FI_GPU_CALL(gpuLaunchKernel((void *)kernel, nblks, - nthrs, args, smem_size, - stream)); - } - else { - // Use cooperative groups to increase occupancy - params.partition_kv = true; - float *tmp_lse = - (float *)(tmp + num_chunks * qo_len * num_qo_heads * - HEAD_DIM_VO); - auto o = params.o; - auto lse = params.lse; - params.o = tmp; - params.lse = tmp_lse; - void *args[] = {(void *)¶ms}; - dim3 nblks(ceil_div(qo_len * group_size, CTA_TILE_Q), - num_chunks, num_kv_heads); - dim3 nthrs(WARP_SIZE, NUM_WARPS_Q, NUM_WARPS_KV); - FI_GPU_CALL(gpuLaunchKernel((void *)kernel, nblks, - nthrs, args, smem_size, - stream)); - if constexpr (AttentionVariant::use_softmax) { - FI_GPU_CALL(MergeStates( - tmp, tmp_lse, o, lse, num_chunks, qo_len, - num_qo_heads, HEAD_DIM_VO, stream)); - } - else { - FI_GPU_CALL(AttentionSum(tmp, o, num_chunks, qo_len, - num_qo_heads, HEAD_DIM_VO, - stream)); - } - } - } - }) - }); - return gpuSuccess; -} - -} // namespace flashinfer - -#endif // FLASHINFER_PREFILL_CUH_ diff --git a/libflashinfer/utils/compute_qk_stub.cuh b/libflashinfer/utils/compute_qk_stub.cuh deleted file mode 100644 index 6c8da0f75e..0000000000 --- a/libflashinfer/utils/compute_qk_stub.cuh +++ /dev/null @@ -1,332 +0,0 @@ -// SPDX-FileCopyrightText: 2025 Advanced Micro Devices, Inc. -// SPDX-License-Identifier: Apache-2.0 - -#include "flashinfer/attention/generic/default_prefill_params.cuh" -#include "flashinfer/attention/generic/prefill.cuh" -#include "gpu_iface/gpu_runtime_compat.hpp" - -using namespace flashinfer; - -template -__device__ __forceinline__ void -ComputeQKStubKernelDevice(const Params params, - typename KTraits::SharedStorage &smem_storage, - float *qk_scores_output, - const dim3 tid = threadIdx, - const uint32_t bx = blockIdx.x, - const uint32_t chunk_idx = blockIdx.y, - const uint32_t kv_head_idx = blockIdx.z, - const uint32_t num_chunks = gridDim.y, - const uint32_t num_kv_heads = gridDim.z) -{ - using DTypeKV = typename Params::DTypeKV; - using DTypeQ = typename Params::DTypeQ; - using DTypeQKAccum = typename KTraits::DTypeQKAccum; - - [[maybe_unused]] constexpr uint32_t NUM_MMA_Q = KTraits::NUM_MMA_Q; - [[maybe_unused]] constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; - [[maybe_unused]] constexpr uint32_t NUM_MMA_D_QK = KTraits::NUM_MMA_D_QK; - [[maybe_unused]] constexpr uint32_t HEAD_DIM_QK = KTraits::HEAD_DIM_QK; - [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_Q = - KTraits::UPCAST_STRIDE_Q; - [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_K = - KTraits::UPCAST_STRIDE_K; - [[maybe_unused]] constexpr uint32_t CTA_TILE_Q = KTraits::CTA_TILE_Q; - [[maybe_unused]] constexpr uint32_t CTA_TILE_KV = KTraits::CTA_TILE_KV; - [[maybe_unused]] constexpr uint32_t NUM_WARPS_Q = KTraits::NUM_WARPS_Q; - [[maybe_unused]] constexpr uint32_t NUM_WARPS_KV = KTraits::NUM_WARPS_KV; - [[maybe_unused]] constexpr SwizzleMode SWIZZLE_MODE_Q = - KTraits::SWIZZLE_MODE_Q; - [[maybe_unused]] constexpr SwizzleMode SWIZZLE_MODE_KV = - KTraits::SWIZZLE_MODE_KV; - [[maybe_unused]] constexpr uint32_t KV_THR_LAYOUT_ROW = - KTraits::KV_THR_LAYOUT_ROW; - [[maybe_unused]] constexpr uint32_t KV_THR_LAYOUT_COL = - KTraits::KV_THR_LAYOUT_COL; - [[maybe_unused]] constexpr uint32_t HALF_ELEMS_PER_THREAD = - KTraits::HALF_ELEMS_PER_THREAD; - [[maybe_unused]] constexpr uint32_t VECTOR_BIT_WIDTH = - KTraits::VECTOR_BIT_WIDTH; - - DTypeQ *q = params.q; - DTypeKV *k = params.k; - - const uint32_t qo_len = params.qo_len; - const uint32_t kv_len = params.kv_len; - - const uint32_t q_stride_n = params.q_stride_n; - const uint32_t q_stride_h = params.q_stride_h; - const uint32_t k_stride_n = params.k_stride_n; - const uint32_t k_stride_h = params.k_stride_h; - const uint_fastdiv &group_size = params.group_size; - - static_assert(sizeof(DTypeQ) == 2); - const uint32_t lane_idx = tid.x, - warp_idx = get_warp_idx(tid.y, tid.z); - const uint32_t chunk_start = 0; - const uint32_t chunk_size = kv_len; - - auto block = cg::this_thread_block(); - DTypeQKAccum s_frag[NUM_MMA_Q][NUM_MMA_KV][HALF_ELEMS_PER_THREAD]; - - // cooperative fetch q fragment from gmem to reg - const uint32_t qo_packed_idx_base = - (bx * NUM_WARPS_Q + get_warp_idx_q(tid.y)) * NUM_MMA_Q * 16; - smem_t qo_smem( - smem_storage.q_smem); - DTypeQ *q_ptr_base = q + (kv_head_idx * group_size) * q_stride_h; - - uint32_t q_smem_offset_r = - qo_smem.template get_permuted_offset( - get_warp_idx_q(tid.y) * NUM_MMA_Q * 16 + lane_idx % 16, - lane_idx / 16); - - load_q_global_smem(qo_packed_idx_base, qo_len, q_ptr_base, - q_stride_n, q_stride_h, group_size, &qo_smem, - tid); - - memory::commit_group(); - smem_t k_smem( - smem_storage.k_smem); - DTypeKV *k_ptr = k + - (chunk_start + warp_idx * KV_THR_LAYOUT_ROW + - lane_idx / KV_THR_LAYOUT_COL) * - k_stride_n + - kv_head_idx * k_stride_h + - (lane_idx % KV_THR_LAYOUT_COL) * - upcast_size(); - - uint32_t k_smem_offset_r = - k_smem.template get_permuted_offset( - get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + - HALF_ELEMS_PER_THREAD * (lane_idx / 16) + - lane_idx % HALF_ELEMS_PER_THREAD, - (lane_idx % 16) / HALF_ELEMS_PER_THREAD), - k_smem_offset_w = - k_smem.template get_permuted_offset( - warp_idx * KV_THR_LAYOUT_ROW + - lane_idx / KV_THR_LAYOUT_COL, - lane_idx % KV_THR_LAYOUT_COL); - produce_kv( - k_smem, &k_smem_offset_w, &k_ptr, k_stride_n, 0, chunk_size, tid); - memory::commit_group(); - - memory::wait_group<1>(); - block.sync(); - // compute attention score - compute_qk(&qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, - s_frag); - memory::wait_group<0>(); - block.sync(); - - // Extract QK scores from s_frag to global memory - if (get_warp_idx_q(tid.y) == 0 && - get_warp_idx_kv(tid.z) == 0) - { - for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { - for (uint32_t mma_kv = 0; mma_kv < NUM_MMA_KV; ++mma_kv) { - for (uint32_t reg_id = 0; reg_id < HALF_ELEMS_PER_THREAD; - ++reg_id) - { - // Map from MMA fragment layout to sequence indices - - // CDNA3 mapping - uint32_t q_idx = - mma_q * 16 + - reg_id % KTraits::NUM_ACCUM_ROWS_PER_THREAD; - uint32_t kv_idx = - mma_kv * 16 + - 2 * (lane_idx % KTraits::THREADS_PER_MATRIX_ROW_SET) + - 8 * (reg_id / 2) + reg_id % 2; - - if (q_idx < qo_len && kv_idx < kv_len) { - // Match CPU layout: [qo_head_idx][q_idx][kv_idx] - uint32_t qo_head_idx = - kv_head_idx * - group_size; // Simple for single head case - uint32_t output_idx = qo_head_idx * qo_len * kv_len + - q_idx * kv_len + kv_idx; - qk_scores_output[output_idx] = - float(s_frag[mma_q][mma_kv][reg_id]); - } - } - } - } - } -} - -template -__global__ __launch_bounds__(KTraits::NUM_THREADS) void ComputeQKStubKernel( - const __grid_constant__ Params params, - float *qk_scores_output) -{ - extern __shared__ uint8_t smem[]; - auto &smem_storage = - reinterpret_cast(smem); - ComputeQKStubKernelDevice(params, smem_storage, qk_scores_output); -} - -template -gpuError_t ComputeQKStubDispatched(Params params, - typename Params::DTypeO *tmp, - float *qk_scores_output, - gpuStream_t stream) -{ - using DTypeQ = typename Params::DTypeQ; - using DTypeKV = typename Params::DTypeKV; - using DTypeO = typename Params::DTypeO; - const uint32_t num_qo_heads = params.num_qo_heads; - const uint32_t num_kv_heads = params.num_kv_heads; - const uint32_t qo_len = params.qo_len; - const uint32_t kv_len = params.kv_len; - if (kv_len < qo_len && MASK_MODE == MaskMode::kCausal) { - std::ostringstream err_msg; - err_msg << "When mask_mode is set to MaskMode::kCausal, kv_len must be " - "greater than or equal to qo_len, got kv_len" - << kv_len << " and qo_len " << qo_len; - FLASHINFER_ERROR(err_msg.str()); - } - - const uint32_t group_size = num_qo_heads / num_kv_heads; - constexpr uint32_t NUM_MMA_D_QK = HEAD_DIM_QK / 16; - constexpr uint32_t NUM_MMA_D_VO = HEAD_DIM_VO / 16; - int64_t packed_qo_len = qo_len * group_size; - uint32_t cta_tile_q = FA2DetermineCtaTileQ(packed_qo_len, HEAD_DIM_VO); - - DISPATCH_CTA_TILE_Q(cta_tile_q, CTA_TILE_Q, { - constexpr uint32_t NUM_WARPS_Q = get_num_warps_q(CTA_TILE_Q); - constexpr uint32_t NUM_WARPS_KV = get_num_warps_kv(CTA_TILE_Q); - constexpr uint32_t NUM_MMA_Q = get_num_mma_q(CTA_TILE_Q); - - using DTypeQKAccum = - typename std::conditional, - half, float>::type; - - int dev_id = 0; - FI_GPU_CALL(gpuGetDevice(&dev_id)); - int max_smem_per_sm = getMaxSharedMemPerMultiprocessor(dev_id); - // we expect each sm execute two threadblocks - const int num_ctas_per_sm = - max_smem_per_sm >= 2 * (CTA_TILE_Q * HEAD_DIM_QK * sizeof(DTypeQ) + - (HEAD_DIM_QK + HEAD_DIM_VO) * 16 * - NUM_WARPS_KV * sizeof(DTypeKV)) - ? 2 - : 1; - const int max_smem_per_threadblock = max_smem_per_sm / num_ctas_per_sm; - - const uint32_t max_num_mma_kv_reg = - (HEAD_DIM_VO >= 128 && NUM_MMA_Q == 2 && - POS_ENCODING_MODE == PosEncodingMode::kRoPELlama && - !USE_FP16_QK_REDUCTION) - ? 2 - : (8 / NUM_MMA_Q); - const uint32_t max_num_mma_kv_smem = - (max_smem_per_threadblock - - CTA_TILE_Q * HEAD_DIM_QK * sizeof(DTypeQ)) / - ((HEAD_DIM_QK + HEAD_DIM_VO) * 16 * NUM_WARPS_KV * sizeof(DTypeKV)); - - // control NUM_MMA_KV for maximum warp occupancy - DISPATCH_NUM_MMA_KV( - min(max_num_mma_kv_smem, max_num_mma_kv_reg), NUM_MMA_KV, { - using KTraits = - KernelTraits; - if constexpr (KTraits::IsInvalid()) { - // Invalid configuration, skip - std::ostringstream err_msg; - err_msg << "FlashInfer Internal Error: Invalid " - "configuration : NUM_MMA_Q=" - << NUM_MMA_Q << " NUM_MMA_D_QK=" << NUM_MMA_D_QK - << " NUM_MMA_D_VO=" << NUM_MMA_D_VO - << " NUM_MMA_KV=" << NUM_MMA_KV - << " NUM_WARPS_Q=" << NUM_WARPS_Q - << " NUM_WARPS_KV=" << NUM_WARPS_KV - << " please create an issue " - "(https://github.com/flashinfer-ai/flashinfer/" - "issues)" - " and report the issue to the developers."; - FLASHINFER_ERROR(err_msg.str()); - } - else { - constexpr uint32_t num_threads = - (NUM_WARPS_Q * NUM_WARPS_KV) * WARP_SIZE; - auto kernel = ComputeQKStubKernel; - size_t smem_size = sizeof(typename KTraits::SharedStorage); - FI_GPU_CALL(gpuFuncSetAttribute( - kernel, gpuFuncAttributeMaxDynamicSharedMemorySize, - smem_size)); - int num_blocks_per_sm = 0; - int num_sm = 0; - FI_GPU_CALL(gpuDeviceGetAttribute( - &num_sm, gpuDevAttrMultiProcessorCount, dev_id)); - FI_GPU_CALL(gpuOccupancyMaxActiveBlocksPerMultiprocessor( - &num_blocks_per_sm, kernel, num_threads, smem_size)); - uint32_t max_num_kv_chunks = - (num_blocks_per_sm * num_sm) / - (num_kv_heads * - ceil_div(qo_len * group_size, CTA_TILE_Q)); - uint32_t num_chunks; - if (max_num_kv_chunks > 0) { - uint32_t chunk_size = - max(ceil_div(kv_len, max_num_kv_chunks), 256); - num_chunks = ceil_div(kv_len, chunk_size); - } - else { - num_chunks = 0; - } - - if (num_chunks <= 1 || tmp == nullptr) { - // Enough parallelism, do not split-kv - params.partition_kv = false; - void *args[] = {(void *)¶ms, - (void *)&qk_scores_output}; - dim3 nblks(ceil_div(qo_len * group_size, CTA_TILE_Q), 1, - num_kv_heads); - dim3 nthrs(WARP_SIZE, NUM_WARPS_Q, NUM_WARPS_KV); - FI_GPU_CALL(gpuLaunchKernel((void *)kernel, nblks, - nthrs, args, smem_size, - stream)); - } - else { - // Use cooperative groups to increase occupancy - params.partition_kv = true; - float *tmp_lse = - (float *)(tmp + num_chunks * qo_len * num_qo_heads * - HEAD_DIM_VO); - auto o = params.o; - auto lse = params.lse; - params.o = tmp; - params.lse = tmp_lse; - void *args[] = {(void *)¶ms}; - dim3 nblks(ceil_div(qo_len * group_size, CTA_TILE_Q), - num_chunks, num_kv_heads); - dim3 nthrs(WARP_SIZE, NUM_WARPS_Q, NUM_WARPS_KV); - FI_GPU_CALL(gpuLaunchKernel((void *)kernel, nblks, - nthrs, args, smem_size, - stream)); - if constexpr (AttentionVariant::use_softmax) { - FI_GPU_CALL(MergeStates( - tmp, tmp_lse, o, lse, num_chunks, qo_len, - num_qo_heads, HEAD_DIM_VO, stream)); - } - else { - FI_GPU_CALL(AttentionSum(tmp, o, num_chunks, qo_len, - num_qo_heads, HEAD_DIM_VO, - stream)); - } - } - } - }) - }); - return gpuSuccess; -} From e0c8dc055b106d2c0b10b5a6ac1e6f333c9ebbc7 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Wed, 17 Sep 2025 12:18:57 -0400 Subject: [PATCH 088/109] Formatting --- .../include/flashinfer/attention/generic/prefill.cuh | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh index 86941c687c..ae88167dc1 100644 --- a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -1061,14 +1061,10 @@ __device__ __forceinline__ void update_mdo_states( m[mma_q][j] = max(m[mma_q][j], s_frag[mma_q][mma_kv][j]); } // Butterfly reduction across all threads in the band - m[mma_q][j] = - max(m[mma_q][j], gpu_iface::math::shfl_xor_sync(m[mma_q][j], 0x8)); // 16 apart - m[mma_q][j] = - max(m[mma_q][j], gpu_iface::math::shfl_xor_sync(m[mma_q][j], 0x4)); // 8 apart - m[mma_q][j] = - max(m[mma_q][j], gpu_iface::math::shfl_xor_sync(m[mma_q][j], 0x2)); // 4 apart - m[mma_q][j] = - max(m[mma_q][j], gpu_iface::math::shfl_xor_sync(m[mma_q][j], 0x1)); // 2 apart + m[mma_q][j] = max(m[mma_q][j], gpu_iface::math::shfl_xor_sync(m[mma_q][j], 0x8)); + m[mma_q][j] = max(m[mma_q][j], gpu_iface::math::shfl_xor_sync(m[mma_q][j], 0x4)); + m[mma_q][j] = max(m[mma_q][j], gpu_iface::math::shfl_xor_sync(m[mma_q][j], 0x2)); + m[mma_q][j] = max(m[mma_q][j], gpu_iface::math::shfl_xor_sync(m[mma_q][j], 0x1)); float o_scale = gpu_iface::math::ptx_exp2(m_prev * sm_scale - m[mma_q][j] * sm_scale); // Scale output fragments for this specific row From 57dac701c6685966637c0d4b7582c488393a8865 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Fri, 26 Sep 2025 11:17:44 -0400 Subject: [PATCH 089/109] Fix fragment loading --- .../flashinfer/attention/generic/prefill.cuh | 107 ++++++--- .../include/gpu_iface/backend/hip/mma_hip.h | 4 +- .../hip/test_inplace_transpose_loads.cpp | 211 ++++++++++++++++++ .../tests/hip/test_layout_transform.cpp | 105 +++++++++ .../tests/hip/test_single_prefill.cpp | 24 +- 5 files changed, 399 insertions(+), 52 deletions(-) create mode 100644 libflashinfer/tests/hip/test_inplace_transpose_loads.cpp create mode 100644 libflashinfer/tests/hip/test_layout_transform.cpp diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh index ae88167dc1..4dabe61fb8 100644 --- a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -394,10 +394,8 @@ __device__ __forceinline__ void produce_kv_impl_cdna3_( // NOTE: NUM_MMA_KV*4/NUM_WARPS_Q = NUM_WARPS_KV*NUM_MMA_KV*4/num_warps static_assert(NUM_MMA_KV * 4 % NUM_WARPS_Q == 0); - uint32_t kv_idx = kv_idx_base + warp_idx * 4 + lane_idx / KV_THR_LAYOUT_COL; - // NOTE: NUM_MMA_KV * 4 / NUM_WARPS_Q = NUM_WARPS_KV*NUM_MMA_KV*4/num_warps - static_assert(NUM_MMA_KV * 4 % NUM_WARPS_Q == 0); + #pragma unroll for (uint32_t i = 0; i < NUM_MMA_KV * 4 / NUM_WARPS_Q; ++i) { #pragma unroll @@ -565,7 +563,7 @@ __device__ __forceinline__ void init_states( for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { #pragma unroll for (uint32_t reg_id = 0; reg_id < KTraits::HALF_ELEMS_PER_THREAD; ++reg_id) { - o_frag[mma_q][mma_d][reg_id] = 1.f; + o_frag[mma_q][mma_d][reg_id] = 0.f; } } } @@ -576,7 +574,7 @@ __device__ __forceinline__ void init_states( #pragma unroll for (uint32_t j = 0; j < NUM_ACCUM_ROWS_PER_THREAD; ++j) { m[mma_q][j] = typename KTraits::DTypeQKAccum(-gpu_iface::math::inf); - d[mma_q][j] = 0.f; + d[mma_q][j] = 1.f; } } } @@ -1065,15 +1063,21 @@ __device__ __forceinline__ void update_mdo_states( m[mma_q][j] = max(m[mma_q][j], gpu_iface::math::shfl_xor_sync(m[mma_q][j], 0x4)); m[mma_q][j] = max(m[mma_q][j], gpu_iface::math::shfl_xor_sync(m[mma_q][j], 0x2)); m[mma_q][j] = max(m[mma_q][j], gpu_iface::math::shfl_xor_sync(m[mma_q][j], 0x1)); + float o_scale = gpu_iface::math::ptx_exp2(m_prev * sm_scale - m[mma_q][j] * sm_scale); + d[mma_q][j] *= o_scale; - // Scale output fragments for this specific row +#if Debug + if (warp_idx == 0 && lane_idx == 0) { + printf("Max value %f, m_prev %f, o_scale %f, d %f\n", m[mma_q][j], m_prev, o_scale, + float(d[mma_q][j])); + printf("-------------\n"); + } +#endif #pragma unroll for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { o_frag[mma_q][mma_d][j] *= o_scale; } - - // Convert logits to probabilities for this row #pragma unroll for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { s_frag[mma_q][mma_kv][j] = gpu_iface::math::ptx_exp2( @@ -1082,13 +1086,14 @@ __device__ __forceinline__ void update_mdo_states( #elif (PLATFORM_CUDA_DEVICE) #pragma unroll for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { - auto m_local = max(max(s_frag[mma_q][mma_kv][0], s_frag[mma_q][mma_kv][1]), - max(s_frag[mma_q][mma_kv][2], s_frag[mma_q][mma_kv][3])); + float m_local = + max(max(s_frag[mma_q][mma_kv][j * 2 + 0], s_frag[mma_q][mma_kv][j * 2 + 1]), + max(s_frag[mma_q][mma_kv][j * 2 + 4], s_frag[mma_q][mma_kv][j * 2 + 5])); m[mma_q][j] = max(m[mma_q][j], m_local); } - m[mma_q][j] = max(m[mma_q][j], gpu_iface::math::shfl_xor_sync(m[mma_q][j], 0x2)); m[mma_q][j] = max(m[mma_q][j], gpu_iface::math::shfl_xor_sync(m[mma_q][j], 0x1)); + float o_scale = gpu_iface::math::ptx_exp2(m_prev * sm_scale - m[mma_q][j] * sm_scale); d[mma_q][j] *= o_scale; #pragma unroll @@ -1111,6 +1116,13 @@ __device__ __forceinline__ void update_mdo_states( } #endif } +#if Debug1 + if (warp_idx == 0 && lane_idx == 0) { + printf("d[0] %f d[1] %f d[2] %f d[3]%f\n", float(d[mma_q][0]), float(d[mma_q][0]), + float(d[mma_q][0]), float(d[mma_q][0])); + printf("-------------\n"); + } +#endif } } else if constexpr (std::is_same_v) { #if defined(PLATFORM_HIP_DEVICE) @@ -1167,14 +1179,15 @@ __device__ __forceinline__ void compute_sfm_v( uint32_t* v_smem_offset_r, typename KTraits::DTypeQKAccum (*s_frag)[KTraits::NUM_MMA_KV][KTraits::HALF_ELEMS_PER_THREAD], float (*o_frag)[KTraits::NUM_MMA_D_VO][KTraits::HALF_ELEMS_PER_THREAD], - float (*d)[KTraits::NUM_ACCUM_ROWS_PER_THREAD]) { + float (*d)[KTraits::NUM_ACCUM_ROWS_PER_THREAD], const dim3 tid = threadIdx, + uint32_t debug_warp_idx = 0, uint32_t debug_lane_idx = 0) { constexpr uint32_t UPCAST_STRIDE_V = KTraits::UPCAST_STRIDE_V; constexpr uint32_t HALF_ELEMS_PER_THREAD = KTraits::HALF_ELEMS_PER_THREAD; constexpr uint32_t INT32_ELEMS_PER_THREAD = KTraits::INT32_ELEMS_PER_THREAD; constexpr uint32_t V_SMEM_COLUMN_ADVANCE = 16 / KTraits::HALF_ELEMS_PER_THREAD; - typename KTraits::DTypeQ s_frag_f16[KTraits::NUM_MMA_Q][KTraits::NUM_MMA_KV] [HALF_ELEMS_PER_THREAD]; + if constexpr (std::is_same_v) { #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { @@ -1186,6 +1199,25 @@ __device__ __forceinline__ void compute_sfm_v( } } +#if Debug1 + // Debug the state of attention score matrix before rowsum to compute denom + constexpr uint32_t NUM_MMA_Q = KTraits::NUM_MMA_Q; + constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; + const uint32_t warp_idx = get_warp_idx(tid.y, tid.z), lane_idx = tid.x; + + // Write all thread's fragments to shared memory + for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { + for (uint32_t mma_kv = 0; mma_kv < NUM_MMA_KV; ++mma_kv) { + if (lane_idx == debug_lane_idx && warp_idx == debug_warp_idx) { + printf("%.6f %.6f %.6f %.6f\n", s_frag[mma_q][mma_kv][0], s_frag[mma_q][mma_kv][1], + s_frag[mma_q][mma_kv][2], s_frag[mma_q][mma_kv][3]); + } + } + } + __syncthreads(); + +#endif + if constexpr (KTraits::AttentionVariant::use_softmax) { #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { @@ -1202,6 +1234,13 @@ __device__ __forceinline__ void compute_sfm_v( #endif } } +#if Debug + if (debug_warp_idx == 0 && debug_lane_idx == 0) { + printf("After rowsum: d[0] %f d[1] %f d[2] %f d[3] %f\n", float(d[mma_q][0]), + float(d[mma_q][0]), float(d[mma_q][0]), float(d[mma_q][0])); + printf("-------------\n"); + } +#endif } } @@ -1318,8 +1357,7 @@ __device__ __forceinline__ void finalize_m( } /*! - * \brief Synchronize the states of the MDO kernel across the threadblock along - * threadIdx.z. + * \brief Synchronize the states of the MDO kernel across the threadblock along threadIdx.z. */ template __device__ __forceinline__ void threadblock_sync_mdo_states( @@ -1426,12 +1464,10 @@ __device__ __forceinline__ void threadblock_sync_mdo_states( #pragma unroll for (uint32_t reg_id = 0; reg_id < KTraits::HALF_ELEMS_PER_THREAD; ++reg_id) { #if defined(PLATFORM_HIP_DEVICE) - // CDNA3: Direct mapping - each reg_id corresponds - // to one accumulator row + // CDNA3: Direct mapping - each reg_id corresponds to one accumulator row o_new[reg_id] += oi[reg_id] * o_scale[reg_id][i]; #else - // CUDA: Grouped mapping - 2 elements per - // accumulator row + // CUDA: Grouped mapping - 2 elements per accumulator row o_new[reg_id] += oi[reg_id] * o_scale[(reg_id % 4) / 2][i]; #endif } @@ -1781,7 +1817,9 @@ __device__ __forceinline__ void SinglePrefillWithKVCacheDevice( (lane_idx % 16) / 8); #endif uint32_t v_smem_offset_r = v_smem.template get_permuted_offset( - get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + lane_idx % 16, lane_idx / 16), + get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + (lane_idx % 4) + + 4 * (lane_idx / 16), + lane_idx / 4), k_smem_offset_w = k_smem.template get_permuted_offset( warp_idx * KV_THR_LAYOUT_ROW + lane_idx / KV_THR_LAYOUT_COL, lane_idx % KV_THR_LAYOUT_COL), @@ -1871,7 +1909,6 @@ __device__ __forceinline__ void SinglePrefillWithKVCacheDevice( #pragma unroll 1 for (uint32_t iter = 0; iter < num_iterations; ++iter) { - // for (uint32_t iter = 0; iter < 1; ++iter) { memory::wait_group<1>(); block.sync(); @@ -1931,6 +1968,9 @@ __device__ __forceinline__ void SinglePrefillWithKVCacheDevice( // compute attention score compute_qk(&qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag); #if Debug1 + if (params.debug_thread_id == lane_idx && params.debug_warp_id == warp_idx) { + printf("After compute_qk\n"); + } debug_write_sfrag_to_scratch(s_frag, tid, params.debug_thread_id, params.debug_warp_id); #endif @@ -1940,13 +1980,15 @@ __device__ __forceinline__ void SinglePrefillWithKVCacheDevice( qo_len, kv_len, group_size, s_frag, tid, kv_head_idx); #if Debug1 + if (params.debug_thread_id == lane_idx && params.debug_warp_id == warp_idx) { + printf("params.sm_scale %f, params.logits_soft_cap %f\n", params.sm_scale, + params.logits_soft_cap); + printf("After logits_transform\n"); + } debug_write_sfrag_to_scratch(s_frag, tid, params.debug_thread_id, params.debug_warp_id); #endif -#if Debug1 - debug_write_sfrag_to_scratch(s_frag, &scratch, tid); -#endif // apply mask if (MASK_MODE == MaskMode::kCustom || (iter >= mask_iteration || iter < window_iteration)) { logits_mask( @@ -1956,6 +1998,9 @@ __device__ __forceinline__ void SinglePrefillWithKVCacheDevice( } #if Debug1 + // if(params.debug_thread_id == lane_idx && params.debug_warp_id == warp_idx) { + // printf("Before update_mdo_states\n"); + // } debug_write_sfrag_to_scratch(s_frag, tid, params.debug_thread_id, params.debug_warp_id); #endif @@ -1975,18 +2020,8 @@ __device__ __forceinline__ void SinglePrefillWithKVCacheDevice( block.sync(); // compute sfm*v - compute_sfm_v(&v_smem, &v_smem_offset_r, s_frag, o_frag, d); -#if Debug - if (lane_idx == params.debug_thread_id && warp_idx == params.debug_warp_id) { - for (auto mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { - printf("%f\n", d[mma_q][0]); - printf("%f\n", d[mma_q][1]); - printf("%f\n", d[mma_q][2]); - printf("%f\n", d[mma_q][3]); - } - } -#endif - + compute_sfm_v(&v_smem, &v_smem_offset_r, s_frag, o_frag, d, tid, + params.debug_warp_id, params.debug_thread_id); block.sync(); produce_kv( v_smem, &v_smem_offset_w, &v_ptr, v_stride_n, (iter + 1) * CTA_TILE_KV, chunk_size, tid); diff --git a/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h b/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h index 52229f70fe..6aa6935d14 100644 --- a/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h +++ b/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h @@ -77,8 +77,8 @@ __device__ __forceinline__ void load_fragment(uint32_t* R, const T* smem_ptr) { const uint16_t* v2 = reinterpret_cast(++smem_ptr); const uint16_t* v3 = reinterpret_cast(++smem_ptr); - R[0] = (static_cast(*v0) << 16) | static_cast(*v1); - R[1] = (static_cast(*v2) << 16) | static_cast(*v3); + R[0] = (static_cast(*v1) << 16) | static_cast(*v0); + R[1] = (static_cast(*v3) << 16) | static_cast(*v2); } template diff --git a/libflashinfer/tests/hip/test_inplace_transpose_loads.cpp b/libflashinfer/tests/hip/test_inplace_transpose_loads.cpp new file mode 100644 index 0000000000..98ee9aedf0 --- /dev/null +++ b/libflashinfer/tests/hip/test_inplace_transpose_loads.cpp @@ -0,0 +1,211 @@ +/// 1. Allocate a 128x64 memory on CPU and init lexicographically to represent a +/// 128X64 matrix. +/// 2. Copy CPU array to global memory. +// 3 Copy global memory into LDS using produce_kv function. The LDS should +// also of be 128x64 elements +/// 4. Call transpose kernel that inplace transposes the LDS 128x64 matrix into +// a 64x128 matrix. Each warp handles multiple blocks of 16x16 chunks +/// 5. Post transposition copy back the 128x64 LDS linear memory to global and +/// then back to CPU. +/// 6. Evaluate the output is same as the transpose of the original array. + +#include +#include + +#include +#include + +#include "flashinfer/attention/generic/permuted_smem.cuh" +#include "flashinfer/attention/generic/prefill.cuh" +#include "gpu_iface/backend/hip/mma_hip.h" +#include "gpu_iface/gpu_runtime_compat.hpp" + +using namespace flashinfer; + +namespace { + +// Define matrix dimensions for the test +constexpr int MATRIX_ROWS = 128; +constexpr int MATRIX_COLS = 64; +constexpr uint32_t KV_THR_LAYOUT_ROW = 4; +constexpr uint32_t KV_THR_LAYOUT_COL = 16; +constexpr uint32_t NUM_WARPS = 4; +constexpr uint32_t NUM_MMA_KV = MATRIX_ROWS / 16; +constexpr uint32_t NUM_WARPS_Q = MATRIX_COLS / 16; +constexpr uint32_t NUM_MMA_D = 4; +constexpr uint32_t UPCAST_STRIDE = 64; +constexpr uint32_t VECTOR_BIT_WIDTH = 64; +constexpr uint32_t CTA_TILE_KV = NUM_MMA_KV * 4 * 16; + +using DTypeKV = __half; + +template +__device__ __forceinline__ void load_matrix_global_to_smem(uint32_t warp_idx, uint32_t lane_idx, + smem_t smem, + uint32_t* smem_offset, DTypeKV** gptr, + const uint32_t stride_n, + const uint32_t kv_idx_base, + const uint32_t kv_len) { + static_assert(NUM_MMA_KV * 4 % NUM_WARPS_Q == 0); + + uint32_t kv_idx = kv_idx_base + warp_idx * 4 + lane_idx / KV_THR_LAYOUT_ROW; + +#pragma unroll + for (uint32_t i = 0; i < NUM_MMA_KV * 4 / NUM_WARPS_Q; ++i) { +#pragma unroll + for (uint32_t j = 0; j < NUM_MMA_D / (8 / sizeof(DTypeKV)); ++j) { + smem.template load_vector_async(*smem_offset, *gptr, kv_idx < kv_len); + *smem_offset = smem.template advance_offset_by_column<16>(*smem_offset, j); + *gptr += 16 * upcast_size(); + } + kv_idx += NUM_WARPS * 4; + *smem_offset = smem.template advance_offset_by_row(*smem_offset) - + (sizeof(DTypeKV) * NUM_MMA_D * 2); + *gptr += NUM_WARPS * 4 * stride_n - + sizeof(DTypeKV) * NUM_MMA_D * 2 * upcast_size(); + } + *smem_offset -= CTA_TILE_KV * UPCAST_STRIDE; +} + +} // namespace + +// Helper to initialize matrix with lexicographic values +void initMatrixLexicographic(half* matrix, int rows, int cols) { + for (int i = 0; i < rows; ++i) { + for (int j = 0; j < cols; ++j) { + matrix[i * cols + j] = static_cast(i * cols + j); + } + } +} + +// Helper to transpose a matrix on CPU (for verification) +void transposeMatrixCPU(half* input, half* output, int rows, int cols) { + for (int i = 0; i < rows; ++i) { + for (int j = 0; j < cols; ++j) { + output[j * rows + i] = input[i * cols + j]; + } + } +} + +// Helper to print a matrix section (for debugging) +void printMatrixSection(half* matrix, int rows, int cols, const char* name) { + std::cout << "Matrix " << name << " (" << rows << "x" << cols << "):" << std::endl; + for (int i = 0; i < std::min(rows, 8); ++i) { + for (int j = 0; j < std::min(cols, 8); ++j) { + std::cout << static_cast(matrix[i * cols + j]) << " "; + } + std::cout << (cols > 8 ? "..." : "") << std::endl; + } + if (rows > 8) std::cout << "..." << std::endl; +} + +// Kernel to load the matrix from global to shared memory using produce_kv +__device__ __forceinline__ void loadGlobalToSharedKernel(__half* input, + smem_t v_smem, + int rows, int cols) { + const uint32_t tid = threadIdx.x; + const uint32_t lane_idx = tid % 64; + const uint32_t warp_idx = tid / 64; + + uint32_t smem_offset = + v_smem.template get_permuted_offset<64>(warp_idx * 4 + lane_idx / 16, lane_idx % 16); + + DTypeKV* input_ptr = input + (warp_idx * KV_THR_LAYOUT_ROW + lane_idx / KV_THR_LAYOUT_COL) * 64 + + +(lane_idx % KV_THR_LAYOUT_COL) * upcast_size(); + + // Load global memory to shared memory collaboratively + load_matrix_global_to_smem(warp_idx, lane_idx, v_smem, &smem_offset, + &input_ptr, cols, 0, rows); + + __syncthreads(); + + if (tid == 0) { + printf("\n DEBUG LDS loaded from global\n"); + auto hMem = reinterpret_cast<__half*>(v_smem.base); + uint32_t offset_r_debug; + // for (auto i = 0; i < rows; ++i) { + for (auto j = 0; j < 256; ++j) { + printf("%f ", float(hMem[j])); + } + printf("\n"); + //} + } + + // TODO: Store shared memory back to global memory for verification +} + +// Kernel to transpose shared memory in-place +__global__ void transposeSharedMemoryKernel(half* input, half* output, int rows, int cols) { + // Define shared memory for the matrix + extern __shared__ half shared_mem[]; + smem_t v_smem(shared_mem); + + // TODO: Load data from global to shared memory + loadGlobalToSharedKernel(input, v_smem, rows, cols); + + __syncthreads(); + + // TODO: Call transpose_4x4_half_registers to transpose in-place + + __syncthreads(); + + // TODO: Copy transposed data back to global memory +} + +TEST(InplaceTransposeTest, TestTransposeLDS) { + // 1. Allocate a 128x64 memory on CPU and init lexicographically + std::vector h_input(MATRIX_ROWS * MATRIX_COLS); + std::vector h_output(MATRIX_COLS * MATRIX_ROWS); + std::vector h_expected(MATRIX_COLS * MATRIX_ROWS); + + initMatrixLexicographic(h_input.data(), MATRIX_ROWS, MATRIX_COLS); + + for (auto i = 0; i < 32; ++i) { + std::cout << float(h_input[i]) << " "; + } + std::cout << std::endl; + + transposeMatrixCPU(h_input.data(), h_expected.data(), MATRIX_ROWS, MATRIX_COLS); + + // 2. Copy CPU array to global memory + half *d_input, *d_output; + FI_GPU_CALL(hipMalloc(&d_input, h_input.size() * sizeof(half))); + FI_GPU_CALL(hipMalloc(&d_output, h_output.size() * sizeof(half))); + FI_GPU_CALL( + hipMemcpy(d_input, h_input.data(), h_input.size() * sizeof(half), hipMemcpyHostToDevice)); + + // 3 & 4. Load into shared memory and transpose in-place + const int blockSize = 256; + const int gridSize = 1; + size_t sharedMemSize = MATRIX_ROWS * MATRIX_COLS * sizeof(half); + + // Single wave of four wavefronts + transposeSharedMemoryKernel<<>>(d_input, d_output, + MATRIX_ROWS, MATRIX_COLS); + + // 5. Copy back to CPU + FI_GPU_CALL( + hipMemcpy(h_output.data(), d_output, h_output.size() * sizeof(half), hipMemcpyDeviceToHost)); + + // 6. Verify the output matches the transpose of the original array + bool all_match = true; + for (int i = 0; i < MATRIX_COLS * MATRIX_ROWS; ++i) { + if (static_cast(h_output[i]) != static_cast(h_expected[i])) { + std::cout << "Mismatch at index " << i << ": " << static_cast(h_output[i]) << " vs " + << static_cast(h_expected[i]) << std::endl; + all_match = false; + if (i > 10) break; // Limit output + } + } + + EXPECT_TRUE(all_match) << "Transposed matrix doesn't match expected result"; + + // Clean up + FI_GPU_CALL(hipFree(d_input)); + FI_GPU_CALL(hipFree(d_output)); +} + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/libflashinfer/tests/hip/test_layout_transform.cpp b/libflashinfer/tests/hip/test_layout_transform.cpp new file mode 100644 index 0000000000..ae87562715 --- /dev/null +++ b/libflashinfer/tests/hip/test_layout_transform.cpp @@ -0,0 +1,105 @@ +#include +#include + +#include "gpu_iface/backend/hip/mma_hip.h" + +// Define test dimensions +constexpr int MATRIX_SIZE_X = 16; +constexpr int MATRIX_SIZE_Y = 16; + +namespace { +// Print register values for debugging +__device__ void print_register(uint32_t* R) { + auto values = reinterpret_cast<__half*>(R); + printf("[%5.1f %5.1f %5.1f %5.1f]\n", __half2float(values[0]), __half2float(values[1]), + __half2float(values[2]), __half2float(values[3])); +} + +// Initialize LDS array with lexicographic values +__device__ void init_lds_array(half* lds_array) { + const int tid = threadIdx.x; + if (tid == 0) { + for (int y = 0; y < MATRIX_SIZE_Y; ++y) { + for (int x = 0; x < MATRIX_SIZE_X; ++x) { + lds_array[y * MATRIX_SIZE_X + x] = __half(y * MATRIX_SIZE_X + x); + } + } + } + __syncthreads(); +} + +// Each thread loads 4 elements in A-matrix layout +__device__ void load_amatrix_layout(half* lds_array, uint32_t* R) { + const int tid = threadIdx.x; + const int lane_id = tid % 64; + const int row = lane_id % 16; + const int col_start = (lane_id / 16) * 4; + + auto offset = lds_array + row * MATRIX_SIZE_X + col_start; + + if (tid == 0) { + printf("DEBUG:::: %f %f %f %f\n", __half2float(*offset), __half2float(*(offset + 1)), + __half2float(*(offset + 2)), __half2float(*(offset + 3))); + } + + flashinfer::gpu_iface::mma_impl::hip::load_fragment<__half>(R, offset); + + if (tid == 0) { + print_register(R); + } +} + +// Print LDS array using one thread +__device__ void print_lds_array(half* lds_array) { + if (threadIdx.x == 0) { + printf("LDS Array (%dx%d):\n", MATRIX_SIZE_X, MATRIX_SIZE_Y); + for (int y = 0; y < MATRIX_SIZE_Y; ++y) { + for (int x = 0; x < MATRIX_SIZE_X; ++x) { + printf("%5.1f ", __half2float(lds_array[y * MATRIX_SIZE_X + x])); + } + printf("\n"); + } + printf("\n"); + } + __syncthreads(); +} + +} // namespace + +__global__ void test_mini_tile_transpose_kernel() { + // Allocate shared memory for the 16x16 matrix + __shared__ half lds_array[MATRIX_SIZE_X * MATRIX_SIZE_Y]; + + // Step 1: Initialize the LDS array with lexicographic values + init_lds_array(lds_array); + + // Step 2: Print the LDS array (for debugging) + print_lds_array(lds_array); + + // Step 3: Load data from LDS to registers in A-matrix layout + uint32_t registers[2]; + load_amatrix_layout(lds_array, registers); + + // Step 4: Print initial register values for verification + __syncthreads(); + // print_registers(registers, "Before transpose"); + + // Step 5: Apply transpose to convert from A-matrix to B/C-matrix layout + // flashinfer::gpu_iface::mma_impl::hip::transpose_4x4_half_registers(registers); + + // Step 6: Print transposed register values + __syncthreads(); + // print_registers(registers, "After transpose"); +} + +// Host code to launch the kernel +void test_mini_tile_transpose() { + // Launch with 1 block of 64 threads (full warp for CDNA3) + test_mini_tile_transpose_kernel<<<1, 64>>>(); + hipDeviceSynchronize(); +} + +int main() { + test_mini_tile_transpose(); + return 0; +} diff --git a/libflashinfer/tests/hip/test_single_prefill.cpp b/libflashinfer/tests/hip/test_single_prefill.cpp index ce747f98ab..d91d4c7ca4 100644 --- a/libflashinfer/tests/hip/test_single_prefill.cpp +++ b/libflashinfer/tests/hip/test_single_prefill.cpp @@ -264,20 +264,16 @@ void _TestSinglePrefillKernelCorrectness(size_t qo_len, size_t kv_len, size_t nu // for(auto i: att_out) { // std::cout << i << "\n"; // } -#if 0 - float result_accuracy = - 1. - float(num_results_error_atol) / float(o_ref.size()); - std::cout << "num_qo_heads=" << num_qo_heads - << ", num_kv_heads=" << num_kv_heads << ", qo_len=" << qo_len - << ", kv_len=" << kv_len << ", head_dim=" << head_dim - << ", causal=" << causal - << ", kv_layout=" << QKVLayoutToString(kv_layout) - << ", pos_encoding_mode=" - << PosEncodingModeToString(pos_encoding_mode) - << ", result_accuracy=" << result_accuracy << std::endl; - - EXPECT_GT(result_accuracy, 0.90) << "Result correctness test failed."; - EXPECT_FALSE(nan_detected) << "Nan detected in the result."; +#if 1 + float result_accuracy = 1. - float(num_results_error_atol) / float(o_ref.size()); + std::cout << "num_qo_heads=" << num_qo_heads << ", num_kv_heads=" << num_kv_heads + << ", qo_len=" << qo_len << ", kv_len=" << kv_len << ", head_dim=" << head_dim + << ", causal=" << causal << ", kv_layout=" << QKVLayoutToString(kv_layout) + << ", pos_encoding_mode=" << PosEncodingModeToString(pos_encoding_mode) + << ", result_accuracy=" << result_accuracy << std::endl; + + EXPECT_GT(result_accuracy, 0.90) << "Result correctness test failed."; + EXPECT_FALSE(nan_detected) << "Nan detected in the result."; #endif FI_GPU_CALL(hipFree(q_d)); FI_GPU_CALL(hipFree(k_d)); From c9b2d8364fdb3311a67f2fa51a33ffecba5228ba Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Fri, 26 Sep 2025 17:56:10 -0400 Subject: [PATCH 090/109] Fixes --- .../include/gpu_iface/backend/hip/mma_hip.h | 32 +++++++++---------- .../tests/hip/test_layout_transform.cpp | 10 +++--- 2 files changed, 21 insertions(+), 21 deletions(-) diff --git a/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h b/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h index 6aa6935d14..1b58697084 100644 --- a/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h +++ b/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h @@ -37,34 +37,32 @@ __device__ __forceinline__ void transpose_4x4_half_registers(uint32_t* R) { uint32_t lane_in_group = lane_id % 4; // === ROUND 1: Exchange with neighbor (XOR with 1) === - // T0↔T1, T2↔T3 partial exchange - uint32_t reg_idx = (lane_in_group >> 1) & 0x1; - uint32_t exchanged_val = __shfl_xor(R[reg_idx], 0x1); + // T0 <-> T1, T2 <-> T3 partial exchange + uint32_t regid = (lane_in_group >> 1) & 0x1; + uint32_t exchanged_val = __shfl_xor(R[regid], 0x1); uint32_t shift = (lane_in_group & 1) * 16; - uint32_t keep_mask = 0xFFFF0000 >> shift; - int right_shift_amount = 16 * (1 - (lane_in_group & 1)); - int left_shift_amount = 16 * (lane_in_group & 1); - R[reg_idx] = - (R[reg_idx] & keep_mask) | ((exchanged_val >> right_shift_amount) << left_shift_amount); + uint32_t keep_mask = 0x0000FFFF << shift; + int left_shift_amount = 16 * (1 - (lane_in_group & 1)); + int right_shift_amount = 16 * (lane_in_group & 1); + R[regid] = (R[regid] & keep_mask) | ((exchanged_val >> right_shift_amount) << left_shift_amount); // === ROUND 2: Exchange with one hop (XOR with 2) === - // T0↔T2, T1↔T3 exchange R[0] and R[1] + // T0 <-> T2, T1 <-> T3 exchange R[0] and R[1] // Swap entire registers based on thread position - uint32_t is_top = 1 - reg_idx; + uint32_t is_top = 1 - regid; uint32_t temp0 = __shfl_xor(R[0], 0x2); uint32_t temp1 = __shfl_xor(R[1], 0x2); // Compute both possibilities and select - R[0] = R[0] * is_top + temp1 * reg_idx; - R[1] = temp0 * is_top + R[1] * reg_idx; + R[0] = R[0] * is_top + temp1 * regid; + R[1] = temp0 * is_top + R[1] * regid; // === ROUND 3: Exchange with neighbor again (XOR with 1) === - // T0↔T1, T2↔T3 exchange remaining parts + // T0 <-> T1, T2 <-> T3 exchange remaining parts - reg_idx = 1 - reg_idx; - exchanged_val = __shfl_xor(R[reg_idx], 0x1); - R[reg_idx] = - (R[reg_idx] & keep_mask) | ((exchanged_val >> right_shift_amount) << left_shift_amount); + regid = 1 - regid; + exchanged_val = __shfl_xor(R[regid], 0x1); + R[regid] = (R[regid] & keep_mask) | ((exchanged_val >> right_shift_amount) << left_shift_amount); } // Single unified load function for all fragment types diff --git a/libflashinfer/tests/hip/test_layout_transform.cpp b/libflashinfer/tests/hip/test_layout_transform.cpp index ae87562715..1646fbf7da 100644 --- a/libflashinfer/tests/hip/test_layout_transform.cpp +++ b/libflashinfer/tests/hip/test_layout_transform.cpp @@ -42,7 +42,7 @@ __device__ void load_amatrix_layout(half* lds_array, uint32_t* R) { __half2float(*(offset + 2)), __half2float(*(offset + 3))); } - flashinfer::gpu_iface::mma_impl::hip::load_fragment<__half>(R, offset); + flashinfer::gpu_iface::mma_impl::hip::load_fragment(R, offset); if (tid == 0) { print_register(R); @@ -82,14 +82,16 @@ __global__ void test_mini_tile_transpose_kernel() { // Step 4: Print initial register values for verification __syncthreads(); - // print_registers(registers, "Before transpose"); // Step 5: Apply transpose to convert from A-matrix to B/C-matrix layout - // flashinfer::gpu_iface::mma_impl::hip::transpose_4x4_half_registers(registers); + flashinfer::gpu_iface::mma_impl::hip::transpose_4x4_half_registers(registers); // Step 6: Print transposed register values __syncthreads(); - // print_registers(registers, "After transpose"); + if (threadIdx.x == 0) { + printf("After Transpose\n"); + print_lds_array(lds_array); + } } // Host code to launch the kernel From 25f40d991ba73f311e64a71cf3223fac696978c1 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Sun, 28 Sep 2025 09:52:51 -0400 Subject: [PATCH 091/109] WIP --- .../include/gpu_iface/backend/hip/mma_hip.h | 40 ++++++++----------- .../tests/hip/test_layout_transform.cpp | 15 ++++++- .../tests/hip/test_mfma_fp32_16x16x16fp16.cpp | 32 +++++++-------- 3 files changed, 46 insertions(+), 41 deletions(-) diff --git a/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h b/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h index 1b58697084..1dddb31eff 100644 --- a/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h +++ b/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h @@ -70,13 +70,8 @@ __device__ __forceinline__ void transpose_4x4_half_registers(uint32_t* R) { /// @param smem_ptr [in] pointer to the shared memory to load the fragment from template __device__ __forceinline__ void load_fragment(uint32_t* R, const T* smem_ptr) { - const uint16_t* v0 = reinterpret_cast(smem_ptr) + 0; - const uint16_t* v1 = reinterpret_cast(++smem_ptr); - const uint16_t* v2 = reinterpret_cast(++smem_ptr); - const uint16_t* v3 = reinterpret_cast(++smem_ptr); - - R[0] = (static_cast(*v1) << 16) | static_cast(*v0); - R[1] = (static_cast(*v3) << 16) | static_cast(*v2); + R[0] = reinterpret_cast(smem_ptr)[0]; + R[1] = reinterpret_cast(smem_ptr)[1]; } template @@ -119,27 +114,26 @@ __device__ __forceinline__ void mma_sync_m16n16k16_row_col_f16f16f32(float* C, u C[3] = C_fp32[3]; } -/// Loads a fragment from LDS to two 32bit registers and then transposes +/// @brief Loads a fragment from LDS to two 32bit registers and then transposes /// the registers for a group of four consecuitive threads. +/// +/// transposes the values in four adjacent threads. The function does the +/// following layout transformation: +/// Original data in registers for Threads 0-3 after fragment load +/// T0 : a b c d +/// T1 : e f g h +/// T2 : i j k l +/// T3 : m n o p +/// +/// After transposition: +/// T0 : a e i m +/// T1 : b f j n +/// T2 : c g k o +/// T3 : d h l p template __device__ __forceinline__ void load_fragment_4x4_half_registers(uint32_t* R, const T* smem_ptr) { static_assert(std::is_same_v, "Only half type is supported"); - // Each thread loads 4 __half values in two 32b registers. load_fragment(R, smem_ptr); - // transposes the values in four adjacent threads. The function does the - // following layout transformation: - // Original data in registers for Threads 0-3 after fragment load - // T0 : a b c d - // T1 : e f g h - // T2 : i j k l - // T3 : m n o p - // - // After transposition: - // T0 : a e i m - // T1 : b f j n - // T2 : c g k o - // T3 : d h l p - transpose_4x4_half_registers(R); } diff --git a/libflashinfer/tests/hip/test_layout_transform.cpp b/libflashinfer/tests/hip/test_layout_transform.cpp index 1646fbf7da..077233f8c4 100644 --- a/libflashinfer/tests/hip/test_layout_transform.cpp +++ b/libflashinfer/tests/hip/test_layout_transform.cpp @@ -89,8 +89,19 @@ __global__ void test_mini_tile_transpose_kernel() { // Step 6: Print transposed register values __syncthreads(); if (threadIdx.x == 0) { - printf("After Transpose\n"); - print_lds_array(lds_array); + print_register(registers); + } + __syncthreads(); + if (threadIdx.x == 1) { + print_register(registers); + } + __syncthreads(); + if (threadIdx.x == 2) { + print_register(registers); + } + __syncthreads(); + if (threadIdx.x == 3) { + print_register(registers); } } diff --git a/libflashinfer/tests/hip/test_mfma_fp32_16x16x16fp16.cpp b/libflashinfer/tests/hip/test_mfma_fp32_16x16x16fp16.cpp index 24eb5219af..4128c7f6cb 100644 --- a/libflashinfer/tests/hip/test_mfma_fp32_16x16x16fp16.cpp +++ b/libflashinfer/tests/hip/test_mfma_fp32_16x16x16fp16.cpp @@ -24,6 +24,18 @@ } \ } +namespace { + +__device__ void print_register(uint32_t* R) { + auto values = reinterpret_cast<__half*>(R); + printf("[%f %f %f %f]\n", __half2float(values[0]), __half2float(values[1]), + __half2float(values[2]), __half2float(values[3])); +} + +__device__ void print_register(float* R) { printf("[%f %f %f %f]\n", R[0], R[1], R[3], R[4]); } + +} // namespace + // Dimensions for our test matrices constexpr int M = 16; constexpr int N = 16; @@ -41,7 +53,6 @@ void gemm_reference(const __half* A, const __half* B, float* C, int M, int N, in for (int j = 0; j < N; ++j) { float acc = 0.0f; for (int k = 0; k < K; ++k) { - // Use __half_as_float to properly convert __half to float acc += __half2float(A[i * K + k]) * __half2float(B[k * N + j]); } C[i * N + j] = acc; @@ -54,26 +65,15 @@ __global__ void test_mfma_kernel(const __half* A, const __half* B, float* C) { uint32_t b_reg[2]; float c_reg[4] = {0.0f, 0.0f, 0.0f, 0.0f}; - // A Matrix is read row wise. Threads T0...T15 read Col 0...3 of Row 0...15 - // Threads T16...T31 read Col 4...7 of Row 0...15 - // Threads T32...T47 read Col 8...11 of Row 0...15 - // Threads T48...T63 read Col 12...15 of Row 0...15 - - // B Matrix is read column wise. Threads T0...T15 read Row 0...3 of Col - // 0...15 (Each thread reads 1 column per 4 rows) Threads T16...T31 read - // Row 4...7 of Col 0...15 Threads T32...T47 read Row 8...11 of Col 0...15 - // Threads T48...T63 read Row 12...15 of Col 0...15 - int a_idx = (threadIdx.x / 16) * 4 + threadIdx.x % 16 * LDA; - int b_idx = (threadIdx.x / 16) * LDB * 4 + threadIdx.x % 16; + int a_idx = (threadIdx.x % 16) * LDA + (threadIdx.x / 16) * 4; + int b_idx = ((threadIdx.x % 4) + 4 * (threadIdx.x / 16)) * LDB + ((threadIdx.x % 16) / 4) * 4; flashinfer::gpu_iface::mma::load_fragment<__half>(a_reg, &A[a_idx]); - flashinfer::gpu_iface::mma::load_fragment_transpose<__half>(b_reg, &B[b_idx], LDB); - + flashinfer::gpu_iface::mma::load_fragment_transpose_4x4_half_registers<__half>(b_reg, &B[b_idx]); flashinfer::gpu_iface::mma::mma_sync_m16n16k16_row_col_f16f16f32<__half>(c_reg, a_reg, b_reg); for (int i = 0; i < 4; ++i) { - const int d_idx = threadIdx.x % 16 + i * LDC + (threadIdx.x / 16) * 4 * LDC; - + int d_idx = ((threadIdx.x / 16) * 4 + i) * LDC + (threadIdx.x % 16); C[d_idx] = c_reg[i]; } } From 98bdf4bc220935f7ca8155497e55f3c7dc3220dc Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Sun, 28 Sep 2025 10:00:15 -0400 Subject: [PATCH 092/109] Remove redundant fuction --- .../include/gpu_iface/backend/hip/mma_hip.h | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h b/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h index 1dddb31eff..7429e40966 100644 --- a/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h +++ b/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h @@ -74,18 +74,6 @@ __device__ __forceinline__ void load_fragment(uint32_t* R, const T* smem_ptr) { R[1] = reinterpret_cast(smem_ptr)[1]; } -template -__device__ __forceinline__ void load_fragment_transpose(uint32_t* R, const T* smem_ptr, - uint32_t stride) { - const uint16_t* v0 = reinterpret_cast(smem_ptr) + 0; - const uint16_t* v1 = reinterpret_cast(smem_ptr + 1 * stride); - const uint16_t* v2 = reinterpret_cast(smem_ptr + 2 * stride); - const uint16_t* v3 = reinterpret_cast(smem_ptr + 3 * stride); - - R[0] = (static_cast(*v0) << 16) | static_cast(*v1); - R[1] = (static_cast(*v2) << 16) | static_cast(*v3); -} - // MMA operation for FP16 inputs with FP32 accumulator template __device__ __forceinline__ void mma_sync_m16n16k16_row_col_f16f16f32(float* C, uint32_t* A, From 55e048175afa15ecf25dd1786ceedd39df743428 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Wed, 1 Oct 2025 15:05:15 -0400 Subject: [PATCH 093/109] Precommit fixes --- aot_build_utils/generate.py | 11 +- aot_build_utils/generate_dispatch_inc.py | 6 +- .../generate_single_prefill_inst.py | 6 +- .../generate_single_prefill_sm90_inst.py | 6 +- cmake/utils/ConfigurePrebuitUris.cmake | 26 +- examples/cpp/standalone_single_prefill.cu | 1242 ++++++++--------- examples/test_batch_decode_example.py | 33 +- .../generic/default_prefill_params.cuh | 654 +++++---- .../flashinfer/attention/generic/dispatch.cuh | 432 +++--- scripts/run_hip_tests.sh | 1 - tests/test_batch_decode_kernels_hip.py | 7 +- tests/test_logits_cap_hip.py | 5 +- tests/test_non_contiguous_decode_hip.py | 4 +- tests/test_norm_hip.py | 1 - tests/test_rope.py | 1 - tests/test_sliding_window_hip.py | 4 +- 16 files changed, 1169 insertions(+), 1270 deletions(-) diff --git a/aot_build_utils/generate.py b/aot_build_utils/generate.py index 29ce92f56a..431ba48081 100644 --- a/aot_build_utils/generate.py +++ b/aot_build_utils/generate.py @@ -252,11 +252,14 @@ def write_if_different(path: Path, content: str) -> None: f"use_logits_cap_{logits_soft_cap}_" f"f16qk_{bool(use_fp16_qk_reduction)}" ) - final_list = single_decode_uris + batch_decode_uris + single_prefill_uris + batch_prefill_uris - print(final_list) - return ( - final_list + final_list = ( + single_decode_uris + + batch_decode_uris + + single_prefill_uris + + batch_prefill_uris ) + print(final_list) + return final_list if __name__ == "__main__": diff --git a/aot_build_utils/generate_dispatch_inc.py b/aot_build_utils/generate_dispatch_inc.py index 8a0cfa1432..e73a345c5a 100644 --- a/aot_build_utils/generate_dispatch_inc.py +++ b/aot_build_utils/generate_dispatch_inc.py @@ -17,7 +17,11 @@ import argparse from pathlib import Path -from .literal_map import bool_literal, mask_mode_literal, pos_encoding_mode_literal +from .literal_map import ( + bool_literal, + mask_mode_literal, + pos_encoding_mode_literal, +) def get_dispatch_inc_str(args: argparse.Namespace) -> str: diff --git a/aot_build_utils/generate_single_prefill_inst.py b/aot_build_utils/generate_single_prefill_inst.py index 14535c04e6..35da3a1818 100644 --- a/aot_build_utils/generate_single_prefill_inst.py +++ b/aot_build_utils/generate_single_prefill_inst.py @@ -18,7 +18,11 @@ import sys from pathlib import Path -from .literal_map import dtype_literal, mask_mode_literal, pos_encoding_mode_literal +from .literal_map import ( + dtype_literal, + mask_mode_literal, + pos_encoding_mode_literal, +) def get_cu_file_str( diff --git a/aot_build_utils/generate_single_prefill_sm90_inst.py b/aot_build_utils/generate_single_prefill_sm90_inst.py index 291aad8edd..7a531059ca 100644 --- a/aot_build_utils/generate_single_prefill_sm90_inst.py +++ b/aot_build_utils/generate_single_prefill_sm90_inst.py @@ -18,7 +18,11 @@ import sys from pathlib import Path -from .literal_map import dtype_literal, mask_mode_literal, pos_encoding_mode_literal +from .literal_map import ( + dtype_literal, + mask_mode_literal, + pos_encoding_mode_literal, +) def get_cu_file_str( diff --git a/cmake/utils/ConfigurePrebuitUris.cmake b/cmake/utils/ConfigurePrebuitUris.cmake index b6789b2cdb..4b31476356 100644 --- a/cmake/utils/ConfigurePrebuitUris.cmake +++ b/cmake/utils/ConfigurePrebuitUris.cmake @@ -1,19 +1,19 @@ function(flashinfer_configure_prebuilt_uris) - message(STATUS "Configuring prebuilt URIs") - get_property(PREBUILT_URI_LIST GLOBAL PROPERTY FLASHINFER_PREBUILT_URIS) - set(PYTHON_URI "") - message(STATUS "PREBUILT_URI_LIST: ${PREBUILT_URI_LIST}") + message(STATUS "Configuring prebuilt URIs") + get_property(PREBUILT_URI_LIST GLOBAL PROPERTY FLASHINFER_PREBUILT_URIS) + set(PYTHON_URI "") + message(STATUS "PREBUILT_URI_LIST: ${PREBUILT_URI_LIST}") - string(REPLACE ";" "\", \"" list_items "${PREBUILT_URI_LIST}") - set(PYTHON_URI "${list_items}") + string(REPLACE ";" "\", \"" list_items "${PREBUILT_URI_LIST}") + set(PYTHON_URI "${list_items}") - message(STATUS "PYTHON_URI: ${PYTHON_URI}") - - set(TEMPLATE_FILE "${CMAKE_SOURCE_DIR}/templates/__aot_prebuilt_uris__.py.in") - set(OUTPUT_FILE "${CMAKE_BINARY_DIR}/flashinfer/__aot_prebuilt_uris__.py") - set(INSTALL_DIR "flashinfer") + message(STATUS "PYTHON_URI: ${PYTHON_URI}") - configure_file("${TEMPLATE_FILE}" "${OUTPUT_FILE}" @ONLY) + set(TEMPLATE_FILE "${CMAKE_SOURCE_DIR}/templates/__aot_prebuilt_uris__.py.in") + set(OUTPUT_FILE "${CMAKE_BINARY_DIR}/flashinfer/__aot_prebuilt_uris__.py") + set(INSTALL_DIR "flashinfer") - install(FILES "${OUTPUT_FILE}" DESTINATION "${INSTALL_DIR}") + configure_file("${TEMPLATE_FILE}" "${OUTPUT_FILE}" @ONLY) + + install(FILES "${OUTPUT_FILE}" DESTINATION "${INSTALL_DIR}") endfunction() diff --git a/examples/cpp/standalone_single_prefill.cu b/examples/cpp/standalone_single_prefill.cu index 345f87f401..53181273af 100644 --- a/examples/cpp/standalone_single_prefill.cu +++ b/examples/cpp/standalone_single_prefill.cu @@ -20,725 +20,617 @@ #include #include -namespace flashinfer -{ +namespace flashinfer { // Parameter struct for SinglePrefill -template struct SinglePrefillParams -{ - using DTypeQ = half; - using DTypeKV = half; - using DTypeO = DTypeOs; - using IdType = IdTypes; - - half *q; - half *k; - half *v; - DTypeO *o; - float *lse; - uint_fastdiv group_size; - - uint8_t *maybe_custom_mask; - float *maybe_alibi_slopes; - double logits_soft_cap; - double sm_scale; - double rope_rcp_scale; - double rope_rcp_theta; - - uint32_t qo_len; - uint32_t kv_len; - uint32_t num_qo_heads; - uint32_t num_kv_heads; - uint32_t q_stride_n; - uint32_t q_stride_h; - uint32_t k_stride_n; - uint32_t k_stride_h; - uint32_t v_stride_n; - uint32_t v_stride_h; - uint32_t head_dim; - int32_t window_left; - - bool partition_kv; - - __host__ __device__ __forceinline__ uint32_t - get_qo_len(uint32_t batch_idx) const - { - return qo_len; - } - - __host__ __device__ __forceinline__ uint32_t - get_kv_len(uint32_t batch_idx) const - { - return kv_len; - } +template +struct SinglePrefillParams { + using DTypeQ = half; + using DTypeKV = half; + using DTypeO = DTypeOs; + using IdType = IdTypes; + + half* q; + half* k; + half* v; + DTypeO* o; + float* lse; + uint_fastdiv group_size; + + uint8_t* maybe_custom_mask; + float* maybe_alibi_slopes; + double logits_soft_cap; + double sm_scale; + double rope_rcp_scale; + double rope_rcp_theta; + + uint32_t qo_len; + uint32_t kv_len; + uint32_t num_qo_heads; + uint32_t num_kv_heads; + uint32_t q_stride_n; + uint32_t q_stride_h; + uint32_t k_stride_n; + uint32_t k_stride_h; + uint32_t v_stride_n; + uint32_t v_stride_h; + uint32_t head_dim; + int32_t window_left; + + bool partition_kv; + + __host__ __device__ __forceinline__ uint32_t get_qo_len(uint32_t batch_idx) const { + return qo_len; + } + + __host__ __device__ __forceinline__ uint32_t get_kv_len(uint32_t batch_idx) const { + return kv_len; + } }; -} // namespace flashinfer +} // namespace flashinfer // CPU reference implementation for validation -namespace reference -{ +namespace reference { template -std::vector single_mha(const std::vector &q, - const std::vector &k, - const std::vector &v, - size_t qo_len, - size_t kv_len, - size_t num_qo_heads, - size_t num_kv_heads, - size_t head_dim, - bool causal, - flashinfer::QKVLayout kv_layout, - flashinfer::PosEncodingMode pos_encoding_mode, - float rope_scale = 1.0f, - float rope_theta = 10000.0f) -{ - float sm_scale = 1.0f / std::sqrt(static_cast(head_dim)); - std::vector o(qo_len * num_qo_heads * head_dim, static_cast(0.0f)); - std::vector att(kv_len); - size_t group_size = num_qo_heads / num_kv_heads; - - for (size_t qo_head_idx = 0; qo_head_idx < num_qo_heads; ++qo_head_idx) { - size_t kv_head_idx = qo_head_idx / group_size; - - for (size_t q_idx = 0; q_idx < qo_len; ++q_idx) { - // 1. Compute attention scores - float max_val = -5e4f; - - for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { - if (causal && kv_idx > kv_len + q_idx - qo_len) { - att[kv_idx] = -5e4f; - continue; - } - - // Compute dot product between Q and K - float score = 0.0f; - for (size_t d = 0; d < head_dim; ++d) { - float q_val = 0.0f; - float k_val = 0.0f; - - // Get Q value - always NHD layout - size_t q_offset = q_idx * num_qo_heads * head_dim + - qo_head_idx * head_dim + d; - q_val = static_cast(q[q_offset]); - - // Get K value - depends on layout - if (kv_layout == flashinfer::QKVLayout::kNHD) { - size_t k_offset = kv_idx * num_kv_heads * head_dim + - kv_head_idx * head_dim + d; - k_val = static_cast(k[k_offset]); - } - else { - size_t k_offset = kv_head_idx * kv_len * head_dim + - kv_idx * head_dim + d; - k_val = static_cast(k[k_offset]); - } - - score += q_val * k_val; - } - score *= sm_scale; - - att[kv_idx] = score; - max_val = std::max(max_val, score); - } - - // 2. Apply softmax - float sum_exp = 0.0f; - for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { - if (causal && kv_idx > kv_len + q_idx - qo_len) { - att[kv_idx] = 0.0f; - } - else { - att[kv_idx] = std::exp(att[kv_idx] - max_val); - sum_exp += att[kv_idx]; - } - } - - // Normalize - for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { - if (sum_exp > 0.0f) { - att[kv_idx] /= sum_exp; - } - } - - // 3. Compute weighted sum of values - for (size_t d = 0; d < head_dim; ++d) { - float weighted_sum = 0.0f; - - for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { - float v_val = 0.0f; - - // Get V value - depends on layout - if (kv_layout == flashinfer::QKVLayout::kNHD) { - size_t v_offset = kv_idx * num_kv_heads * head_dim + - kv_head_idx * head_dim + d; - v_val = static_cast(v[v_offset]); - } - else { - size_t v_offset = kv_head_idx * kv_len * head_dim + - kv_idx * head_dim + d; - v_val = static_cast(v[v_offset]); - } - - weighted_sum += att[kv_idx] * v_val; - } - - // Store result in output - size_t o_offset = q_idx * num_qo_heads * head_dim + - qo_head_idx * head_dim + d; - o[o_offset] = static_cast(weighted_sum); - } +std::vector single_mha(const std::vector& q, const std::vector& k, const std::vector& v, + size_t qo_len, size_t kv_len, size_t num_qo_heads, size_t num_kv_heads, + size_t head_dim, bool causal, flashinfer::QKVLayout kv_layout, + flashinfer::PosEncodingMode pos_encoding_mode, float rope_scale = 1.0f, + float rope_theta = 10000.0f) { + float sm_scale = 1.0f / std::sqrt(static_cast(head_dim)); + std::vector o(qo_len * num_qo_heads * head_dim, static_cast(0.0f)); + std::vector att(kv_len); + size_t group_size = num_qo_heads / num_kv_heads; + + for (size_t qo_head_idx = 0; qo_head_idx < num_qo_heads; ++qo_head_idx) { + size_t kv_head_idx = qo_head_idx / group_size; + + for (size_t q_idx = 0; q_idx < qo_len; ++q_idx) { + // 1. Compute attention scores + float max_val = -5e4f; + + for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { + if (causal && kv_idx > kv_len + q_idx - qo_len) { + att[kv_idx] = -5e4f; + continue; } - } - return o; -} + // Compute dot product between Q and K + float score = 0.0f; + for (size_t d = 0; d < head_dim; ++d) { + float q_val = 0.0f; + float k_val = 0.0f; + + // Get Q value - always NHD layout + size_t q_offset = q_idx * num_qo_heads * head_dim + qo_head_idx * head_dim + d; + q_val = static_cast(q[q_offset]); + + // Get K value - depends on layout + if (kv_layout == flashinfer::QKVLayout::kNHD) { + size_t k_offset = kv_idx * num_kv_heads * head_dim + kv_head_idx * head_dim + d; + k_val = static_cast(k[k_offset]); + } else { + size_t k_offset = kv_head_idx * kv_len * head_dim + kv_idx * head_dim + d; + k_val = static_cast(k[k_offset]); + } + + score += q_val * k_val; + } + score *= sm_scale; + + att[kv_idx] = score; + max_val = std::max(max_val, score); + } + + // 2. Apply softmax + float sum_exp = 0.0f; + for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { + if (causal && kv_idx > kv_len + q_idx - qo_len) { + att[kv_idx] = 0.0f; + } else { + att[kv_idx] = std::exp(att[kv_idx] - max_val); + sum_exp += att[kv_idx]; + } + } -} // namespace reference + // Normalize + for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { + if (sum_exp > 0.0f) { + att[kv_idx] /= sum_exp; + } + } -// Helper function to generate random data (without Thrust) -void generate_random_data(half *data, - size_t size, - float min_val = -1.0f, - float max_val = 1.0f) -{ - std::vector host_data(size); - std::mt19937 rng(42); // Fixed seed for reproducibility - std::uniform_real_distribution dist(min_val, max_val); - - for (size_t i = 0; i < size; ++i) { - host_data[i] = static_cast(dist(rng)); - } + // 3. Compute weighted sum of values + for (size_t d = 0; d < head_dim; ++d) { + float weighted_sum = 0.0f; - // Copy to device - FI_GPU_CALL(gpuMemcpy(data, host_data.data(), size * sizeof(half), - gpuMemcpyHostToDevice)); -} + for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { + float v_val = 0.0f; -// Function to validate GPU results against CPU reference (simplified) -bool validate_results(const half *gpu_output, - size_t gpu_size, - const std::vector &cpu_output, - float rtol = 1e-3f, - float atol = 1e-3f) -{ - if (gpu_size != cpu_output.size()) { - std::cerr << "Size mismatch: GPU=" << gpu_size - << " vs CPU=" << cpu_output.size() << std::endl; - return false; - } + // Get V value - depends on layout + if (kv_layout == flashinfer::QKVLayout::kNHD) { + size_t v_offset = kv_idx * num_kv_heads * head_dim + kv_head_idx * head_dim + d; + v_val = static_cast(v[v_offset]); + } else { + size_t v_offset = kv_head_idx * kv_len * head_dim + kv_idx * head_dim + d; + v_val = static_cast(v[v_offset]); + } - // Copy GPU data to host for comparison - std::vector host_output(gpu_size); - FI_GPU_CALL(gpuMemcpy(host_output.data(), gpu_output, - gpu_size * sizeof(half), gpuMemcpyDeviceToHost)); - - int errors = 0; - float max_diff = 0.0f; - float max_rel_diff = 0.0f; - - for (size_t i = 0; i < gpu_size; ++i) { - float gpu_val = static_cast(host_output[i]); - float cpu_val = static_cast(cpu_output[i]); - float abs_diff = std::abs(gpu_val - cpu_val); - float rel_diff = - (cpu_val != 0.0f) ? abs_diff / std::abs(cpu_val) : abs_diff; - - max_diff = std::max(max_diff, abs_diff); - max_rel_diff = std::max(max_rel_diff, rel_diff); - - bool close = (abs_diff <= atol + rtol * std::abs(cpu_val)); - if (!close) { - errors++; - if (errors <= 10) { // Print just a few examples - std::cerr << "Mismatch at " << i << ": GPU=" << gpu_val - << " CPU=" << cpu_val << " (diff=" << abs_diff << ")" - << std::endl; - } + weighted_sum += att[kv_idx] * v_val; } + + // Store result in output + size_t o_offset = q_idx * num_qo_heads * head_dim + qo_head_idx * head_dim + d; + o[o_offset] = static_cast(weighted_sum); + } } + } - float error_rate = static_cast(errors) / gpu_size; - std::cout << "\nValidation Results:" << std::endl; - std::cout << " Max absolute difference: " << max_diff << std::endl; - std::cout << " Max relative difference: " << max_rel_diff << std::endl; - std::cout << " Error rate: " << (error_rate * 100) << "% (" << errors - << " / " << gpu_size << " elements)" << std::endl; - std::cout << " Status: " << (error_rate < 0.05 ? "PASSED" : "FAILED") - << std::endl; - - // Allow up to 5% error rate - return error_rate < 0.05; + return o; } -using namespace flashinfer; +} // namespace reference -// Helper class to convert strings to parameters -class ArgParser -{ -public: - static bool get_bool(const char *arg, bool default_val) - { - return arg == nullptr - ? default_val - : (std::string(arg) == "1" || std::string(arg) == "true"); - } +// Helper function to generate random data (without Thrust) +void generate_random_data(half* data, size_t size, float min_val = -1.0f, float max_val = 1.0f) { + std::vector host_data(size); + std::mt19937 rng(42); // Fixed seed for reproducibility + std::uniform_real_distribution dist(min_val, max_val); - static int get_int(const char *arg, int default_val) - { - return arg == nullptr ? default_val : std::atoi(arg); - } + for (size_t i = 0; i < size; ++i) { + host_data[i] = static_cast(dist(rng)); + } - static float get_float(const char *arg, float default_val) - { - return arg == nullptr ? default_val : std::atof(arg); - } + // Copy to device + FI_GPU_CALL(gpuMemcpy(data, host_data.data(), size * sizeof(half), gpuMemcpyHostToDevice)); +} - static PosEncodingMode get_pos_encoding_mode(const char *arg) - { - if (arg == nullptr) - return PosEncodingMode::kNone; - std::string str_val = arg; - if (str_val == "none") - return PosEncodingMode::kNone; - if (str_val == "rope") - return PosEncodingMode::kRoPELlama; - if (str_val == "alibi") - return PosEncodingMode::kALiBi; - return PosEncodingMode::kNone; +// Function to validate GPU results against CPU reference (simplified) +bool validate_results(const half* gpu_output, size_t gpu_size, const std::vector& cpu_output, + float rtol = 1e-3f, float atol = 1e-3f) { + if (gpu_size != cpu_output.size()) { + std::cerr << "Size mismatch: GPU=" << gpu_size << " vs CPU=" << cpu_output.size() << std::endl; + return false; + } + + // Copy GPU data to host for comparison + std::vector host_output(gpu_size); + FI_GPU_CALL( + gpuMemcpy(host_output.data(), gpu_output, gpu_size * sizeof(half), gpuMemcpyDeviceToHost)); + + int errors = 0; + float max_diff = 0.0f; + float max_rel_diff = 0.0f; + + for (size_t i = 0; i < gpu_size; ++i) { + float gpu_val = static_cast(host_output[i]); + float cpu_val = static_cast(cpu_output[i]); + float abs_diff = std::abs(gpu_val - cpu_val); + float rel_diff = (cpu_val != 0.0f) ? abs_diff / std::abs(cpu_val) : abs_diff; + + max_diff = std::max(max_diff, abs_diff); + max_rel_diff = std::max(max_rel_diff, rel_diff); + + bool close = (abs_diff <= atol + rtol * std::abs(cpu_val)); + if (!close) { + errors++; + if (errors <= 10) { // Print just a few examples + std::cerr << "Mismatch at " << i << ": GPU=" << gpu_val << " CPU=" << cpu_val + << " (diff=" << abs_diff << ")" << std::endl; + } } + } + + float error_rate = static_cast(errors) / gpu_size; + std::cout << "\nValidation Results:" << std::endl; + std::cout << " Max absolute difference: " << max_diff << std::endl; + std::cout << " Max relative difference: " << max_rel_diff << std::endl; + std::cout << " Error rate: " << (error_rate * 100) << "% (" << errors << " / " << gpu_size + << " elements)" << std::endl; + std::cout << " Status: " << (error_rate < 0.05 ? "PASSED" : "FAILED") << std::endl; + + // Allow up to 5% error rate + return error_rate < 0.05; +} - static QKVLayout get_layout(const char *arg) - { - if (arg == nullptr) - return QKVLayout::kNHD; - std::string str_val = arg; - if (str_val == "nhd") - return QKVLayout::kNHD; - if (str_val == "hnd") - return QKVLayout::kHND; - return QKVLayout::kNHD; - } +using namespace flashinfer; + +// Helper class to convert strings to parameters +class ArgParser { + public: + static bool get_bool(const char* arg, bool default_val) { + return arg == nullptr ? default_val : (std::string(arg) == "1" || std::string(arg) == "true"); + } + + static int get_int(const char* arg, int default_val) { + return arg == nullptr ? default_val : std::atoi(arg); + } + + static float get_float(const char* arg, float default_val) { + return arg == nullptr ? default_val : std::atof(arg); + } + + static PosEncodingMode get_pos_encoding_mode(const char* arg) { + if (arg == nullptr) return PosEncodingMode::kNone; + std::string str_val = arg; + if (str_val == "none") return PosEncodingMode::kNone; + if (str_val == "rope") return PosEncodingMode::kRoPELlama; + if (str_val == "alibi") return PosEncodingMode::kALiBi; + return PosEncodingMode::kNone; + } + + static QKVLayout get_layout(const char* arg) { + if (arg == nullptr) return QKVLayout::kNHD; + std::string str_val = arg; + if (str_val == "nhd") return QKVLayout::kNHD; + if (str_val == "hnd") return QKVLayout::kHND; + return QKVLayout::kNHD; + } }; // Dispatch function for half precision -gpuError_t dispatch_single_prefill(half *q_ptr, - half *k_ptr, - half *v_ptr, - half *o_ptr, - half *tmp_ptr, - float *lse_ptr, - uint32_t num_qo_heads, - uint32_t num_kv_heads, - uint32_t qo_len, - uint32_t kv_len, - uint32_t head_dim, - QKVLayout kv_layout, - PosEncodingMode pos_encoding_mode, - bool causal, - bool use_fp16_qk_reduction, - double sm_scale, - int32_t window_left, - double rope_scale, - double rope_theta, - gpuStream_t stream) -{ - // Compute strides based on layout - uint32_t q_stride_n = num_qo_heads * head_dim; - uint32_t q_stride_h = head_dim; - uint32_t k_stride_n, k_stride_h, v_stride_n, v_stride_h; - - if (kv_layout == QKVLayout::kNHD) { - k_stride_n = num_kv_heads * head_dim; - k_stride_h = head_dim; - v_stride_n = num_kv_heads * head_dim; - v_stride_h = head_dim; - } - else { - k_stride_h = kv_len * head_dim; - k_stride_n = head_dim; - v_stride_h = kv_len * head_dim; - v_stride_n = head_dim; - } - - // Configure mask mode - const MaskMode mask_mode = causal ? MaskMode::kCausal : MaskMode::kNone; - - // Constants for prefill kernel - constexpr uint32_t HEAD_DIM_QK = 128; - constexpr uint32_t HEAD_DIM_VO = 128; - constexpr PosEncodingMode POS_ENCODING_MODE = PosEncodingMode::kRoPELlama; - constexpr bool USE_FP16_QK_REDUCTION = false; - - gpuError_t status = gpuSuccess; - - if (causal) { - // Causal attention - using AttentionVariantType = - DefaultAttention; - using Params = SinglePrefillParams; - - Params params; - params.q = q_ptr; - params.k = k_ptr; - params.v = v_ptr; - params.o = o_ptr; - params.lse = lse_ptr; - params.num_qo_heads = num_qo_heads; - params.num_kv_heads = num_kv_heads; - params.group_size = uint_fastdiv(num_qo_heads / num_kv_heads); - params.qo_len = qo_len; - params.kv_len = kv_len; - params.q_stride_n = q_stride_n; - params.q_stride_h = q_stride_h; - params.k_stride_n = k_stride_n; - params.k_stride_h = k_stride_h; - params.v_stride_n = v_stride_n; - params.v_stride_h = v_stride_h; - params.head_dim = head_dim; - params.window_left = window_left; - params.partition_kv = false; - params.maybe_custom_mask = nullptr; - params.maybe_alibi_slopes = nullptr; - params.logits_soft_cap = 0.0; - params.sm_scale = sm_scale; - params.rope_rcp_scale = 1.0 / rope_scale; - params.rope_rcp_theta = 1.0 / rope_theta; - - status = SinglePrefillWithKVCacheDispatched< - HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE, USE_FP16_QK_REDUCTION, - MaskMode::kCausal, AttentionVariantType>(params, tmp_ptr, stream); - } - else { - // Non-causal attention - using AttentionVariantType = - DefaultAttention; - using Params = SinglePrefillParams; - - Params params; - params.q = q_ptr; - params.k = k_ptr; - params.v = v_ptr; - params.o = o_ptr; - params.lse = lse_ptr; - params.num_qo_heads = num_qo_heads; - params.num_kv_heads = num_kv_heads; - params.group_size = uint_fastdiv(num_qo_heads / num_kv_heads); - params.qo_len = qo_len; - params.kv_len = kv_len; - params.q_stride_n = q_stride_n; - params.q_stride_h = q_stride_h; - params.k_stride_n = k_stride_n; - params.k_stride_h = k_stride_h; - params.v_stride_n = v_stride_n; - params.v_stride_h = v_stride_h; - params.head_dim = head_dim; - params.window_left = window_left; - params.partition_kv = false; - params.maybe_custom_mask = nullptr; - params.maybe_alibi_slopes = nullptr; - params.logits_soft_cap = 0.0; - params.sm_scale = sm_scale; - params.rope_rcp_scale = 1.0 / rope_scale; - params.rope_rcp_theta = 1.0 / rope_theta; - - status = SinglePrefillWithKVCacheDispatched< - HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE, USE_FP16_QK_REDUCTION, - MaskMode::kNone, AttentionVariantType>(params, tmp_ptr, stream); - } - - return status; +gpuError_t dispatch_single_prefill(half* q_ptr, half* k_ptr, half* v_ptr, half* o_ptr, + half* tmp_ptr, float* lse_ptr, uint32_t num_qo_heads, + uint32_t num_kv_heads, uint32_t qo_len, uint32_t kv_len, + uint32_t head_dim, QKVLayout kv_layout, + PosEncodingMode pos_encoding_mode, bool causal, + bool use_fp16_qk_reduction, double sm_scale, int32_t window_left, + double rope_scale, double rope_theta, gpuStream_t stream) { + // Compute strides based on layout + uint32_t q_stride_n = num_qo_heads * head_dim; + uint32_t q_stride_h = head_dim; + uint32_t k_stride_n, k_stride_h, v_stride_n, v_stride_h; + + if (kv_layout == QKVLayout::kNHD) { + k_stride_n = num_kv_heads * head_dim; + k_stride_h = head_dim; + v_stride_n = num_kv_heads * head_dim; + v_stride_h = head_dim; + } else { + k_stride_h = kv_len * head_dim; + k_stride_n = head_dim; + v_stride_h = kv_len * head_dim; + v_stride_n = head_dim; + } + + // Configure mask mode + const MaskMode mask_mode = causal ? MaskMode::kCausal : MaskMode::kNone; + + // Constants for prefill kernel + constexpr uint32_t HEAD_DIM_QK = 128; + constexpr uint32_t HEAD_DIM_VO = 128; + constexpr PosEncodingMode POS_ENCODING_MODE = PosEncodingMode::kRoPELlama; + constexpr bool USE_FP16_QK_REDUCTION = false; + + gpuError_t status = gpuSuccess; + + if (causal) { + // Causal attention + using AttentionVariantType = DefaultAttention; + using Params = SinglePrefillParams; + + Params params; + params.q = q_ptr; + params.k = k_ptr; + params.v = v_ptr; + params.o = o_ptr; + params.lse = lse_ptr; + params.num_qo_heads = num_qo_heads; + params.num_kv_heads = num_kv_heads; + params.group_size = uint_fastdiv(num_qo_heads / num_kv_heads); + params.qo_len = qo_len; + params.kv_len = kv_len; + params.q_stride_n = q_stride_n; + params.q_stride_h = q_stride_h; + params.k_stride_n = k_stride_n; + params.k_stride_h = k_stride_h; + params.v_stride_n = v_stride_n; + params.v_stride_h = v_stride_h; + params.head_dim = head_dim; + params.window_left = window_left; + params.partition_kv = false; + params.maybe_custom_mask = nullptr; + params.maybe_alibi_slopes = nullptr; + params.logits_soft_cap = 0.0; + params.sm_scale = sm_scale; + params.rope_rcp_scale = 1.0 / rope_scale; + params.rope_rcp_theta = 1.0 / rope_theta; + + status = SinglePrefillWithKVCacheDispatched(params, tmp_ptr, stream); + } else { + // Non-causal attention + using AttentionVariantType = DefaultAttention; + using Params = SinglePrefillParams; + + Params params; + params.q = q_ptr; + params.k = k_ptr; + params.v = v_ptr; + params.o = o_ptr; + params.lse = lse_ptr; + params.num_qo_heads = num_qo_heads; + params.num_kv_heads = num_kv_heads; + params.group_size = uint_fastdiv(num_qo_heads / num_kv_heads); + params.qo_len = qo_len; + params.kv_len = kv_len; + params.q_stride_n = q_stride_n; + params.q_stride_h = q_stride_h; + params.k_stride_n = k_stride_n; + params.k_stride_h = k_stride_h; + params.v_stride_n = v_stride_n; + params.v_stride_h = v_stride_h; + params.head_dim = head_dim; + params.window_left = window_left; + params.partition_kv = false; + params.maybe_custom_mask = nullptr; + params.maybe_alibi_slopes = nullptr; + params.logits_soft_cap = 0.0; + params.sm_scale = sm_scale; + params.rope_rcp_scale = 1.0 / rope_scale; + params.rope_rcp_theta = 1.0 / rope_theta; + + status = SinglePrefillWithKVCacheDispatched(params, tmp_ptr, stream); + } + + return status; } // Function to calculate FLOPs for single_prefill -double calculate_flops(uint32_t qo_len, - uint32_t kv_len, - uint32_t num_qo_heads, - uint32_t head_dim, - bool causal) -{ - double flops; - if (causal) { - // For causal attention: qo_len * (2 * kv_len - qo_len) * 2 * - // num_qo_heads * head_dim - flops = static_cast(qo_len) * (2.0 * kv_len - qo_len) * 2.0 * - num_qo_heads * head_dim; - } - else { - // For non-causal attention: qo_len * kv_len * 4 * num_qo_heads * - // head_dim - flops = static_cast(qo_len) * kv_len * 4.0 * num_qo_heads * - head_dim; - } - return flops; +double calculate_flops(uint32_t qo_len, uint32_t kv_len, uint32_t num_qo_heads, uint32_t head_dim, + bool causal) { + double flops; + if (causal) { + // For causal attention: qo_len * (2 * kv_len - qo_len) * 2 * + // num_qo_heads * head_dim + flops = static_cast(qo_len) * (2.0 * kv_len - qo_len) * 2.0 * num_qo_heads * head_dim; + } else { + // For non-causal attention: qo_len * kv_len * 4 * num_qo_heads * + // head_dim + flops = static_cast(qo_len) * kv_len * 4.0 * num_qo_heads * head_dim; + } + return flops; } -void print_usage(const char *program_name) -{ - std::cerr - << "Usage: " << program_name << " [options]\n" - << "Options:\n" - << " --qo_len : Query sequence length (default: " - "512)\n" - << " --kv_len : Key/value sequence length (default: " - "512)\n" - << " --num_qo_heads : Number of query heads (default: 32)\n" - << " --num_kv_heads : Number of key/value heads (default: " - "32)\n" - << " --head_dim : Head dimension (default: 128)\n" - << " --layout : KV tensor layout (default: nhd)\n" - << " --pos_encoding : Position encoding mode " - "(default: none)\n" - << " --causal <0|1> : Use causal mask (default: 1)\n" - << " --use_fp16_qk <0|1> : Use FP16 for QK reduction (default: " - "0)\n" - << " --window_left : Window left size (default: -1)\n" - << " --rope_scale : RoPE scale factor (default: 1.0)\n" - << " --rope_theta : RoPE theta (default: 10000.0)\n" - << " --iterations : Number of iterations for timing " - "(default: 10)\n" - << " --warmup : Number of warmup iterations " - "(default: 5)\n" - << " --validate <0|1> : Validate against CPU reference " - "(default: 0)\n"; +void print_usage(const char* program_name) { + std::cerr << "Usage: " << program_name << " [options]\n" + << "Options:\n" + << " --qo_len : Query sequence length (default: " + "512)\n" + << " --kv_len : Key/value sequence length (default: " + "512)\n" + << " --num_qo_heads : Number of query heads (default: 32)\n" + << " --num_kv_heads : Number of key/value heads (default: " + "32)\n" + << " --head_dim : Head dimension (default: 128)\n" + << " --layout : KV tensor layout (default: nhd)\n" + << " --pos_encoding : Position encoding mode " + "(default: none)\n" + << " --causal <0|1> : Use causal mask (default: 1)\n" + << " --use_fp16_qk <0|1> : Use FP16 for QK reduction (default: " + "0)\n" + << " --window_left : Window left size (default: -1)\n" + << " --rope_scale : RoPE scale factor (default: 1.0)\n" + << " --rope_theta : RoPE theta (default: 10000.0)\n" + << " --iterations : Number of iterations for timing " + "(default: 10)\n" + << " --warmup : Number of warmup iterations " + "(default: 5)\n" + << " --validate <0|1> : Validate against CPU reference " + "(default: 0)\n"; } // Main function with simplified memory management -int main(int argc, char *argv[]) -{ - if (argc > 1 && - (std::string(argv[1]) == "--help" || std::string(argv[1]) == "-h")) - { - print_usage(argv[0]); - return 0; +int main(int argc, char* argv[]) { + if (argc > 1 && (std::string(argv[1]) == "--help" || std::string(argv[1]) == "-h")) { + print_usage(argv[0]); + return 0; + } + + // Process parameter pairs (--param value) + uint32_t qo_len = 512; + uint32_t kv_len = 512; + uint32_t num_qo_heads = 32; + uint32_t num_kv_heads = 32; + uint32_t head_dim = 128; + QKVLayout kv_layout = QKVLayout::kNHD; + PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone; + bool causal = true; + bool use_fp16_qk_reduction = false; + int32_t window_left = -1; + float rope_scale = 1.0f; + float rope_theta = 10000.0f; + int iterations = 10; + int warmup = 5; + bool validate = false; + + for (int i = 1; i < argc; i += 2) { + std::string arg = argv[i]; + if (i + 1 >= argc && arg != "--help" && arg != "-h") { + std::cerr << "Missing value for parameter " << arg << std::endl; + print_usage(argv[0]); + return 1; } - // Process parameter pairs (--param value) - uint32_t qo_len = 512; - uint32_t kv_len = 512; - uint32_t num_qo_heads = 32; - uint32_t num_kv_heads = 32; - uint32_t head_dim = 128; - QKVLayout kv_layout = QKVLayout::kNHD; - PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone; - bool causal = true; - bool use_fp16_qk_reduction = false; - int32_t window_left = -1; - float rope_scale = 1.0f; - float rope_theta = 10000.0f; - int iterations = 10; - int warmup = 5; - bool validate = false; - - for (int i = 1; i < argc; i += 2) { - std::string arg = argv[i]; - if (i + 1 >= argc && arg != "--help" && arg != "-h") { - std::cerr << "Missing value for parameter " << arg << std::endl; - print_usage(argv[0]); - return 1; - } - - if (arg == "--qo_len") { - qo_len = ArgParser::get_int(argv[i + 1], 512); - } - else if (arg == "--kv_len") { - kv_len = ArgParser::get_int(argv[i + 1], 512); - } - else if (arg == "--num_qo_heads") { - num_qo_heads = ArgParser::get_int(argv[i + 1], 32); - } - else if (arg == "--num_kv_heads") { - num_kv_heads = ArgParser::get_int(argv[i + 1], 32); - } - else if (arg == "--head_dim") { - head_dim = ArgParser::get_int(argv[i + 1], 128); - } - else if (arg == "--layout") { - kv_layout = ArgParser::get_layout(argv[i + 1]); - } - else if (arg == "--pos_encoding") { - pos_encoding_mode = ArgParser::get_pos_encoding_mode(argv[i + 1]); - } - else if (arg == "--causal") { - causal = ArgParser::get_bool(argv[i + 1], true); - } - else if (arg == "--use_fp16_qk") { - use_fp16_qk_reduction = ArgParser::get_bool(argv[i + 1], false); - } - else if (arg == "--window_left") { - window_left = ArgParser::get_int(argv[i + 1], -1); - } - else if (arg == "--rope_scale") { - rope_scale = ArgParser::get_float(argv[i + 1], 1.0f); - } - else if (arg == "--rope_theta") { - rope_theta = ArgParser::get_float(argv[i + 1], 10000.0f); - } - else if (arg == "--iterations") { - iterations = ArgParser::get_int(argv[i + 1], 10); - } - else if (arg == "--warmup") { - warmup = ArgParser::get_int(argv[i + 1], 5); - } - else if (arg == "--validate") { - validate = ArgParser::get_bool(argv[i + 1], false); - } - else { - std::cerr << "Unknown parameter: " << arg << std::endl; - print_usage(argv[0]); - return 1; - } + if (arg == "--qo_len") { + qo_len = ArgParser::get_int(argv[i + 1], 512); + } else if (arg == "--kv_len") { + kv_len = ArgParser::get_int(argv[i + 1], 512); + } else if (arg == "--num_qo_heads") { + num_qo_heads = ArgParser::get_int(argv[i + 1], 32); + } else if (arg == "--num_kv_heads") { + num_kv_heads = ArgParser::get_int(argv[i + 1], 32); + } else if (arg == "--head_dim") { + head_dim = ArgParser::get_int(argv[i + 1], 128); + } else if (arg == "--layout") { + kv_layout = ArgParser::get_layout(argv[i + 1]); + } else if (arg == "--pos_encoding") { + pos_encoding_mode = ArgParser::get_pos_encoding_mode(argv[i + 1]); + } else if (arg == "--causal") { + causal = ArgParser::get_bool(argv[i + 1], true); + } else if (arg == "--use_fp16_qk") { + use_fp16_qk_reduction = ArgParser::get_bool(argv[i + 1], false); + } else if (arg == "--window_left") { + window_left = ArgParser::get_int(argv[i + 1], -1); + } else if (arg == "--rope_scale") { + rope_scale = ArgParser::get_float(argv[i + 1], 1.0f); + } else if (arg == "--rope_theta") { + rope_theta = ArgParser::get_float(argv[i + 1], 10000.0f); + } else if (arg == "--iterations") { + iterations = ArgParser::get_int(argv[i + 1], 10); + } else if (arg == "--warmup") { + warmup = ArgParser::get_int(argv[i + 1], 5); + } else if (arg == "--validate") { + validate = ArgParser::get_bool(argv[i + 1], false); + } else { + std::cerr << "Unknown parameter: " << arg << std::endl; + print_usage(argv[0]); + return 1; } - - // Print configuration - std::cout << "Configuration:" << std::endl - << " QO Length: " << qo_len << std::endl - << " KV Length: " << kv_len << std::endl - << " QO Heads: " << num_qo_heads << std::endl - << " KV Heads: " << num_kv_heads << std::endl - << " Head Dimension: " << head_dim << std::endl - << " KV Layout: " - << (kv_layout == QKVLayout::kNHD ? "NHD" : "HND") << std::endl - << " Position Encoding: " - << (pos_encoding_mode == PosEncodingMode::kNone ? "None" - : pos_encoding_mode == PosEncodingMode::kRoPELlama ? "RoPE" - : "ALiBi") - << std::endl - << " Causal: " << (causal ? "Yes" : "No") << std::endl - << " Use FP16 QK Reduction: " - << (use_fp16_qk_reduction ? "Yes" : "No") << std::endl - << " Window Left: " << window_left << std::endl - << " RoPE Scale: " << rope_scale << std::endl - << " RoPE Theta: " << rope_theta << std::endl - << " Iterations: " << iterations << std::endl - << " Warmup: " << warmup << std::endl - << " Validation: " << (validate ? "Yes" : "No") << std::endl; - - // Create stream - gpuStream_t stream; - FI_GPU_CALL(gpuStreamCreate(&stream)); - - // Allocate device memory using gpuMalloc instead of Thrust - half *q_dev, *k_dev, *v_dev, *o_dev, *tmp_dev; - float *lse_dev; - - size_t q_size = qo_len * num_qo_heads * head_dim; - size_t k_size = kv_len * num_kv_heads * head_dim; - size_t v_size = kv_len * num_kv_heads * head_dim; - size_t o_size = qo_len * num_qo_heads * head_dim; - size_t lse_size = qo_len * num_qo_heads; - - FI_GPU_CALL(gpuMalloc(&q_dev, q_size * sizeof(half))); - FI_GPU_CALL(gpuMalloc(&k_dev, k_size * sizeof(half))); - FI_GPU_CALL(gpuMalloc(&v_dev, v_size * sizeof(half))); - FI_GPU_CALL(gpuMalloc(&o_dev, o_size * sizeof(half))); - FI_GPU_CALL(gpuMalloc(&tmp_dev, o_size * sizeof(half))); - FI_GPU_CALL(gpuMalloc(&lse_dev, lse_size * sizeof(float))); - - // Initialize data - generate_random_data(q_dev, q_size); - generate_random_data(k_dev, k_size); - generate_random_data(v_dev, v_size); - - // Zero out output arrays - FI_GPU_CALL(gpuMemset(o_dev, 0, o_size * sizeof(half))); - FI_GPU_CALL(gpuMemset(tmp_dev, 0, o_size * sizeof(half))); - FI_GPU_CALL(gpuMemset(lse_dev, 0, lse_size * sizeof(float))); - - // Calculate SM scale - float sm_scale = 1.0f / std::sqrt(static_cast(head_dim)); - - // Warmup runs - for (int i = 0; i < warmup; ++i) { - gpuError_t status = dispatch_single_prefill( - q_dev, k_dev, v_dev, o_dev, tmp_dev, lse_dev, num_qo_heads, - num_kv_heads, qo_len, kv_len, head_dim, kv_layout, - pos_encoding_mode, causal, use_fp16_qk_reduction, sm_scale, - window_left, rope_scale, rope_theta, stream); - - if (status != gpuSuccess) { - std::cerr << "Error during warmup: " << gpuGetErrorString(status) - << std::endl; - return 1; - } + } + + // Print configuration + std::cout << "Configuration:" << std::endl + << " QO Length: " << qo_len << std::endl + << " KV Length: " << kv_len << std::endl + << " QO Heads: " << num_qo_heads << std::endl + << " KV Heads: " << num_kv_heads << std::endl + << " Head Dimension: " << head_dim << std::endl + << " KV Layout: " << (kv_layout == QKVLayout::kNHD ? "NHD" : "HND") << std::endl + << " Position Encoding: " + << (pos_encoding_mode == PosEncodingMode::kNone ? "None" + : pos_encoding_mode == PosEncodingMode::kRoPELlama ? "RoPE" + : "ALiBi") + << std::endl + << " Causal: " << (causal ? "Yes" : "No") << std::endl + << " Use FP16 QK Reduction: " << (use_fp16_qk_reduction ? "Yes" : "No") << std::endl + << " Window Left: " << window_left << std::endl + << " RoPE Scale: " << rope_scale << std::endl + << " RoPE Theta: " << rope_theta << std::endl + << " Iterations: " << iterations << std::endl + << " Warmup: " << warmup << std::endl + << " Validation: " << (validate ? "Yes" : "No") << std::endl; + + // Create stream + gpuStream_t stream; + FI_GPU_CALL(gpuStreamCreate(&stream)); + + // Allocate device memory using gpuMalloc instead of Thrust + half *q_dev, *k_dev, *v_dev, *o_dev, *tmp_dev; + float* lse_dev; + + size_t q_size = qo_len * num_qo_heads * head_dim; + size_t k_size = kv_len * num_kv_heads * head_dim; + size_t v_size = kv_len * num_kv_heads * head_dim; + size_t o_size = qo_len * num_qo_heads * head_dim; + size_t lse_size = qo_len * num_qo_heads; + + FI_GPU_CALL(gpuMalloc(&q_dev, q_size * sizeof(half))); + FI_GPU_CALL(gpuMalloc(&k_dev, k_size * sizeof(half))); + FI_GPU_CALL(gpuMalloc(&v_dev, v_size * sizeof(half))); + FI_GPU_CALL(gpuMalloc(&o_dev, o_size * sizeof(half))); + FI_GPU_CALL(gpuMalloc(&tmp_dev, o_size * sizeof(half))); + FI_GPU_CALL(gpuMalloc(&lse_dev, lse_size * sizeof(float))); + + // Initialize data + generate_random_data(q_dev, q_size); + generate_random_data(k_dev, k_size); + generate_random_data(v_dev, v_size); + + // Zero out output arrays + FI_GPU_CALL(gpuMemset(o_dev, 0, o_size * sizeof(half))); + FI_GPU_CALL(gpuMemset(tmp_dev, 0, o_size * sizeof(half))); + FI_GPU_CALL(gpuMemset(lse_dev, 0, lse_size * sizeof(float))); + + // Calculate SM scale + float sm_scale = 1.0f / std::sqrt(static_cast(head_dim)); + + // Warmup runs + for (int i = 0; i < warmup; ++i) { + gpuError_t status = dispatch_single_prefill( + q_dev, k_dev, v_dev, o_dev, tmp_dev, lse_dev, num_qo_heads, num_kv_heads, qo_len, kv_len, + head_dim, kv_layout, pos_encoding_mode, causal, use_fp16_qk_reduction, sm_scale, + window_left, rope_scale, rope_theta, stream); + + if (status != gpuSuccess) { + std::cerr << "Error during warmup: " << gpuGetErrorString(status) << std::endl; + return 1; } + } - // Timing runs - gpuEvent_t start, stop; - FI_GPU_CALL(gpuEventCreate(&start)); - FI_GPU_CALL(gpuEventCreate(&stop)); - - FI_GPU_CALL(gpuEventRecord(start, stream)); + // Timing runs + gpuEvent_t start, stop; + FI_GPU_CALL(gpuEventCreate(&start)); + FI_GPU_CALL(gpuEventCreate(&stop)); - for (int i = 0; i < iterations; ++i) { - gpuError_t status = dispatch_single_prefill( - q_dev, k_dev, v_dev, o_dev, tmp_dev, lse_dev, num_qo_heads, - num_kv_heads, qo_len, kv_len, head_dim, kv_layout, - pos_encoding_mode, causal, use_fp16_qk_reduction, sm_scale, - window_left, rope_scale, rope_theta, stream); + FI_GPU_CALL(gpuEventRecord(start, stream)); - if (status != gpuSuccess) { - std::cerr << "Error during benchmark: " << gpuGetErrorString(status) - << std::endl; - return 1; - } - } + for (int i = 0; i < iterations; ++i) { + gpuError_t status = dispatch_single_prefill( + q_dev, k_dev, v_dev, o_dev, tmp_dev, lse_dev, num_qo_heads, num_kv_heads, qo_len, kv_len, + head_dim, kv_layout, pos_encoding_mode, causal, use_fp16_qk_reduction, sm_scale, + window_left, rope_scale, rope_theta, stream); - FI_GPU_CALL(gpuEventRecord(stop, stream)); - FI_GPU_CALL(gpuEventSynchronize(stop)); - - float elapsed_ms; - FI_GPU_CALL(gpuEventElapsedTime(&elapsed_ms, start, stop)); - float avg_ms = elapsed_ms / iterations; - - // Calculate and report performance - double flops = - calculate_flops(qo_len, kv_len, num_qo_heads, head_dim, causal); - double tflops = flops / (avg_ms * 1e-3) / 1e12; - - // Report results - std::cout << std::fixed << std::setprecision(4); - std::cout << "Performance Results:" << std::endl; - std::cout << " Average time: " << avg_ms << " ms" << std::endl; - std::cout << " Performance: " << tflops << " TFLOPS" << std::endl; - - // Run validation if requested - if (validate) { - std::cout << "\nRunning validation..." << std::endl; - - // Copy input data to host for CPU reference - std::vector h_q(q_size), h_k(k_size), h_v(v_size); - FI_GPU_CALL(gpuMemcpy(h_q.data(), q_dev, q_size * sizeof(half), - gpuMemcpyHostToDevice)); - FI_GPU_CALL(gpuMemcpy(h_k.data(), k_dev, k_size * sizeof(half), - gpuMemcpyHostToDevice)); - FI_GPU_CALL(gpuMemcpy(h_v.data(), v_dev, v_size * sizeof(half), - gpuMemcpyHostToDevice)); - - // Compute reference output on CPU - std::vector ref_output = reference::single_mha( - h_q, h_k, h_v, qo_len, kv_len, num_qo_heads, num_kv_heads, head_dim, - causal, kv_layout, pos_encoding_mode, rope_scale, rope_theta); - - // Validate results - bool validation_passed = validate_results(o_dev, o_size, ref_output); - - // Report validation status - std::cout << "Validation " << (validation_passed ? "PASSED" : "FAILED") - << std::endl; + if (status != gpuSuccess) { + std::cerr << "Error during benchmark: " << gpuGetErrorString(status) << std::endl; + return 1; } - - // Cleanup - FI_GPU_CALL(gpuEventDestroy(start)); - FI_GPU_CALL(gpuEventDestroy(stop)); - FI_GPU_CALL(gpuStreamDestroy(stream)); - FI_GPU_CALL(gpuFree(q_dev)); - FI_GPU_CALL(gpuFree(k_dev)); - FI_GPU_CALL(gpuFree(v_dev)); - FI_GPU_CALL(gpuFree(o_dev)); - FI_GPU_CALL(gpuFree(tmp_dev)); - FI_GPU_CALL(gpuFree(lse_dev)); - - return 0; + } + + FI_GPU_CALL(gpuEventRecord(stop, stream)); + FI_GPU_CALL(gpuEventSynchronize(stop)); + + float elapsed_ms; + FI_GPU_CALL(gpuEventElapsedTime(&elapsed_ms, start, stop)); + float avg_ms = elapsed_ms / iterations; + + // Calculate and report performance + double flops = calculate_flops(qo_len, kv_len, num_qo_heads, head_dim, causal); + double tflops = flops / (avg_ms * 1e-3) / 1e12; + + // Report results + std::cout << std::fixed << std::setprecision(4); + std::cout << "Performance Results:" << std::endl; + std::cout << " Average time: " << avg_ms << " ms" << std::endl; + std::cout << " Performance: " << tflops << " TFLOPS" << std::endl; + + // Run validation if requested + if (validate) { + std::cout << "\nRunning validation..." << std::endl; + + // Copy input data to host for CPU reference + std::vector h_q(q_size), h_k(k_size), h_v(v_size); + FI_GPU_CALL(gpuMemcpy(h_q.data(), q_dev, q_size * sizeof(half), gpuMemcpyHostToDevice)); + FI_GPU_CALL(gpuMemcpy(h_k.data(), k_dev, k_size * sizeof(half), gpuMemcpyHostToDevice)); + FI_GPU_CALL(gpuMemcpy(h_v.data(), v_dev, v_size * sizeof(half), gpuMemcpyHostToDevice)); + + // Compute reference output on CPU + std::vector ref_output = + reference::single_mha(h_q, h_k, h_v, qo_len, kv_len, num_qo_heads, num_kv_heads, head_dim, + causal, kv_layout, pos_encoding_mode, rope_scale, rope_theta); + + // Validate results + bool validation_passed = validate_results(o_dev, o_size, ref_output); + + // Report validation status + std::cout << "Validation " << (validation_passed ? "PASSED" : "FAILED") << std::endl; + } + + // Cleanup + FI_GPU_CALL(gpuEventDestroy(start)); + FI_GPU_CALL(gpuEventDestroy(stop)); + FI_GPU_CALL(gpuStreamDestroy(stream)); + FI_GPU_CALL(gpuFree(q_dev)); + FI_GPU_CALL(gpuFree(k_dev)); + FI_GPU_CALL(gpuFree(v_dev)); + FI_GPU_CALL(gpuFree(o_dev)); + FI_GPU_CALL(gpuFree(tmp_dev)); + FI_GPU_CALL(gpuFree(lse_dev)); + + return 0; } diff --git a/examples/test_batch_decode_example.py b/examples/test_batch_decode_example.py index 9c17356f97..8da21b4c9e 100644 --- a/examples/test_batch_decode_example.py +++ b/examples/test_batch_decode_example.py @@ -1,17 +1,22 @@ import torch + import flashinfer + def verify_tensors(tensor1, tensor2, rtol=1e-3, atol=1e-3): - + for i in range(tensor1.shape[0]): for j in range(tensor1.shape[1]): - if torch.abs(tensor1[i][j] - tensor2[i][j]) > atol + rtol * torch.abs(tensor2[i][j]): + if torch.abs(tensor1[i][j] - tensor2[i][j]) > atol + rtol * torch.abs( + tensor2[i][j] + ): print(f"Error at {i}, {j}") print(f"Expected: {tensor2[i][j]}") print(f"Got: {tensor1[i][j]}") return False return True + def test_batch_decode_with_paged_kv_cache( batch_size, kv_len, @@ -119,10 +124,7 @@ def test_batch_decode_with_paged_kv_cache( dim=0, ).to(kv_dtype) # print(qi.shape, ki.shape, vi.shape) - o_ref_i = flashinfer.single_decode_with_kv_cache( - qi, - ki, - vi) + o_ref_i = flashinfer.single_decode_with_kv_cache(qi, ki, vi) # torch.testing.assert_close(o[i], o_ref_i, rtol=1e-3, atol=1e-3) result += verify_tensors(o[i], o_ref_i, rtol=1e-3, atol=1e-3) @@ -136,14 +138,15 @@ def test_batch_decode_with_paged_kv_cache( else: print("FAIL") + if __name__ == "__main__": batch_size = 256 - page_size = 8 + page_size = 8 # # This configuration works - num_qo_heads = 32 - num_kv_heads = 4 + num_qo_heads = 32 + num_kv_heads = 4 head_dim = 256 kv_len = 512 @@ -152,7 +155,7 @@ def test_batch_decode_with_paged_kv_cache( # num_kv_heads = 8 # head_dim = 128 # kv_len = 54 - + kv_layout = "NHD" pos_encoding_mode = "NONE" logits_soft_cap = 0.0 @@ -160,9 +163,9 @@ def test_batch_decode_with_paged_kv_cache( q_dtype = torch.float16 kv_dtype = torch.float16 contiguous_kv = True - - num_qo_heads = 32 - num_kv_heads = 4 + + num_qo_heads = 32 + num_kv_heads = 4 head_dim = 256 kv_len = 512 test_batch_decode_with_paged_kv_cache( @@ -178,5 +181,5 @@ def test_batch_decode_with_paged_kv_cache( return_lse, q_dtype, kv_dtype, - contiguous_kv) - + contiguous_kv, + ) diff --git a/libflashinfer/include/flashinfer/attention/generic/default_prefill_params.cuh b/libflashinfer/include/flashinfer/attention/generic/default_prefill_params.cuh index 55942d775d..5fdcdf52c8 100644 --- a/libflashinfer/include/flashinfer/attention/generic/default_prefill_params.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/default_prefill_params.cuh @@ -16,342 +16,380 @@ #ifndef FLASHINFER_PREFILL_PARAMS_CUH_ #define FLASHINFER_PREFILL_PARAMS_CUH_ -#include "gpu_iface/gpu_runtime_compat.hpp" - #include #include +#include "gpu_iface/gpu_runtime_compat.hpp" #include "page.cuh" -namespace flashinfer -{ +namespace flashinfer { template -struct SinglePrefillParams -{ - using DTypeQ = DTypeQ_; - using DTypeKV = DTypeKV_; - using DTypeO = DTypeO_; - using IdType = int32_t; - DTypeQ *q; - DTypeKV *k; - DTypeKV *v; - uint8_t *maybe_custom_mask; - DTypeO *o; - float *lse; - float *maybe_alibi_slopes; - uint_fastdiv group_size; - uint32_t num_qo_heads; - uint32_t num_kv_heads; - uint32_t qo_len; - uint32_t kv_len; - uint32_t q_stride_n; - uint32_t q_stride_h; - uint32_t k_stride_n; - uint32_t k_stride_h; - uint32_t v_stride_n; - uint32_t v_stride_h; - uint32_t head_dim; - int32_t window_left; - float logits_soft_cap; - float sm_scale; - float rope_rcp_scale; - float rope_rcp_theta; - uint32_t debug_thread_id; - uint32_t debug_warp_id; +struct SinglePrefillParams { + using DTypeQ = DTypeQ_; + using DTypeKV = DTypeKV_; + using DTypeO = DTypeO_; + using IdType = int32_t; + DTypeQ* q; + DTypeKV* k; + DTypeKV* v; + uint8_t* maybe_custom_mask; + DTypeO* o; + float* lse; + float* maybe_alibi_slopes; + uint_fastdiv group_size; + uint32_t num_qo_heads; + uint32_t num_kv_heads; + uint32_t qo_len; + uint32_t kv_len; + uint32_t q_stride_n; + uint32_t q_stride_h; + uint32_t k_stride_n; + uint32_t k_stride_h; + uint32_t v_stride_n; + uint32_t v_stride_h; + uint32_t head_dim; + int32_t window_left; + float logits_soft_cap; + float sm_scale; + float rope_rcp_scale; + float rope_rcp_theta; + uint32_t debug_thread_id; + uint32_t debug_warp_id; - uint32_t partition_kv; + uint32_t partition_kv; - __host__ SinglePrefillParams() - : q(nullptr), k(nullptr), v(nullptr), maybe_custom_mask(nullptr), - o(nullptr), lse(nullptr), maybe_alibi_slopes(nullptr), group_size(), - qo_len(0), kv_len(0), num_qo_heads(0), num_kv_heads(0), q_stride_n(0), - q_stride_h(0), k_stride_n(0), k_stride_h(0), v_stride_n(0), - v_stride_h(0), head_dim(0), window_left(0), logits_soft_cap(0.0f), - sm_scale(0.0f), rope_rcp_scale(0.0f), rope_rcp_theta(0.0f), - partition_kv(false) - { - } + __host__ SinglePrefillParams() + : q(nullptr), + k(nullptr), + v(nullptr), + maybe_custom_mask(nullptr), + o(nullptr), + lse(nullptr), + maybe_alibi_slopes(nullptr), + group_size(), + qo_len(0), + kv_len(0), + num_qo_heads(0), + num_kv_heads(0), + q_stride_n(0), + q_stride_h(0), + k_stride_n(0), + k_stride_h(0), + v_stride_n(0), + v_stride_h(0), + head_dim(0), + window_left(0), + logits_soft_cap(0.0f), + sm_scale(0.0f), + rope_rcp_scale(0.0f), + rope_rcp_theta(0.0f), + partition_kv(false) {} - __host__ SinglePrefillParams(DTypeQ *q, - DTypeKV *k, - DTypeKV *v, - uint8_t *maybe_custom_mask, - DTypeO *o, - float *lse, - float *maybe_alibi_slopes, - uint32_t num_qo_heads, - uint32_t num_kv_heads, - uint32_t qo_len, - uint32_t kv_len, - uint32_t q_stride_n, - uint32_t q_stride_h, - uint32_t kv_stride_n, - uint32_t kv_stride_h, - uint32_t head_dim, - int32_t window_left, - float logits_soft_cap, - float sm_scale, - float rope_scale, - float rope_theta, - uint32_t debug_thread_id, - uint32_t debug_warp_id) - : q(q), k(k), v(v), maybe_custom_mask(maybe_custom_mask), o(o), - lse(lse), maybe_alibi_slopes(maybe_alibi_slopes), - group_size(num_qo_heads / num_kv_heads), num_qo_heads(num_qo_heads), - num_kv_heads(num_kv_heads), qo_len(qo_len), kv_len(kv_len), - q_stride_n(q_stride_n), q_stride_h(q_stride_h), - k_stride_n(kv_stride_n), k_stride_h(kv_stride_h), - v_stride_n(kv_stride_n), v_stride_h(kv_stride_h), head_dim(head_dim), - window_left(window_left), logits_soft_cap(logits_soft_cap), - sm_scale(sm_scale), rope_rcp_scale(1. / rope_scale), - rope_rcp_theta(1. / rope_theta), debug_thread_id(debug_thread_id), - debug_warp_id(debug_warp_id), partition_kv(false) - { - } + __host__ SinglePrefillParams(DTypeQ* q, DTypeKV* k, DTypeKV* v, uint8_t* maybe_custom_mask, + DTypeO* o, float* lse, float* maybe_alibi_slopes, + uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t qo_len, + uint32_t kv_len, uint32_t q_stride_n, uint32_t q_stride_h, + uint32_t kv_stride_n, uint32_t kv_stride_h, uint32_t head_dim, + int32_t window_left, float logits_soft_cap, float sm_scale, + float rope_scale, float rope_theta, uint32_t debug_thread_id, + uint32_t debug_warp_id) + : q(q), + k(k), + v(v), + maybe_custom_mask(maybe_custom_mask), + o(o), + lse(lse), + maybe_alibi_slopes(maybe_alibi_slopes), + group_size(num_qo_heads / num_kv_heads), + num_qo_heads(num_qo_heads), + num_kv_heads(num_kv_heads), + qo_len(qo_len), + kv_len(kv_len), + q_stride_n(q_stride_n), + q_stride_h(q_stride_h), + k_stride_n(kv_stride_n), + k_stride_h(kv_stride_h), + v_stride_n(kv_stride_n), + v_stride_h(kv_stride_h), + head_dim(head_dim), + window_left(window_left), + logits_soft_cap(logits_soft_cap), + sm_scale(sm_scale), + rope_rcp_scale(1. / rope_scale), + rope_rcp_theta(1. / rope_theta), + debug_thread_id(debug_thread_id), + debug_warp_id(debug_warp_id), + partition_kv(false) {} - __host__ __device__ __forceinline__ uint32_t - get_qo_len(uint32_t batch_idx) const - { - return qo_len; - } + __host__ __device__ __forceinline__ uint32_t get_qo_len(uint32_t batch_idx) const { + return qo_len; + } - __host__ __device__ __forceinline__ uint32_t - get_kv_len(uint32_t batch_idx) const - { - return kv_len; - } + __host__ __device__ __forceinline__ uint32_t get_kv_len(uint32_t batch_idx) const { + return kv_len; + } }; -template -struct BatchPrefillRaggedParams -{ - using DTypeQ = DTypeQ_; - using DTypeKV = DTypeKV_; - using DTypeO = DTypeO_; - using IdType = IdType_; +template +struct BatchPrefillRaggedParams { + using DTypeQ = DTypeQ_; + using DTypeKV = DTypeKV_; + using DTypeO = DTypeO_; + using IdType = IdType_; - DTypeQ *q; - DTypeKV *k; - DTypeKV *v; - uint8_t *maybe_custom_mask; - IdType *q_indptr; - IdType *kv_indptr; - IdType *maybe_mask_indptr; - IdType *maybe_q_rope_offset; // maybe_q_rope_offset is only used for - // fused-rope attention - IdType *maybe_k_rope_offset; // maybe_k_rope_offset is only used for - // fused-rope attention - DTypeO *o; - float *lse; - float *maybe_alibi_slopes; - uint_fastdiv group_size; - uint32_t num_qo_heads; - uint32_t num_kv_heads; - uint32_t q_stride_n; - uint32_t q_stride_h; - uint32_t k_stride_n; - uint32_t k_stride_h; - uint32_t v_stride_n; - uint32_t v_stride_h; - int32_t window_left; - float logits_soft_cap; - float sm_scale; - float rope_rcp_scale; - float rope_rcp_theta; + DTypeQ* q; + DTypeKV* k; + DTypeKV* v; + uint8_t* maybe_custom_mask; + IdType* q_indptr; + IdType* kv_indptr; + IdType* maybe_mask_indptr; + IdType* maybe_q_rope_offset; // maybe_q_rope_offset is only used for + // fused-rope attention + IdType* maybe_k_rope_offset; // maybe_k_rope_offset is only used for + // fused-rope attention + DTypeO* o; + float* lse; + float* maybe_alibi_slopes; + uint_fastdiv group_size; + uint32_t num_qo_heads; + uint32_t num_kv_heads; + uint32_t q_stride_n; + uint32_t q_stride_h; + uint32_t k_stride_n; + uint32_t k_stride_h; + uint32_t v_stride_n; + uint32_t v_stride_h; + int32_t window_left; + float logits_soft_cap; + float sm_scale; + float rope_rcp_scale; + float rope_rcp_theta; - IdType *request_indices; - IdType *qo_tile_indices; - IdType *kv_tile_indices; - IdType *merge_indptr; - IdType *o_indptr; - IdType *kv_chunk_size_ptr; - bool *block_valid_mask; - uint32_t max_total_num_rows; - uint32_t *total_num_rows; - uint32_t padded_batch_size; - bool partition_kv; + IdType* request_indices; + IdType* qo_tile_indices; + IdType* kv_tile_indices; + IdType* merge_indptr; + IdType* o_indptr; + IdType* kv_chunk_size_ptr; + bool* block_valid_mask; + uint32_t max_total_num_rows; + uint32_t* total_num_rows; + uint32_t padded_batch_size; + bool partition_kv; - __host__ BatchPrefillRaggedParams() - : q(nullptr), k(nullptr), v(nullptr), maybe_custom_mask(nullptr), - q_indptr(nullptr), kv_indptr(nullptr), maybe_mask_indptr(nullptr), - maybe_q_rope_offset(nullptr), maybe_k_rope_offset(nullptr), - o(nullptr), lse(nullptr), maybe_alibi_slopes(nullptr), group_size(), - num_qo_heads(0), num_kv_heads(0), q_stride_n(0), q_stride_h(0), - k_stride_n(0), k_stride_h(0), v_stride_n(0), v_stride_h(0), - window_left(0), logits_soft_cap(0.0f), sm_scale(0.0f), - rope_rcp_scale(0.0f), rope_rcp_theta(0.0f), request_indices(nullptr), - qo_tile_indices(nullptr), kv_tile_indices(nullptr), - merge_indptr(nullptr), o_indptr(nullptr), kv_chunk_size_ptr(nullptr), - block_valid_mask(nullptr), max_total_num_rows(0), - total_num_rows(nullptr), padded_batch_size(0), partition_kv(false) - { - } + __host__ BatchPrefillRaggedParams() + : q(nullptr), + k(nullptr), + v(nullptr), + maybe_custom_mask(nullptr), + q_indptr(nullptr), + kv_indptr(nullptr), + maybe_mask_indptr(nullptr), + maybe_q_rope_offset(nullptr), + maybe_k_rope_offset(nullptr), + o(nullptr), + lse(nullptr), + maybe_alibi_slopes(nullptr), + group_size(), + num_qo_heads(0), + num_kv_heads(0), + q_stride_n(0), + q_stride_h(0), + k_stride_n(0), + k_stride_h(0), + v_stride_n(0), + v_stride_h(0), + window_left(0), + logits_soft_cap(0.0f), + sm_scale(0.0f), + rope_rcp_scale(0.0f), + rope_rcp_theta(0.0f), + request_indices(nullptr), + qo_tile_indices(nullptr), + kv_tile_indices(nullptr), + merge_indptr(nullptr), + o_indptr(nullptr), + kv_chunk_size_ptr(nullptr), + block_valid_mask(nullptr), + max_total_num_rows(0), + total_num_rows(nullptr), + padded_batch_size(0), + partition_kv(false) {} - __host__ BatchPrefillRaggedParams(DTypeQ *q, - DTypeKV *k, - DTypeKV *v, - uint8_t *maybe_custom_mask, - IdType *q_indptr, - IdType *kv_indptr, - IdType *maybe_mask_indptr, - IdType *maybe_q_rope_offset, - IdType *maybe_k_rope_offset, - DTypeO *o, - float *lse, - float *maybe_alibi_slopes, - uint32_t num_qo_heads, - uint32_t num_kv_heads, - uint32_t q_stride_n, - uint32_t q_stride_h, - uint32_t kv_stride_n, - uint32_t kv_stride_h, - int32_t window_left, - float logits_soft_cap, - float sm_scale, - float rope_scale, - float rope_theta) - : q(q), k(k), v(v), maybe_custom_mask(maybe_custom_mask), - q_indptr(q_indptr), kv_indptr(kv_indptr), - maybe_mask_indptr(maybe_mask_indptr), - maybe_q_rope_offset(maybe_q_rope_offset), - maybe_k_rope_offset(maybe_k_rope_offset), o(o), lse(lse), - maybe_alibi_slopes(maybe_alibi_slopes), - group_size(num_qo_heads / num_kv_heads), num_qo_heads(num_qo_heads), - num_kv_heads(num_kv_heads), q_stride_n(q_stride_n), - q_stride_h(q_stride_h), k_stride_n(kv_stride_n), - k_stride_h(kv_stride_h), v_stride_n(kv_stride_n), - v_stride_h(kv_stride_h), window_left(window_left), - logits_soft_cap(logits_soft_cap), sm_scale(sm_scale), - rope_rcp_scale(1.f / rope_scale), rope_rcp_theta(1.f / rope_theta), - request_indices(nullptr), qo_tile_indices(nullptr), - kv_tile_indices(nullptr), merge_indptr(nullptr), o_indptr(nullptr), - kv_chunk_size_ptr(nullptr), block_valid_mask(nullptr), - max_total_num_rows(0), total_num_rows(nullptr), padded_batch_size(0), - partition_kv(false) - { - } + __host__ BatchPrefillRaggedParams(DTypeQ* q, DTypeKV* k, DTypeKV* v, uint8_t* maybe_custom_mask, + IdType* q_indptr, IdType* kv_indptr, IdType* maybe_mask_indptr, + IdType* maybe_q_rope_offset, IdType* maybe_k_rope_offset, + DTypeO* o, float* lse, float* maybe_alibi_slopes, + uint32_t num_qo_heads, uint32_t num_kv_heads, + uint32_t q_stride_n, uint32_t q_stride_h, uint32_t kv_stride_n, + uint32_t kv_stride_h, int32_t window_left, + float logits_soft_cap, float sm_scale, float rope_scale, + float rope_theta) + : q(q), + k(k), + v(v), + maybe_custom_mask(maybe_custom_mask), + q_indptr(q_indptr), + kv_indptr(kv_indptr), + maybe_mask_indptr(maybe_mask_indptr), + maybe_q_rope_offset(maybe_q_rope_offset), + maybe_k_rope_offset(maybe_k_rope_offset), + o(o), + lse(lse), + maybe_alibi_slopes(maybe_alibi_slopes), + group_size(num_qo_heads / num_kv_heads), + num_qo_heads(num_qo_heads), + num_kv_heads(num_kv_heads), + q_stride_n(q_stride_n), + q_stride_h(q_stride_h), + k_stride_n(kv_stride_n), + k_stride_h(kv_stride_h), + v_stride_n(kv_stride_n), + v_stride_h(kv_stride_h), + window_left(window_left), + logits_soft_cap(logits_soft_cap), + sm_scale(sm_scale), + rope_rcp_scale(1.f / rope_scale), + rope_rcp_theta(1.f / rope_theta), + request_indices(nullptr), + qo_tile_indices(nullptr), + kv_tile_indices(nullptr), + merge_indptr(nullptr), + o_indptr(nullptr), + kv_chunk_size_ptr(nullptr), + block_valid_mask(nullptr), + max_total_num_rows(0), + total_num_rows(nullptr), + padded_batch_size(0), + partition_kv(false) {} - __host__ __device__ __forceinline__ uint32_t - get_qo_len(uint32_t batch_idx) const - { - return q_indptr[batch_idx + 1] - q_indptr[batch_idx]; - } + __host__ __device__ __forceinline__ uint32_t get_qo_len(uint32_t batch_idx) const { + return q_indptr[batch_idx + 1] - q_indptr[batch_idx]; + } - __host__ __device__ __forceinline__ uint32_t - get_kv_len(uint32_t batch_idx) const - { - return kv_indptr[batch_idx + 1] - kv_indptr[batch_idx]; - } + __host__ __device__ __forceinline__ uint32_t get_kv_len(uint32_t batch_idx) const { + return kv_indptr[batch_idx + 1] - kv_indptr[batch_idx]; + } }; -template -struct BatchPrefillPagedParams -{ - using DTypeQ = DTypeQ_; - using DTypeKV = DTypeKV_; - using DTypeO = DTypeO_; - using IdType = IdType_; +template +struct BatchPrefillPagedParams { + using DTypeQ = DTypeQ_; + using DTypeKV = DTypeKV_; + using DTypeO = DTypeO_; + using IdType = IdType_; - DTypeQ *q; - paged_kv_t paged_kv; - uint8_t *maybe_custom_mask; - IdType *q_indptr; - IdType *maybe_mask_indptr; - IdType *maybe_q_rope_offset; // maybe_q_rope_offset is only used for - // fused-rope attention - DTypeO *o; - float *lse; - float *maybe_alibi_slopes; - uint_fastdiv group_size; - uint32_t num_qo_heads; - IdType q_stride_n; - IdType q_stride_h; - int32_t window_left; - float logits_soft_cap; - float sm_scale; - float rope_rcp_scale; - float rope_rcp_theta; + DTypeQ* q; + paged_kv_t paged_kv; + uint8_t* maybe_custom_mask; + IdType* q_indptr; + IdType* maybe_mask_indptr; + IdType* maybe_q_rope_offset; // maybe_q_rope_offset is only used for + // fused-rope attention + DTypeO* o; + float* lse; + float* maybe_alibi_slopes; + uint_fastdiv group_size; + uint32_t num_qo_heads; + IdType q_stride_n; + IdType q_stride_h; + int32_t window_left; + float logits_soft_cap; + float sm_scale; + float rope_rcp_scale; + float rope_rcp_theta; - IdType *request_indices; - IdType *qo_tile_indices; - IdType *kv_tile_indices; - IdType *merge_indptr; - IdType *o_indptr; - bool *block_valid_mask; - IdType *kv_chunk_size_ptr; - uint32_t max_total_num_rows; - uint32_t *total_num_rows; - uint32_t padded_batch_size; - bool partition_kv; + IdType* request_indices; + IdType* qo_tile_indices; + IdType* kv_tile_indices; + IdType* merge_indptr; + IdType* o_indptr; + bool* block_valid_mask; + IdType* kv_chunk_size_ptr; + uint32_t max_total_num_rows; + uint32_t* total_num_rows; + uint32_t padded_batch_size; + bool partition_kv; - __host__ BatchPrefillPagedParams() - : q(nullptr), paged_kv(), maybe_custom_mask(nullptr), q_indptr(nullptr), - maybe_mask_indptr(nullptr), maybe_q_rope_offset(nullptr), o(nullptr), - lse(nullptr), maybe_alibi_slopes(nullptr), group_size(), - num_qo_heads(0), q_stride_n(0), q_stride_h(0), window_left(0), - logits_soft_cap(0.0f), sm_scale(0.0f), rope_rcp_scale(0.0f), - rope_rcp_theta(0.0f), request_indices(nullptr), - qo_tile_indices(nullptr), kv_tile_indices(nullptr), - merge_indptr(nullptr), o_indptr(nullptr), block_valid_mask(nullptr), - kv_chunk_size_ptr(nullptr), max_total_num_rows(0), - total_num_rows(nullptr), padded_batch_size(0), partition_kv(false) - { - } + __host__ BatchPrefillPagedParams() + : q(nullptr), + paged_kv(), + maybe_custom_mask(nullptr), + q_indptr(nullptr), + maybe_mask_indptr(nullptr), + maybe_q_rope_offset(nullptr), + o(nullptr), + lse(nullptr), + maybe_alibi_slopes(nullptr), + group_size(), + num_qo_heads(0), + q_stride_n(0), + q_stride_h(0), + window_left(0), + logits_soft_cap(0.0f), + sm_scale(0.0f), + rope_rcp_scale(0.0f), + rope_rcp_theta(0.0f), + request_indices(nullptr), + qo_tile_indices(nullptr), + kv_tile_indices(nullptr), + merge_indptr(nullptr), + o_indptr(nullptr), + block_valid_mask(nullptr), + kv_chunk_size_ptr(nullptr), + max_total_num_rows(0), + total_num_rows(nullptr), + padded_batch_size(0), + partition_kv(false) {} - __host__ BatchPrefillPagedParams(DTypeQ *q, - paged_kv_t paged_kv, - uint8_t *maybe_custom_mask, - IdType *q_indptr, - IdType *maybe_mask_indptr, - IdType *maybe_q_rope_offset, - DTypeO *o, - float *lse, - float *maybe_alibi_slopes, - uint32_t num_qo_heads, - IdType q_stride_n, - IdType q_stride_h, - int32_t window_left, - float logits_soft_cap, - float sm_scale, - float rope_scale, - float rope_theta) - : q(q), paged_kv(paged_kv), maybe_custom_mask(maybe_custom_mask), - q_indptr(q_indptr), maybe_mask_indptr(maybe_mask_indptr), - maybe_q_rope_offset(maybe_q_rope_offset), o(o), lse(lse), - maybe_alibi_slopes(maybe_alibi_slopes), - group_size(num_qo_heads / paged_kv.num_heads), - num_qo_heads(num_qo_heads), q_stride_n(q_stride_n), - q_stride_h(q_stride_h), window_left(window_left), - logits_soft_cap(logits_soft_cap), sm_scale(sm_scale), - rope_rcp_scale(1.f / rope_scale), rope_rcp_theta(1.f / rope_theta), - request_indices(nullptr), qo_tile_indices(nullptr), - kv_tile_indices(nullptr), merge_indptr(nullptr), o_indptr(nullptr), - block_valid_mask(nullptr), kv_chunk_size_ptr(nullptr), - max_total_num_rows(0), total_num_rows(nullptr), padded_batch_size(0), - partition_kv(false) - { - } + __host__ BatchPrefillPagedParams(DTypeQ* q, paged_kv_t paged_kv, + uint8_t* maybe_custom_mask, IdType* q_indptr, + IdType* maybe_mask_indptr, IdType* maybe_q_rope_offset, + DTypeO* o, float* lse, float* maybe_alibi_slopes, + uint32_t num_qo_heads, IdType q_stride_n, IdType q_stride_h, + int32_t window_left, float logits_soft_cap, float sm_scale, + float rope_scale, float rope_theta) + : q(q), + paged_kv(paged_kv), + maybe_custom_mask(maybe_custom_mask), + q_indptr(q_indptr), + maybe_mask_indptr(maybe_mask_indptr), + maybe_q_rope_offset(maybe_q_rope_offset), + o(o), + lse(lse), + maybe_alibi_slopes(maybe_alibi_slopes), + group_size(num_qo_heads / paged_kv.num_heads), + num_qo_heads(num_qo_heads), + q_stride_n(q_stride_n), + q_stride_h(q_stride_h), + window_left(window_left), + logits_soft_cap(logits_soft_cap), + sm_scale(sm_scale), + rope_rcp_scale(1.f / rope_scale), + rope_rcp_theta(1.f / rope_theta), + request_indices(nullptr), + qo_tile_indices(nullptr), + kv_tile_indices(nullptr), + merge_indptr(nullptr), + o_indptr(nullptr), + block_valid_mask(nullptr), + kv_chunk_size_ptr(nullptr), + max_total_num_rows(0), + total_num_rows(nullptr), + padded_batch_size(0), + partition_kv(false) {} - __host__ __device__ __forceinline__ uint32_t - get_qo_len(uint32_t batch_idx) const - { - return q_indptr[batch_idx + 1] - q_indptr[batch_idx]; - } + __host__ __device__ __forceinline__ uint32_t get_qo_len(uint32_t batch_idx) const { + return q_indptr[batch_idx + 1] - q_indptr[batch_idx]; + } - __host__ __device__ __forceinline__ uint32_t - get_kv_len(uint32_t batch_idx) const - { - return paged_kv.get_length(batch_idx); - } + __host__ __device__ __forceinline__ uint32_t get_kv_len(uint32_t batch_idx) const { + return paged_kv.get_length(batch_idx); + } }; -} // namespace flashinfer +} // namespace flashinfer -#endif // FLASHINFER_DECODE_PARAMS_CUH_ +#endif // FLASHINFER_DECODE_PARAMS_CUH_ diff --git a/libflashinfer/include/flashinfer/attention/generic/dispatch.cuh b/libflashinfer/include/flashinfer/attention/generic/dispatch.cuh index d36536a6f8..abe0a3020e 100644 --- a/libflashinfer/include/flashinfer/attention/generic/dispatch.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/dispatch.cuh @@ -8,249 +8,209 @@ #include "gpu_iface/enums.hpp" #include "gpu_iface/exception.h" -#define DISPATCH_USE_FP16_QK_REDUCTION(use_fp16_qk_reduction, \ - USE_FP16_QK_REDUCTION, ...) \ - if (use_fp16_qk_reduction) { \ - FLASHINFER_ERROR("FP16_QK_REDUCTION disabled at compile time"); \ - } \ - else { \ - constexpr bool USE_FP16_QK_REDUCTION = false; \ - __VA_ARGS__ \ - } +#define DISPATCH_USE_FP16_QK_REDUCTION(use_fp16_qk_reduction, USE_FP16_QK_REDUCTION, ...) \ + if (use_fp16_qk_reduction) { \ + FLASHINFER_ERROR("FP16_QK_REDUCTION disabled at compile time"); \ + } else { \ + constexpr bool USE_FP16_QK_REDUCTION = false; \ + __VA_ARGS__ \ + } -#define DISPATCH_NUM_MMA_Q(num_mma_q, NUM_MMA_Q, ...) \ - if (num_mma_q == 1) { \ - constexpr size_t NUM_MMA_Q = 1; \ - __VA_ARGS__ \ - } \ - else if (num_mma_q == 2) { \ - constexpr size_t NUM_MMA_Q = 2; \ - __VA_ARGS__ \ - } \ - else { \ - std::ostringstream err_msg; \ - err_msg << "Unsupported num_mma_q: " << num_mma_q; \ - FLASHINFER_ERROR(err_msg.str()); \ - } +#define DISPATCH_NUM_MMA_Q(num_mma_q, NUM_MMA_Q, ...) \ + if (num_mma_q == 1) { \ + constexpr size_t NUM_MMA_Q = 1; \ + __VA_ARGS__ \ + } else if (num_mma_q == 2) { \ + constexpr size_t NUM_MMA_Q = 2; \ + __VA_ARGS__ \ + } else { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported num_mma_q: " << num_mma_q; \ + FLASHINFER_ERROR(err_msg.str()); \ + } -#define DISPATCH_NUM_MMA_KV(max_mma_kv, NUM_MMA_KV, ...) \ - if (max_mma_kv >= 8) { \ - constexpr size_t NUM_MMA_KV = 8; \ - __VA_ARGS__ \ - } \ - else if (max_mma_kv >= 4) { \ - constexpr size_t NUM_MMA_KV = 4; \ - __VA_ARGS__ \ - } \ - else if (max_mma_kv >= 2) { \ - constexpr size_t NUM_MMA_KV = 2; \ - __VA_ARGS__ \ - } \ - else if (max_mma_kv >= 1) { \ - constexpr size_t NUM_MMA_KV = 1; \ - __VA_ARGS__ \ - } \ - else { \ - std::ostringstream err_msg; \ - err_msg << "Unsupported max_mma_kv: " << max_mma_kv; \ - FLASHINFER_ERROR(err_msg.str()); \ - } +#define DISPATCH_NUM_MMA_KV(max_mma_kv, NUM_MMA_KV, ...) \ + if (max_mma_kv >= 8) { \ + constexpr size_t NUM_MMA_KV = 8; \ + __VA_ARGS__ \ + } else if (max_mma_kv >= 4) { \ + constexpr size_t NUM_MMA_KV = 4; \ + __VA_ARGS__ \ + } else if (max_mma_kv >= 2) { \ + constexpr size_t NUM_MMA_KV = 2; \ + __VA_ARGS__ \ + } else if (max_mma_kv >= 1) { \ + constexpr size_t NUM_MMA_KV = 1; \ + __VA_ARGS__ \ + } else { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported max_mma_kv: " << max_mma_kv; \ + FLASHINFER_ERROR(err_msg.str()); \ + } -#define DISPATCH_CTA_TILE_Q(cta_tile_q, CTA_TILE_Q, ...) \ - switch (cta_tile_q) { \ - case 128: \ - { \ - constexpr uint32_t CTA_TILE_Q = 128; \ - __VA_ARGS__ \ - break; \ - } \ - case 64: \ - { \ - constexpr uint32_t CTA_TILE_Q = 64; \ - __VA_ARGS__ \ - break; \ - } \ - case 16: \ - { \ - constexpr uint32_t CTA_TILE_Q = 16; \ - __VA_ARGS__ \ - break; \ - } \ - default: \ - { \ - std::ostringstream err_msg; \ - err_msg << "Unsupported cta_tile_q: " << cta_tile_q; \ - FLASHINFER_ERROR(err_msg.str()); \ - } \ - } +#define DISPATCH_CTA_TILE_Q(cta_tile_q, CTA_TILE_Q, ...) \ + switch (cta_tile_q) { \ + case 128: { \ + constexpr uint32_t CTA_TILE_Q = 128; \ + __VA_ARGS__ \ + break; \ + } \ + case 64: { \ + constexpr uint32_t CTA_TILE_Q = 64; \ + __VA_ARGS__ \ + break; \ + } \ + case 16: { \ + constexpr uint32_t CTA_TILE_Q = 16; \ + __VA_ARGS__ \ + break; \ + } \ + default: { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported cta_tile_q: " << cta_tile_q; \ + FLASHINFER_ERROR(err_msg.str()); \ + } \ + } -#define DISPATCH_GQA_GROUP_SIZE(group_size, GROUP_SIZE, ...) \ - if (group_size == 1) { \ - constexpr size_t GROUP_SIZE = 1; \ - __VA_ARGS__ \ - } \ - else if (group_size == 2) { \ - constexpr size_t GROUP_SIZE = 2; \ - __VA_ARGS__ \ - } \ - else if (group_size == 3) { \ - constexpr size_t GROUP_SIZE = 3; \ - __VA_ARGS__ \ - } \ - else if (group_size == 4) { \ - constexpr size_t GROUP_SIZE = 4; \ - __VA_ARGS__ \ - } \ - else if (group_size == 8) { \ - constexpr size_t GROUP_SIZE = 8; \ - __VA_ARGS__ \ - } \ - else { \ - std::ostringstream err_msg; \ - err_msg << "Unsupported group_size: " << group_size; \ - FLASHINFER_ERROR(err_msg.str()); \ - } +#define DISPATCH_GQA_GROUP_SIZE(group_size, GROUP_SIZE, ...) \ + if (group_size == 1) { \ + constexpr size_t GROUP_SIZE = 1; \ + __VA_ARGS__ \ + } else if (group_size == 2) { \ + constexpr size_t GROUP_SIZE = 2; \ + __VA_ARGS__ \ + } else if (group_size == 3) { \ + constexpr size_t GROUP_SIZE = 3; \ + __VA_ARGS__ \ + } else if (group_size == 4) { \ + constexpr size_t GROUP_SIZE = 4; \ + __VA_ARGS__ \ + } else if (group_size == 8) { \ + constexpr size_t GROUP_SIZE = 8; \ + __VA_ARGS__ \ + } else { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported group_size: " << group_size; \ + FLASHINFER_ERROR(err_msg.str()); \ + } -#define DISPATCH_MASK_MODE(mask_mode, MASK_MODE, ...) \ - switch (mask_mode) { \ - case MaskMode::kNone: \ - { \ - constexpr MaskMode MASK_MODE = MaskMode::kNone; \ - __VA_ARGS__ \ - break; \ - } \ - case MaskMode::kCausal: \ - { \ - constexpr MaskMode MASK_MODE = MaskMode::kCausal; \ - __VA_ARGS__ \ - break; \ - } \ - case MaskMode::kCustom: \ - { \ - constexpr MaskMode MASK_MODE = MaskMode::kCustom; \ - __VA_ARGS__ \ - break; \ - } \ - default: \ - { \ - std::ostringstream err_msg; \ - err_msg << "Unsupported mask_mode: " << int(mask_mode); \ - FLASHINFER_ERROR(err_msg.str()); \ - } \ - } +#define DISPATCH_MASK_MODE(mask_mode, MASK_MODE, ...) \ + switch (mask_mode) { \ + case MaskMode::kNone: { \ + constexpr MaskMode MASK_MODE = MaskMode::kNone; \ + __VA_ARGS__ \ + break; \ + } \ + case MaskMode::kCausal: { \ + constexpr MaskMode MASK_MODE = MaskMode::kCausal; \ + __VA_ARGS__ \ + break; \ + } \ + case MaskMode::kCustom: { \ + constexpr MaskMode MASK_MODE = MaskMode::kCustom; \ + __VA_ARGS__ \ + break; \ + } \ + default: { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported mask_mode: " << int(mask_mode); \ + FLASHINFER_ERROR(err_msg.str()); \ + } \ + } // convert head_dim to compile-time constant -#define DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, ...) \ - switch (head_dim) { \ - case 64: \ - { \ - constexpr size_t HEAD_DIM = 64; \ - __VA_ARGS__ \ - break; \ - } \ - case 128: \ - { \ - constexpr size_t HEAD_DIM = 128; \ - __VA_ARGS__ \ - break; \ - } \ - case 256: \ - { \ - constexpr size_t HEAD_DIM = 256; \ - __VA_ARGS__ \ - break; \ - } \ - case 512: \ - { \ - constexpr size_t HEAD_DIM = 512; \ - __VA_ARGS__ \ - break; \ - } \ - default: \ - { \ - std::ostringstream err_msg; \ - err_msg << "Unsupported head_dim: " << head_dim; \ - FLASHINFER_ERROR(err_msg.str()); \ - } \ - } +#define DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, ...) \ + switch (head_dim) { \ + case 64: { \ + constexpr size_t HEAD_DIM = 64; \ + __VA_ARGS__ \ + break; \ + } \ + case 128: { \ + constexpr size_t HEAD_DIM = 128; \ + __VA_ARGS__ \ + break; \ + } \ + case 256: { \ + constexpr size_t HEAD_DIM = 256; \ + __VA_ARGS__ \ + break; \ + } \ + case 512: { \ + constexpr size_t HEAD_DIM = 512; \ + __VA_ARGS__ \ + break; \ + } \ + default: { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported head_dim: " << head_dim; \ + FLASHINFER_ERROR(err_msg.str()); \ + } \ + } -#define DISPATCH_POS_ENCODING_MODE(pos_encoding_mode, POS_ENCODING_MODE, ...) \ - switch (pos_encoding_mode) { \ - case PosEncodingMode::kNone: \ - { \ - constexpr PosEncodingMode POS_ENCODING_MODE = PosEncodingMode::kNone; \ - __VA_ARGS__ \ - break; \ - } \ - case PosEncodingMode::kRoPELlama: \ - { \ - constexpr PosEncodingMode POS_ENCODING_MODE = \ - PosEncodingMode::kRoPELlama; \ - __VA_ARGS__ \ - break; \ - } \ - case PosEncodingMode::kALiBi: \ - { \ - constexpr PosEncodingMode POS_ENCODING_MODE = PosEncodingMode::kALiBi; \ - __VA_ARGS__ \ - break; \ - } \ - default: \ - { \ - std::ostringstream err_msg; \ - err_msg << "Unsupported pos_encoding_mode: " \ - << int(pos_encoding_mode); \ - FLASHINFER_ERROR(err_msg.str()); \ - } \ - } +#define DISPATCH_POS_ENCODING_MODE(pos_encoding_mode, POS_ENCODING_MODE, ...) \ + switch (pos_encoding_mode) { \ + case PosEncodingMode::kNone: { \ + constexpr PosEncodingMode POS_ENCODING_MODE = PosEncodingMode::kNone; \ + __VA_ARGS__ \ + break; \ + } \ + case PosEncodingMode::kRoPELlama: { \ + constexpr PosEncodingMode POS_ENCODING_MODE = PosEncodingMode::kRoPELlama; \ + __VA_ARGS__ \ + break; \ + } \ + case PosEncodingMode::kALiBi: { \ + constexpr PosEncodingMode POS_ENCODING_MODE = PosEncodingMode::kALiBi; \ + __VA_ARGS__ \ + break; \ + } \ + default: { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported pos_encoding_mode: " << int(pos_encoding_mode); \ + FLASHINFER_ERROR(err_msg.str()); \ + } \ + } -#define DISPATCH_ALIGNED_VEC_SIZE(aligned_vec_size, ALIGNED_VEC_SIZE, ...) \ - switch (aligned_vec_size) { \ - case 16: \ - { \ - constexpr size_t ALIGNED_VEC_SIZE = 16; \ - __VA_ARGS__ \ - break; \ - } \ - case 8: \ - { \ - constexpr size_t ALIGNED_VEC_SIZE = 8; \ - __VA_ARGS__ \ - break; \ - } \ - case 4: \ - { \ - constexpr size_t ALIGNED_VEC_SIZE = 4; \ - __VA_ARGS__ \ - break; \ - } \ - case 2: \ - { \ - constexpr size_t ALIGNED_VEC_SIZE = 2; \ - __VA_ARGS__ \ - break; \ - } \ - case 1: \ - { \ - constexpr size_t ALIGNED_VEC_SIZE = 1; \ - __VA_ARGS__ \ - break; \ - } \ - default: \ - { \ - std::ostringstream err_msg; \ - err_msg << "Unsupported aligned_vec_size: " << aligned_vec_size; \ - FLASHINFER_ERROR(err_msg.str()); \ - } \ - } +#define DISPATCH_ALIGNED_VEC_SIZE(aligned_vec_size, ALIGNED_VEC_SIZE, ...) \ + switch (aligned_vec_size) { \ + case 16: { \ + constexpr size_t ALIGNED_VEC_SIZE = 16; \ + __VA_ARGS__ \ + break; \ + } \ + case 8: { \ + constexpr size_t ALIGNED_VEC_SIZE = 8; \ + __VA_ARGS__ \ + break; \ + } \ + case 4: { \ + constexpr size_t ALIGNED_VEC_SIZE = 4; \ + __VA_ARGS__ \ + break; \ + } \ + case 2: { \ + constexpr size_t ALIGNED_VEC_SIZE = 2; \ + __VA_ARGS__ \ + break; \ + } \ + case 1: { \ + constexpr size_t ALIGNED_VEC_SIZE = 1; \ + __VA_ARGS__ \ + break; \ + } \ + default: { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported aligned_vec_size: " << aligned_vec_size; \ + FLASHINFER_ERROR(err_msg.str()); \ + } \ + } -#define DISPATCH_COMPUTE_CAP_DECODE_NUM_STAGES_SMEM(compute_capacity, \ - NUM_STAGES_SMEM, ...) \ - if (compute_capacity.first >= 8) { \ - constexpr uint32_t NUM_STAGES_SMEM = 2; \ - __VA_ARGS__ \ - } \ - else { \ - constexpr uint32_t NUM_STAGES_SMEM = 1; \ - __VA_ARGS__ \ - } +#define DISPATCH_COMPUTE_CAP_DECODE_NUM_STAGES_SMEM(compute_capacity, NUM_STAGES_SMEM, ...) \ + if (compute_capacity.first >= 8) { \ + constexpr uint32_t NUM_STAGES_SMEM = 2; \ + __VA_ARGS__ \ + } else { \ + constexpr uint32_t NUM_STAGES_SMEM = 1; \ + __VA_ARGS__ \ + } diff --git a/scripts/run_hip_tests.sh b/scripts/run_hip_tests.sh index 73a17db3e1..7061277da8 100755 --- a/scripts/run_hip_tests.sh +++ b/scripts/run_hip_tests.sh @@ -10,4 +10,3 @@ python -m pytest ../tests/test_sliding_window_hip.py \ ../tests/test_norm_hip.py \ ../tests/test_logits_cap_hip.py \ ../tests/test_non_contiguous_decode_hip.py \ - diff --git a/tests/test_batch_decode_kernels_hip.py b/tests/test_batch_decode_kernels_hip.py index d70b3b47ab..e0db60a548 100644 --- a/tests/test_batch_decode_kernels_hip.py +++ b/tests/test_batch_decode_kernels_hip.py @@ -16,12 +16,11 @@ import pytest import torch -from jit_utils import ( - jit_decode_attention_func_args -) +from jit_utils import jit_decode_attention_func_args import flashinfer + @pytest.fixture(autouse=True, scope="module") def warmup_jit(): if flashinfer.jit.has_prebuilt_ops: @@ -318,6 +317,7 @@ def test_batch_decode_with_tuple_paged_kv_cache( ) torch.testing.assert_close(o[i], o_ref_i, rtol=1e-3, atol=1e-3) + if __name__ == "__main__": test_batch_decode_with_paged_kv_cache( 256, @@ -395,4 +395,3 @@ def test_batch_decode_with_tuple_paged_kv_cache( torch.float16, True, ) - diff --git a/tests/test_logits_cap_hip.py b/tests/test_logits_cap_hip.py index c3a1f475c9..7c43284b29 100644 --- a/tests/test_logits_cap_hip.py +++ b/tests/test_logits_cap_hip.py @@ -18,9 +18,7 @@ import pytest import torch -from jit_utils import ( - jit_decode_attention_func_args -) +from jit_utils import jit_decode_attention_func_args import flashinfer @@ -75,5 +73,6 @@ def test_single_decode_logits_soft_cap( o_ref = attention_logits_soft_cap_torch(q.unsqueeze(0), k, v, soft_cap).squeeze(0) torch.testing.assert_close(o, o_ref, rtol=1e-3, atol=1e-3) + if __name__ == "__main__": test_single_decode_logits_soft_cap(9, 32, 128, 30.0) diff --git a/tests/test_non_contiguous_decode_hip.py b/tests/test_non_contiguous_decode_hip.py index a526eb48d5..1836c054eb 100644 --- a/tests/test_non_contiguous_decode_hip.py +++ b/tests/test_non_contiguous_decode_hip.py @@ -1,8 +1,6 @@ import pytest import torch -from jit_utils import ( - jit_decode_attention_func_args -) +from jit_utils import jit_decode_attention_func_args import flashinfer diff --git a/tests/test_norm_hip.py b/tests/test_norm_hip.py index 3c5fbf891d..41f4faca8e 100644 --- a/tests/test_norm_hip.py +++ b/tests/test_norm_hip.py @@ -103,7 +103,6 @@ def test_fused_add_rmsnorm(batch_size, hidden_size, dtype, enable_pdl, contiguou x = torch.randn(batch_size, hidden_size * 2, device="cuda").to(dtype) x = x[:, :hidden_size] - residual = torch.randn_like(x) weight = torch.randn(hidden_size, dtype=dtype, device="cuda") diff --git a/tests/test_rope.py b/tests/test_rope.py index dbca4670b4..4e0c40b1cc 100644 --- a/tests/test_rope.py +++ b/tests/test_rope.py @@ -301,7 +301,6 @@ def forward_cuda( (256, 128, 4096, 9231, False, torch.bfloat16, "cuda", 3, 231, 4, 2), ], ) - def test_rope_cos_sin_cache( head_size: int, rotary_dim: int, diff --git a/tests/test_sliding_window_hip.py b/tests/test_sliding_window_hip.py index 0699fe98b7..39fa15d5cd 100644 --- a/tests/test_sliding_window_hip.py +++ b/tests/test_sliding_window_hip.py @@ -16,9 +16,7 @@ import pytest import torch -from jit_utils import ( - jit_decode_attention_func_args -) +from jit_utils import jit_decode_attention_func_args import flashinfer From 68d1bc2ffae2e0c9777ca5fbd37a2c5c724f5f2d Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Thu, 2 Oct 2025 13:01:32 -0400 Subject: [PATCH 094/109] WIP --- .../flashinfer/attention/generic/permuted_smem.cuh | 2 +- libflashinfer/include/flashinfer/attention/prefill.cuh | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/libflashinfer/include/flashinfer/attention/generic/permuted_smem.cuh b/libflashinfer/include/flashinfer/attention/generic/permuted_smem.cuh index 536888f434..e6a2013000 100644 --- a/libflashinfer/include/flashinfer/attention/generic/permuted_smem.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/permuted_smem.cuh @@ -141,7 +141,7 @@ struct smem_t { __device__ __forceinline__ void load_fragment_4x4_transposed(uint32_t offset, T* frag) { #if defined(PLATFORM_HIP_DEVICE) auto smem_t_ptr = reinterpret_cast(base + offset); - flashinfer::gpu_iface::mma::load_fragment_transpose_4x4_half_registers(frag, smem_t_ptr); + flashinfer::gpu_iface::mma::load_quad_transposed_fragment(frag, smem_t_ptr); #else static_assert(false, "Not supported on current platform"); #endif diff --git a/libflashinfer/include/flashinfer/attention/prefill.cuh b/libflashinfer/include/flashinfer/attention/prefill.cuh index 19200b3792..05d0083a8c 100644 --- a/libflashinfer/include/flashinfer/attention/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/prefill.cuh @@ -1397,8 +1397,17 @@ __device__ __forceinline__ void SinglePrefillWithKVCacheDevice( get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + 8 * (lane_idx / 16) + lane_idx % 8, (lane_idx % 16) / 8), +#if defined(PLATFROM_CUDA_DEVICE) v_smem_offset_r = v_smem.template get_permuted_offset( get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + lane_idx % 16, lane_idx / 16), +#elif defined(PLATFORM_HIP_DEVICE) + v_smem_offset_r = v_smem.template get_permuted_offset(), + + // ((threadIdx.x % 4) + 4 * (threadIdx.x / 16)) * LDB + ((threadIdx.x % 16) / 4) * 4; + v_smem_offset_r = ((lane_idx % 4) + 4 * get_warp_idx_q(tid.y)) * UPCAST_STRIDE_V + + get_warp_idx_kv(tid.z) * 16 * UPCAST_STRIDE_V + + ((lane_idx % 16) / 4) * 4, +#endif k_smem_offset_w = k_smem.template get_permuted_offset( warp_idx * KV_THR_LAYOUT_ROW + lane_idx / KV_THR_LAYOUT_COL, lane_idx % KV_THR_LAYOUT_COL), From 23863416c39f0504b30725eea0a3b208d36bcab5 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Thu, 2 Oct 2025 16:36:20 -0400 Subject: [PATCH 095/109] Revert changes to wrong file --- libflashinfer/include/flashinfer/attention/prefill.cuh | 9 --------- 1 file changed, 9 deletions(-) diff --git a/libflashinfer/include/flashinfer/attention/prefill.cuh b/libflashinfer/include/flashinfer/attention/prefill.cuh index 05d0083a8c..19200b3792 100644 --- a/libflashinfer/include/flashinfer/attention/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/prefill.cuh @@ -1397,17 +1397,8 @@ __device__ __forceinline__ void SinglePrefillWithKVCacheDevice( get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + 8 * (lane_idx / 16) + lane_idx % 8, (lane_idx % 16) / 8), -#if defined(PLATFROM_CUDA_DEVICE) v_smem_offset_r = v_smem.template get_permuted_offset( get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + lane_idx % 16, lane_idx / 16), -#elif defined(PLATFORM_HIP_DEVICE) - v_smem_offset_r = v_smem.template get_permuted_offset(), - - // ((threadIdx.x % 4) + 4 * (threadIdx.x / 16)) * LDB + ((threadIdx.x % 16) / 4) * 4; - v_smem_offset_r = ((lane_idx % 4) + 4 * get_warp_idx_q(tid.y)) * UPCAST_STRIDE_V + - get_warp_idx_kv(tid.z) * 16 * UPCAST_STRIDE_V + - ((lane_idx % 16) / 4) * 4, -#endif k_smem_offset_w = k_smem.template get_permuted_offset( warp_idx * KV_THR_LAYOUT_ROW + lane_idx / KV_THR_LAYOUT_COL, lane_idx % KV_THR_LAYOUT_COL), From bd598c1821b83540aa9ecd6f77e3571df648f7f8 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Thu, 2 Oct 2025 16:41:19 -0400 Subject: [PATCH 096/109] Update from amd-integration --- .../attention/generic/permuted_smem.cuh | 36 ++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/libflashinfer/include/flashinfer/attention/generic/permuted_smem.cuh b/libflashinfer/include/flashinfer/attention/generic/permuted_smem.cuh index e6a2013000..25aab13c91 100644 --- a/libflashinfer/include/flashinfer/attention/generic/permuted_smem.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/permuted_smem.cuh @@ -36,6 +36,7 @@ using b64_t = uint2; /*! * \brief Compute the number of elements that can be stored in a b128_t. * \tparam T The data type of the elements. + * \tparam VectorWidthBits The width in bits for vector operations (64 or 128). */ template constexpr __host__ __device__ __forceinline__ uint32_t upcast_size() { @@ -137,8 +138,41 @@ struct smem_t { #endif } + /*! + * \brief Loads a fragment from shared memory and performs an in-register transpose across a quad. + * \details This function is designed to prepare the B-matrix operand for a CDNA3 MFMA + * instruction. + * It performs two actions in sequence for a quad of 4 threads: + * 1. Each thread loads a row-oriented fragment (e.g., 4 `half` values) from shared + * memory. + * 2. It then calls `transpose_intra_quad_fragments` to perform an in-register transpose + * of this data among the 4 threads. + * + * The result is that each thread's registers are populated with a column-oriented + * fragment, which is the required layout for the B-operand in a + * row-major(A) x col-major(B) MFMA. + * + * Visual Representation: + * If `[a,b,c,d]` are the 4 `half` values loaded by Thread 0: + * + * Data in Shared Memory (conceptually): + * Row 0: [a, b, c, d] + * Row 1: [e, f, g, h] + * Row 2: [i, j, k, l] + * Row 3: [m, n, o, p] + * + * After this function, registers hold: + * Thread 0: [a, e, i, m] (Column 0) + * Thread 1: [b, f, j, n] (Column 1) + * Thread 2: [c, g, k, o] (Column 2) + * Thread 3: [d, h, l, p] (Column 3) + * + * \tparam T The type of the register fragment (e.g., uint32_t). + * \param offset The starting offset in shared memory for the quad to begin loading. + * \param frag A pointer to the thread's local registers to store the resulting column fragment. + */ template - __device__ __forceinline__ void load_fragment_4x4_transposed(uint32_t offset, T* frag) { + __device__ __forceinline__ void load_fragment_and_quad_transpose(uint32_t offset, T* frag) { #if defined(PLATFORM_HIP_DEVICE) auto smem_t_ptr = reinterpret_cast(base + offset); flashinfer::gpu_iface::mma::load_quad_transposed_fragment(frag, smem_t_ptr); From f0e8e7204e2868c9f7740a80dfcf111aa46873fb Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Thu, 2 Oct 2025 16:43:06 -0400 Subject: [PATCH 097/109] Update from amd-integration --- .../include/flashinfer/attention/generic/permuted_smem.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libflashinfer/include/flashinfer/attention/generic/permuted_smem.cuh b/libflashinfer/include/flashinfer/attention/generic/permuted_smem.cuh index 25aab13c91..2fd12b924f 100644 --- a/libflashinfer/include/flashinfer/attention/generic/permuted_smem.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/permuted_smem.cuh @@ -177,7 +177,7 @@ struct smem_t { auto smem_t_ptr = reinterpret_cast(base + offset); flashinfer::gpu_iface::mma::load_quad_transposed_fragment(frag, smem_t_ptr); #else - static_assert(false, "Not supported on current platform"); + static_assert(sizeof(T) == 0, "Not supported on current platform"); #endif } From 0b9db157be7659ffb6529d068f738d3119e35266 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Mon, 6 Oct 2025 12:58:58 -0400 Subject: [PATCH 098/109] Compilation fixes --- libflashinfer/include/flashinfer/attention/generic/prefill.cuh | 2 +- libflashinfer/utils/utils_hip.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh index 4dabe61fb8..9c48a450bf 100644 --- a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -1267,7 +1267,7 @@ __device__ __forceinline__ void compute_sfm_v( #endif } else { #if defined(PLATFORM_HIP_DEVICE) - v_smem->load_fragment_4x4_transposed(*v_smem_offset_r, b_frag); + v_smem->load_fragment_and_quad_transpose(*v_smem_offset_r, b_frag); #else v_smem->ldmatrix_m8n8x4_trans(*v_smem_offset_r, b_frag); #endif diff --git a/libflashinfer/utils/utils_hip.h b/libflashinfer/utils/utils_hip.h index c858ef82ae..de6b6956c8 100644 --- a/libflashinfer/utils/utils_hip.h +++ b/libflashinfer/utils/utils_hip.h @@ -60,7 +60,7 @@ enum Predicate { template void generate_data(std::vector& vec) { if constexpr (Pred == Predicate::Linear) { - assert(vec.size() <= 0); + assert(vec.size() > 0); for (int i = 0; i < vec.size(); i++) { vec[i] = fi::con::explicit_casting(static_cast(i)); } From 87e6f5559e98adfd664de18e9f492be815700113 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Wed, 8 Oct 2025 23:53:36 -0400 Subject: [PATCH 099/109] Improved debugging --- .../flashinfer/attention/generic/prefill.cuh | 111 ++++++------------ 1 file changed, 39 insertions(+), 72 deletions(-) diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh index 9c48a450bf..c8c6123b02 100644 --- a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -26,6 +26,10 @@ #include "pos_enc.cuh" #include "variants.cuh" +#if Debug +#include "gpu_iface/backend/hip/mma_debug_utils_hip.h" +#endif + namespace flashinfer { DEFINE_HAS_MEMBER(maybe_q_rope_offset) @@ -67,9 +71,6 @@ struct SharedStorageQKVO { struct { alignas(16) DTypeQ q_smem[CTA_TILE_Q * HEAD_DIM_QK]; alignas(16) DTypeKV k_smem[CTA_TILE_KV * HEAD_DIM_QK]; -#if Debug - alignas(16) DTypeKV qk_scratch[CTA_TILE_Q * CTA_TILE_KV]; -#endif alignas(16) DTypeKV v_smem[CTA_TILE_KV * HEAD_DIM_VO]; }; struct { // NOTE(Zihao): synchronize attention states across warps @@ -1067,13 +1068,6 @@ __device__ __forceinline__ void update_mdo_states( float o_scale = gpu_iface::math::ptx_exp2(m_prev * sm_scale - m[mma_q][j] * sm_scale); d[mma_q][j] *= o_scale; -#if Debug - if (warp_idx == 0 && lane_idx == 0) { - printf("Max value %f, m_prev %f, o_scale %f, d %f\n", m[mma_q][j], m_prev, o_scale, - float(d[mma_q][j])); - printf("-------------\n"); - } -#endif #pragma unroll for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { o_frag[mma_q][mma_d][j] *= o_scale; @@ -1116,13 +1110,6 @@ __device__ __forceinline__ void update_mdo_states( } #endif } -#if Debug1 - if (warp_idx == 0 && lane_idx == 0) { - printf("d[0] %f d[1] %f d[2] %f d[3]%f\n", float(d[mma_q][0]), float(d[mma_q][0]), - float(d[mma_q][0]), float(d[mma_q][0])); - printf("-------------\n"); - } -#endif } } else if constexpr (std::is_same_v) { #if defined(PLATFORM_HIP_DEVICE) @@ -1234,13 +1221,6 @@ __device__ __forceinline__ void compute_sfm_v( #endif } } -#if Debug - if (debug_warp_idx == 0 && debug_lane_idx == 0) { - printf("After rowsum: d[0] %f d[1] %f d[2] %f d[3] %f\n", float(d[mma_q][0]), - float(d[mma_q][0]), float(d[mma_q][0]), float(d[mma_q][0])); - printf("-------------\n"); - } -#endif } } @@ -1749,6 +1729,12 @@ __device__ __forceinline__ void SinglePrefillWithKVCacheDevice( } init_states(variant, o_frag, m, d); +#if Debug + // Statically allocate a shared memory array specifically for debugging s_frag. + // This avoids modifying the main SharedStorage union. + __shared__ DTypeQKAccum qk_scratch[CTA_TILE_Q * CTA_TILE_KV]; +#endif + // cooperative fetch q fragment from gmem to reg const uint32_t qo_packed_idx_base = (bx * NUM_WARPS_Q + get_warp_idx_q(tid.y)) * NUM_MMA_Q * 16; @@ -1811,16 +1797,18 @@ __device__ __forceinline__ void SinglePrefillWithKVCacheDevice( #if defined(PLATFORM_HIP_DEVICE) uint32_t k_smem_offset_r = k_smem.template get_permuted_offset( get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + lane_idx % 16, (lane_idx / 16)); + uint32_t v_smem_offset_r = v_smem.template get_permuted_offset( + get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + (lane_idx % 4) + 4 * (lane_idx / 16), + lane_idx / 4); #elif defined(PLATFORM_CUDA_DEVICE) uint32_t k_smem_offset_r = k_smem.template get_permuted_offset( get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + 8 * (lane_idx / 16) + lane_idx % 8, (lane_idx % 16) / 8); + uint32_t v_smem_offset_r = v_smem.template get_permuted_offset( + get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + (lane_idx % 4) + 4 * (lane_idx / 16), + lane_idx / 4); #endif - uint32_t v_smem_offset_r = v_smem.template get_permuted_offset( - get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + (lane_idx % 4) + - 4 * (lane_idx / 16), - lane_idx / 4), - k_smem_offset_w = k_smem.template get_permuted_offset( + uint32_t k_smem_offset_w = k_smem.template get_permuted_offset( warp_idx * KV_THR_LAYOUT_ROW + lane_idx / KV_THR_LAYOUT_COL, lane_idx % KV_THR_LAYOUT_COL), v_smem_offset_w = v_smem.template get_permuted_offset( @@ -1832,11 +1820,7 @@ __device__ __forceinline__ void SinglePrefillWithKVCacheDevice( produce_kv(v_smem, &v_smem_offset_w, &v_ptr, v_stride_n, 0, chunk_size, tid); memory::commit_group(); - #if Debug - - smem_t scratch(smem_storage.qk_scratch); - // if (warp_idx == 0 && lane_idx == 0) { // printf("partition_kv : %d\n", partition_kv); // printf("kv_len : %d\n", kv_len); @@ -1871,16 +1855,16 @@ __device__ __forceinline__ void SinglePrefillWithKVCacheDevice( // Prints the (NUM_MMA_KV*16) x (NUM_MMA_D*16) matrix from global mem. if (warp_idx == 0 && lane_idx == 0) { - // printf("\n DEBUG K Global (HIP):\n"); - // printf("k_stride_n : %d\n", k_stride_n); - // printf("k_stride_h : %d\n", k_stride_h); - // printf("kv_head_idx : %d\n", kv_head_idx); - // printf("num_qo_heads : %d\n", num_qo_heads); - // printf("num_kv_heads : %d\n", num_kv_heads); - // printf("k_stride_n : %d\n", k_stride_n); - // printf("KTraits::NUM_MMA_D_QK : %d\n", KTraits::NUM_MMA_D_QK); - // printf("NUM_MMA_KV : %d\n", NUM_MMA_KV); - // printf("NUM_MMA_Q : %d\n", NUM_MMA_Q); + printf("\n DEBUG K Global (HIP):\n"); + printf("k_stride_n : %d\n", k_stride_n); + printf("k_stride_h : %d\n", k_stride_h); + printf("kv_head_idx : %d\n", kv_head_idx); + printf("num_qo_heads : %d\n", num_qo_heads); + printf("num_kv_heads : %d\n", num_kv_heads); + printf("k_stride_n : %d\n", k_stride_n); + printf("KTraits::NUM_MMA_D_QK : %d\n", KTraits::NUM_MMA_D_QK); + printf("NUM_MMA_KV : %d\n", NUM_MMA_KV); + printf("NUM_MMA_Q : %d\n", NUM_MMA_Q); #if 0 DTypeKV *k_ptr_tmp = k + (chunk_start + warp_idx * KV_THR_LAYOUT_ROW + @@ -1967,28 +1951,24 @@ __device__ __forceinline__ void SinglePrefillWithKVCacheDevice( // compute attention score compute_qk(&qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag); -#if Debug1 - if (params.debug_thread_id == lane_idx && params.debug_warp_id == warp_idx) { - printf("After compute_qk\n"); - } - debug_write_sfrag_to_scratch(s_frag, tid, params.debug_thread_id, - params.debug_warp_id); +#if Debug + flashinfer::gpu_iface::debug_utils::hip::write_s_frag_to_lds< + DTypeQKAccum, NUM_MMA_Q, NUM_MMA_KV, HALF_ELEMS_PER_THREAD>(s_frag, qk_scratch, + CTA_TILE_KV, tid); + + // a) Print thread 0's registers to see the source data. + flashinfer::gpu_iface::debug_utils::hip::print_s_frag_register< + DTypeQKAccum, NUM_MMA_Q, NUM_MMA_KV, HALF_ELEMS_PER_THREAD>(s_frag, tid); + + // b) Print the materialized LDS array to see the final result for this iteration. + flashinfer::gpu_iface::debug_utils::hip::print_lds_array(qk_scratch, CTA_TILE_Q, CTA_TILE_KV); + #endif logits_transform( params, variant, /*batch_idx=*/0, qo_packed_idx_base, chunk_start + (iter * NUM_WARPS_KV + get_warp_idx_kv(tid.z)) * NUM_MMA_KV * 16, qo_len, kv_len, group_size, s_frag, tid, kv_head_idx); -#if Debug1 - if (params.debug_thread_id == lane_idx && params.debug_warp_id == warp_idx) { - printf("params.sm_scale %f, params.logits_soft_cap %f\n", params.sm_scale, - params.logits_soft_cap); - printf("After logits_transform\n"); - } - debug_write_sfrag_to_scratch(s_frag, tid, params.debug_thread_id, - params.debug_warp_id); -#endif - // apply mask if (MASK_MODE == MaskMode::kCustom || (iter >= mask_iteration || iter < window_iteration)) { logits_mask( @@ -1997,21 +1977,8 @@ __device__ __forceinline__ void SinglePrefillWithKVCacheDevice( qo_len, kv_len, chunk_end, group_size, s_frag, tid, kv_head_idx); } -#if Debug1 - // if(params.debug_thread_id == lane_idx && params.debug_warp_id == warp_idx) { - // printf("Before update_mdo_states\n"); - // } - debug_write_sfrag_to_scratch(s_frag, tid, params.debug_thread_id, - params.debug_warp_id); -#endif - // compute m,d states in online softmax update_mdo_states(variant, s_frag, o_frag, m, d, warp_idx, lane_idx); - -#if Debug1 - debug_write_sfrag_to_scratch(s_frag, tid, params.debug_thread_id, - params.debug_warp_id); -#endif block.sync(); produce_kv( k_smem, &k_smem_offset_w, &k_ptr, k_stride_n, (iter + 1) * CTA_TILE_KV, chunk_size, tid); From 28a0355b1db79a14b127579b59bd2c878bd7a9bc Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Wed, 8 Oct 2025 23:54:18 -0400 Subject: [PATCH 100/109] Improved s_frag debug printer --- .../backend/hip/mma_debug_utils_hip.h | 176 +++++++++++++----- 1 file changed, 129 insertions(+), 47 deletions(-) diff --git a/libflashinfer/include/gpu_iface/backend/hip/mma_debug_utils_hip.h b/libflashinfer/include/gpu_iface/backend/hip/mma_debug_utils_hip.h index cf1aa52629..f8bc7dd1d2 100644 --- a/libflashinfer/include/gpu_iface/backend/hip/mma_debug_utils_hip.h +++ b/libflashinfer/include/gpu_iface/backend/hip/mma_debug_utils_hip.h @@ -7,6 +7,11 @@ #include "gpu_iface/backend/hip/mma_hip.h" #include "gpu_iface/gpu_runtime_compat.hpp" +namespace { +constexpr uint32_t MMA_COLS = 16; +constexpr uint32_t MMA_ROWS_PER_THREAD = 4; +} // namespace + namespace flashinfer::gpu_iface::debug_utils::hip { enum class MatrixLayout { A, B }; @@ -37,8 +42,8 @@ template __device__ void load_amatrix_layout(T* lds_array, uint32_t* R, uint32_t dimX) { static_assert(std::is_same_v, "Only supported for __half types"); const int lane_id = threadIdx.x % 64; - const int row = lane_id % 16; - const int col_start = (lane_id / 16) * 4; + const int row = lane_id % MMA_COLS; + const int col_start = (lane_id / MMA_COLS) * MMA_ROWS_PER_THREAD; auto offset = lds_array + row * dimX + col_start; mma_impl::hip::load_fragment(R, offset); @@ -56,7 +61,9 @@ template __device__ void load_bmatrix_layout(T* arr, uint32_t* R, uint32_t dimY) { static_assert(std::is_same_v, "Only supported for __half types"); const int lane_id = threadIdx.x % 64; - int b_idx = ((lane_id % 4) + 4 * (lane_id / 16)) * dimY + ((lane_id % 16) / 4) * 4; + int b_idx = + ((lane_id % MMA_ROWS_PER_THREAD) + MMA_ROWS_PER_THREAD * (lane_id / MMA_COLS)) * dimY + + ((lane_id % MMA_COLS) / MMA_ROWS_PER_THREAD) * MMA_ROWS_PER_THREAD; mma_impl::hip::load_quad_transposed_fragment<__half>(R, &arr[b_idx]); } @@ -71,16 +78,35 @@ __device__ void print_register(uint32_t* R) { __half2float(values[2]), __half2float(values[3])); } +/// @brief Prints the full s_frag array from a single thread's registers. +template +__device__ void print_s_frag_register(const T (*s_frag)[NUM_MMA_KV][ELEMS_PER_FRAGMENT], + const dim3 tid = threadIdx) { + if (tid.x == 0 && tid.y == 0 && tid.z == 0) { + printf("Thread (0,0,0) s_frag registers:\n"); + for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { + for (uint32_t mma_kv = 0; mma_kv < NUM_MMA_KV; ++mma_kv) { + const T* values = s_frag[mma_q][mma_kv]; + printf(" frag[%u][%u]: [%8.3f, %8.3f, %8.3f, %8.3f]\n", mma_q, mma_kv, float(values[0]), + float(values[1]), float(values[2]), float(values[3])); + } + } + printf("\n"); + } + __syncthreads(); +} + /// @brief Prints a 2D LDS array to the console from a single thread. /// @tparam T The data type of the LDS array, must be `__half`. /// @param lds_array Pointer to the shared memory array. /// @param dimY The height of the 2D array. /// @param dimX The width of the 2D array. template -__device__ void print_lds_array(T* lds_array, uint32_t dimY, uint32_t dimX) { +__device__ void print_lds_array(T* lds_array, uint32_t dimY, uint32_t dimX, + const char* title = "LDS Array") { static_assert(std::is_same_v, "Only supported for __half types"); if (threadIdx.x == 0) { - printf("LDS Array (%dx%d):\n", dimX, dimY); + printf("%s (%dx%d):\n", title, dimX, dimY); for (int y = 0; y < dimY; ++y) { for (int x = 0; x < dimX; ++x) { printf("%5.1f ", __half2float(lds_array[y * dimX + x])); @@ -92,51 +118,107 @@ __device__ void print_lds_array(T* lds_array, uint32_t dimY, uint32_t dimX) { __syncthreads(); } -/// @brief Writes the 4 `half` values from each thread's registers back to LDS. -/// @details This function is the inverse of the `load_*_layout` functions. It materializes -/// the in-register matrix layout into shared memory. -/// -/// A-Layout Pattern: -/// Each thread `T_(16*c + r)` writes its 4 values to `LDS[r, 4*c : 4*c+3]`. -/// This reconstructs the standard row-major matrix. -/// -/// B-Layout Pattern: -/// Each thread `T_(16*br + 4*bc + ti)` (where br=block_row, bc=block_col, -/// ti=thread_in_block) writes its 4 values to `LDS[4*br + ti, 4*bc : 4*bc+3]`. This -/// creates a block-transposed matrix in shared memory. -/// @tparam T The data type of the LDS array, must be `__half`. -/// @param R Pointer to the thread's registers (uint32_t[2]). -/// @param lds_array Pointer to the shared memory array. -/// @param dimY The height of the LDS array. -/// @param dimX The width of the LDS array. -/// @param layout The target memory layout (A or B) to use for writing. -template -__device__ void write_matrix_frag_to_lds(const uint32_t* R, T* lds_array, uint32_t dimY, - uint32_t dimX, MatrixLayout layout) { - static_assert(std::is_same_v, "Only supported for __half types"); +/// @brief Prints a 2D LDS array of floats to the console from a single thread. +__device__ void print_lds_array(float* lds_array, uint32_t dimY, uint32_t dimX, + const char* title = "LDS Array (float)") { + if (threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0) { + printf("%s (%dx%d):\n", title, dimX, dimY); + for (int y = 0; y < dimY; ++y) { + for (int x = 0; x < dimX; ++x) { + printf("%8.3f ", lds_array[y * dimX + x]); + } + printf("\n"); + } + printf("\n"); + } + __syncthreads(); +} - const int lane_id = threadIdx.x % 64; - const T* values = reinterpret_cast(R); - int row, col_start; - - if (layout == MatrixLayout::A) { - // A-matrix layout: each thread owns a 1x4 strip of a row - row = lane_id % 16; - col_start = (lane_id / 16) * 4; - } else { // MatrixLayout::B - // B-matrix layout: each thread owns a 1x4 strip of a column block - const uint32_t block_row = (lane_id % 16) / 4; - const uint32_t block_col = (lane_id / 16); - const uint32_t thread_in_block = lane_id % 4; - row = block_col * 4 + thread_in_block; - col_start = block_row * 4; +/// @brief Materializes a 2D array of accumulator fragments from each thread's registers into a +/// 2D shared memory array. +/// @details This function is the inverse of the hardware's distribution of accumulator results. +/// It reconstructs a logical tile of the S = Q * K^T matrix in shared memory, +/// accounting for the partitioning of work across multiple warps. +/// @tparam T The data type of the fragments and LDS array (e.g., float or half). +/// @tparam NUM_MMA_Q The number of fragments along the Q dimension (rows) per thread. +/// @tparam NUM_MMA_KV The number of fragments along the KV dimension (columns) per thread. +/// @tparam ELEMS_PER_FRAGMENT The number of elements per fragment (typically 4 for float/half). +/// @param s_frag The 3D fragment array from the thread's registers. +/// @param lds_scratchpad Pointer to the shared memory array. +/// @param lds_stride The width/stride of the lds_scratchpad (e.g., CTA_TILE_KV). +/// @param tid The thread's index within the block (threadIdx). +template +__device__ void write_s_frag_to_lds(const T (*s_frag)[NUM_MMA_KV][ELEMS_PER_FRAGMENT], + T* lds_scratchpad, const uint32_t lds_stride, + const dim3 tid = threadIdx) { + const int lane_id = tid.x % 64; + const int warp_idx_q = tid.y; + + // Calculate the starting row in the LDS tile for this entire warp. + const uint32_t warp_base_row = warp_idx_q * NUM_MMA_Q * MMA_COLS; + +#pragma unroll + for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t mma_kv = 0; mma_kv < NUM_MMA_KV; ++mma_kv) { + // -- Calculate the top-left corner of the 16x16 fragment this thread contributes to -- + const uint32_t frag_row_offset = mma_q * MMA_COLS; + const uint32_t frag_col_offset = mma_kv * MMA_COLS; + + // -- Calculate the specific 4x1 element strip this thread writes within that fragment -- + // This logic correctly materializes a B-layout fragment (column strip). + // Each thread T_c handles column 'c' of the fragment. + // The 4 threads in a "column" of the warp (e.g., lanes 0, 16, 32, 48) + // handle the 4 rows of that column strip. + const uint32_t thread_start_row_in_frag = (lane_id / MMA_COLS) * MMA_ROWS_PER_THREAD; + const uint32_t thread_col_in_frag = (lane_id % MMA_COLS); + + // -- Combine all offsets and write the 4x1 column strip to LDS -- + const T* values = s_frag[mma_q][mma_kv]; + for (int i = 0; i < MMA_ROWS_PER_THREAD; ++i) { + // The row for this element is the thread's starting row + the element's index in the strip. + const uint32_t final_row = warp_base_row + frag_row_offset + thread_start_row_in_frag + i; + // The column is fixed for all 4 elements in the strip. + const uint32_t final_col = frag_col_offset + thread_col_in_frag; + + // Calculate destination and write the value. + T* dest = lds_scratchpad + final_row * lds_stride + final_col; + *dest = values[i]; + } + } } +} + +template +__device__ void write_m_new_to_lds(const T (*m)[NUM_ACCUM_ROWS_PER_THREAD], T* lds_scratchpad, + const dim3 tid = threadIdx) { + const int lane_idx = tid.x; + const int warp_idx_q = tid.y; + + // Each group of 16 threads (a "row group") computes the max for 4 rows. + // We only need one thread from each group to write the results. + if (lane_idx % MMA_COLS == 0) { + // Base row index for this warp's Q tile + const uint32_t warp_base_row = warp_idx_q * NUM_MMA_Q * MMA_COLS; + +#pragma unroll + for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { + // Base row for this specific MMA instruction within the warp's tile + const uint32_t mma_base_row = mma_q * MMA_COLS; - half* offset = lds_array + row * dimX + col_start; - offset[0] = values[0]; - offset[1] = values[1]; - offset[2] = values[2]; - offset[3] = values[3]; +#pragma unroll + for (uint32_t j = 0; j < NUM_ACCUM_ROWS_PER_THREAD; ++j) { + // The thread's lane_idx determines which group of 4 rows it is in. + // e.g., lane 0 is in group 0, lane 16 is in group 1, etc. + const uint32_t row_group_offset = (lane_idx / MMA_COLS) * NUM_ACCUM_ROWS_PER_THREAD; + + // The final row index in the logical S matrix + const uint32_t final_row_idx = warp_base_row + mma_base_row + row_group_offset + j; + + lds_scratchpad[final_row_idx] = m[mma_q][j]; + } + } + } } } // namespace flashinfer::gpu_iface::debug_utils::hip From b77748d4abcbcdfdb4bdd3d947e8a353b2619c49 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Fri, 10 Oct 2025 17:54:26 -0400 Subject: [PATCH 101/109] Remove unused debug function --- .../flashinfer/attention/generic/prefill.cuh | 45 +------------------ 1 file changed, 2 insertions(+), 43 deletions(-) diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh index c8c6123b02..38c0ec5d5c 100644 --- a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -1166,8 +1166,7 @@ __device__ __forceinline__ void compute_sfm_v( uint32_t* v_smem_offset_r, typename KTraits::DTypeQKAccum (*s_frag)[KTraits::NUM_MMA_KV][KTraits::HALF_ELEMS_PER_THREAD], float (*o_frag)[KTraits::NUM_MMA_D_VO][KTraits::HALF_ELEMS_PER_THREAD], - float (*d)[KTraits::NUM_ACCUM_ROWS_PER_THREAD], const dim3 tid = threadIdx, - uint32_t debug_warp_idx = 0, uint32_t debug_lane_idx = 0) { + float (*d)[KTraits::NUM_ACCUM_ROWS_PER_THREAD], const dim3 tid = threadIdx) { constexpr uint32_t UPCAST_STRIDE_V = KTraits::UPCAST_STRIDE_V; constexpr uint32_t HALF_ELEMS_PER_THREAD = KTraits::HALF_ELEMS_PER_THREAD; constexpr uint32_t INT32_ELEMS_PER_THREAD = KTraits::INT32_ELEMS_PER_THREAD; @@ -1186,25 +1185,6 @@ __device__ __forceinline__ void compute_sfm_v( } } -#if Debug1 - // Debug the state of attention score matrix before rowsum to compute denom - constexpr uint32_t NUM_MMA_Q = KTraits::NUM_MMA_Q; - constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; - const uint32_t warp_idx = get_warp_idx(tid.y, tid.z), lane_idx = tid.x; - - // Write all thread's fragments to shared memory - for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { - for (uint32_t mma_kv = 0; mma_kv < NUM_MMA_KV; ++mma_kv) { - if (lane_idx == debug_lane_idx && warp_idx == debug_warp_idx) { - printf("%.6f %.6f %.6f %.6f\n", s_frag[mma_q][mma_kv][0], s_frag[mma_q][mma_kv][1], - s_frag[mma_q][mma_kv][2], s_frag[mma_q][mma_kv][3]); - } - } - } - __syncthreads(); - -#endif - if constexpr (KTraits::AttentionVariant::use_softmax) { #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { @@ -1600,26 +1580,6 @@ __device__ __forceinline__ void write_o_reg_gmem( } // namespace -template -__device__ __forceinline__ void debug_write_sfrag_to_scratch( - typename KTraits::DTypeQKAccum (*s_frag)[KTraits::NUM_MMA_KV][KTraits::HALF_ELEMS_PER_THREAD], - const dim3 tid = threadIdx, uint32_t debug_thread_id = 0, uint32_t debug_warp_id = 0) { - constexpr uint32_t NUM_MMA_Q = KTraits::NUM_MMA_Q; - constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; - const uint32_t warp_idx = get_warp_idx(tid.y, tid.z), lane_idx = tid.x; - - // Write all thread's fragments to shared memory - for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { - for (uint32_t mma_kv = 0; mma_kv < NUM_MMA_KV; ++mma_kv) { - if (lane_idx == debug_thread_id && warp_idx == debug_warp_id) { - printf("%.6f %.6f %.6f %.6f\n", s_frag[mma_q][mma_kv][0], s_frag[mma_q][mma_kv][1], - s_frag[mma_q][mma_kv][2], s_frag[mma_q][mma_kv][3]); - } - } - } - __syncthreads(); -} - /*! * \brief FlashAttention prefill CUDA kernel for a single request. * \tparam partition_kv Whether to split kv_len into chunks. @@ -1987,8 +1947,7 @@ __device__ __forceinline__ void SinglePrefillWithKVCacheDevice( block.sync(); // compute sfm*v - compute_sfm_v(&v_smem, &v_smem_offset_r, s_frag, o_frag, d, tid, - params.debug_warp_id, params.debug_thread_id); + compute_sfm_v(&v_smem, &v_smem_offset_r, s_frag, o_frag, d, tid); block.sync(); produce_kv( v_smem, &v_smem_offset_w, &v_ptr, v_stride_n, (iter + 1) * CTA_TILE_KV, chunk_size, tid); From 3b52620792274757353d3741d3d4c895b4bf8929 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Fri, 10 Oct 2025 17:57:44 -0400 Subject: [PATCH 102/109] Clean up tests --- .../tests/hip/test_load_q_global_smem.cpp | 42 ------------------- .../tests/hip/test_mfma_fp32_16x16x16fp16.cpp | 12 ------ 2 files changed, 54 deletions(-) delete mode 100644 libflashinfer/tests/hip/test_load_q_global_smem.cpp diff --git a/libflashinfer/tests/hip/test_load_q_global_smem.cpp b/libflashinfer/tests/hip/test_load_q_global_smem.cpp deleted file mode 100644 index d4e381aa62..0000000000 --- a/libflashinfer/tests/hip/test_load_q_global_smem.cpp +++ /dev/null @@ -1,42 +0,0 @@ -// SPDX - FileCopyrightText : 2025 Advanced Micro Devices, Inc. -// -// SPDX - License - Identifier : Apache 2.0 - -#include -#include - -#include -#include -#include -#include - -#include "flashinfer/attention/generic/default_prefill_params.cuh" -#include "flashinfer/attention/generic/prefill.cuh" -#include "flashinfer/attention/generic/variants.cuh" -#include "utils/cpu_reference_hip.h" -#include "utils/utils_hip.h" // vec_normal_ - -namespace { -constexpr uint32_t qo_len = 64; -constexpr uint32_t num_qo_heads = 1; -constexpr uint32_t head_dim = 64; -} // namespace - -// CPU reference implementation that creates a Q matrix with a kNHD layout and -// initializes. -void initialize_cpu_q() { - std::vector q(qo_len * num_qo_heads * head_dim); - utils::vec_normal_(q); -} - -// Validates the original Q matrix on CPU with the copied over data from GPU. -// Ensures that the copied over data matches both the CDNA3 A-matrix layout and -// also validates with the original Q matrix. - -// GPU kernel that launches exactly one warp and calls prefill.cuh's -// load_q_global_smem to populate a LDS array from a global array. Then copies -// back the shared memory array to another output global array. - -// Laucher of GPU kernel. -// Copies the Q array from the CPU reference to GPU and then calls the kernel -// to copy from global to shared memory. diff --git a/libflashinfer/tests/hip/test_mfma_fp32_16x16x16fp16.cpp b/libflashinfer/tests/hip/test_mfma_fp32_16x16x16fp16.cpp index 5c4f37c333..fd638a75d0 100644 --- a/libflashinfer/tests/hip/test_mfma_fp32_16x16x16fp16.cpp +++ b/libflashinfer/tests/hip/test_mfma_fp32_16x16x16fp16.cpp @@ -24,18 +24,6 @@ } \ } -namespace { - -__device__ void print_register(uint32_t* R) { - auto values = reinterpret_cast<__half*>(R); - printf("[%f %f %f %f]\n", __half2float(values[0]), __half2float(values[1]), - __half2float(values[2]), __half2float(values[3])); -} - -__device__ void print_register(float* R) { printf("[%f %f %f %f]\n", R[0], R[1], R[3], R[4]); } - -} // namespace - // Dimensions for our test matrices constexpr int M = 16; constexpr int N = 16; From 3e4e12b105f0fcc8b41fd56dd4678a5cb4be1e3c Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Fri, 10 Oct 2025 17:59:55 -0400 Subject: [PATCH 103/109] Remove more unused stuff --- .../tests/hip/test_apply_llama_rope.cpp | 355 ------------------ 1 file changed, 355 deletions(-) delete mode 100644 libflashinfer/tests/hip/test_apply_llama_rope.cpp diff --git a/libflashinfer/tests/hip/test_apply_llama_rope.cpp b/libflashinfer/tests/hip/test_apply_llama_rope.cpp deleted file mode 100644 index 3620388510..0000000000 --- a/libflashinfer/tests/hip/test_apply_llama_rope.cpp +++ /dev/null @@ -1,355 +0,0 @@ -// SPDX - FileCopyrightText : 2025 Advanced Micro Devices, Inc. -// -// SPDX - License - Identifier : Apache 2.0 - -#include - -#include -#include - -#include "../../utils/cpu_reference_hip.h" -#include "../../utils/utils_hip.h" -#include "flashinfer/attention/generic/prefill.cuh" -#include "gpu_iface/fastdiv.cuh" -#include "gpu_iface/gpu_runtime_compat.hpp" - -namespace { -using QParamType = std::tuple; - -template -struct TestKernelTraits { - static constexpr uint32_t NUM_MMA_D_QK = HEAD_DIM / 16; - static constexpr uint32_t NUM_MMA_D_VO = HEAD_DIM / 16; -}; - -template -__global__ void test_init_rope_freq_kernel(float* output_freq, float rope_rcp_scale, - float rope_rcp_theta) { - using KTraits = TestKernelTraits; - - // Allocate local frequency array - float rope_freq[KTraits::NUM_MMA_D_VO / 2][4]; // [2][4] for HEAD_DIM=64 - - // Call the init_rope_freq function from prefill.cuh - flashinfer::init_rope_freq(rope_freq, rope_rcp_scale, rope_rcp_theta, threadIdx.x); - - // Write frequencies to their correct feature indices - const uint32_t lane_idx = threadIdx.x; - if (lane_idx < 64) { // Only write for valid threads - for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO / 2; ++mma_d) { - for (uint32_t j = 0; j < 4; ++j) { - // Calculate the actual feature index this frequency corresponds - // to - uint32_t feature_idx = flashinfer::get_feature_index(mma_d, lane_idx, j); - - // Write frequency to the correct feature index in global array - if (feature_idx < HEAD_DIM) { - output_freq[feature_idx] = rope_freq[mma_d][j]; - if (feature_idx + HEAD_DIM / 2 < HEAD_DIM) { - output_freq[feature_idx + HEAD_DIM / 2] = rope_freq[mma_d][j]; - } - } - } - } - } -} - -template -__global__ void test_q_frag_apply_llama_rope_kernel(__half* q_input, __half* q_output, - uint32_t qo_len, uint32_t num_qo_heads, - uint32_t kv_len, float rope_rcp_scale, - float rope_rcp_theta, - flashinfer::uint_fastdiv group_size_fastdiv) { - using KTraits = TestKernelTraits; - constexpr uint32_t HALF_ELEMS_PER_THREAD = 4; - constexpr uint32_t INT32_ELEMS_PER_THREAD = 2; - constexpr uint32_t NUM_MMA_D_QK = HEAD_DIM / 16; - - float rope_freq[KTraits::NUM_MMA_D_VO / 2][4]; - flashinfer::init_rope_freq(rope_freq, rope_rcp_scale, rope_rcp_theta, threadIdx.x); - - const uint32_t lane_idx = threadIdx.x; - const uint32_t warp_idx = blockIdx.x; - - // TODO: Need to check that qo_len is evenly divisible by 16. - for (uint32_t qo_head_idx = 0; qo_head_idx < num_qo_heads; ++qo_head_idx) { - for (uint32_t seq_chunk = 0; seq_chunk < qo_len; seq_chunk += 16) { - uint32_t seq_idx = seq_chunk + (lane_idx % 16); - if (seq_idx >= qo_len) continue; - - uint32_t abs_position = seq_idx + kv_len - qo_len; - // Each iteration processes 16*2=32 features (first_half + - // second_half) - for (uint32_t feat_chunk = 0; feat_chunk < NUM_MMA_D_QK / 2; ++feat_chunk) { - uint32_t feat_offset_first = feat_chunk * 32; - uint32_t feat_offset_second = feat_offset_first + HEAD_DIM / 2; - - // Load fragments from global memory - __half q_frag_first[HALF_ELEMS_PER_THREAD]; - __half q_frag_second[HALF_ELEMS_PER_THREAD]; - - // Calculate base address for this sequence and head - uint32_t base_offset = qo_head_idx * HEAD_DIM + seq_idx * (num_qo_heads * HEAD_DIM); - - // Load first half (4 consecutive features per thread) - for (uint32_t i = 0; i < HALF_ELEMS_PER_THREAD; ++i) { - uint32_t feat_idx1 = flashinfer::get_feature_index(feat_chunk, lane_idx, i); - uint32_t feat_idx2 = feat_idx1 + HEAD_DIM / 2; - q_frag_first[i] = *(q_input + base_offset + feat_idx1); - q_frag_second[i] = *(q_input + base_offset + feat_idx2); - } - - // Apply RoPE using the validated function - uint32_t mma_di = feat_chunk; - flashinfer::q_frag_apply_llama_rope<__half, HALF_ELEMS_PER_THREAD>( - q_frag_first, q_frag_second, rope_freq[mma_di % (KTraits::NUM_MMA_D_VO / 2)], - abs_position, group_size_fastdiv); - - // Store results back to global memory - for (uint32_t i = 0; i < HALF_ELEMS_PER_THREAD; ++i) { - uint32_t feat_idx1 = flashinfer::get_feature_index(feat_chunk, lane_idx, i); - uint32_t feat_idx2 = feat_idx1 + HEAD_DIM / 2; - *(q_output + base_offset + feat_idx1) = q_frag_first[i]; - *(q_output + base_offset + feat_idx2) = q_frag_second[i]; - } - } - } - } -} - -template -class LLamaRopeTestFixture : public ::testing::TestWithParam { - protected: - uint32_t qo_len, num_qo_heads, head_dim; - std::vector q; - - LLamaRopeTestFixture() { - const auto& params = GetParam(); - qo_len = std::get<0>(params); - num_qo_heads = std::get<1>(params); - head_dim = std::get<2>(params); - q.resize(qo_len * num_qo_heads * head_dim); - } - - void SetUp() override { utils::vec_normal_(q); } - - void TearDown() override {} - - std::vector apply_cpu_rope(size_t offset, float rope_scale = 1.0f, - float rope_theta = 10000.0f) { - return cpu_reference::apply_llama_rope(q.data(), head_dim, offset, rope_scale, rope_theta); - } - - std::vector get_cpu_rope_frequencies(float rope_scale = 1.0f, - float rope_theta = 10000.0f) { - std::vector frequencies(head_dim); - - for (size_t k = 0; k < head_dim; ++k) { - // Extract ONLY the frequency calculation (without position/offset) - float freq_base = float(2 * (k % (head_dim / 2))) / float(head_dim); - float frequency = (1.0f / rope_scale) / std::pow(rope_theta, freq_base); - frequencies[k] = frequency; - } - - return frequencies; - } - - std::vector get_gpu_rope_frequencies(float rope_scale = 1.0f, - float rope_theta = 10000.0f) { - // Convert to reciprocal values as expected by GPU kernel - float rope_rcp_scale = 1.0f / rope_scale; - float rope_rcp_theta = 1.0f / rope_theta; - - // Allocate GPU memory for output (one frequency per feature) - float* d_output_freq; - size_t output_size = head_dim * sizeof(float); - FI_GPU_CALL(hipMalloc(&d_output_freq, output_size)); - FI_GPU_CALL(hipMemset(d_output_freq, 0, output_size)); - - // Launch kernel with 64 threads - dim3 grid(1); - dim3 block(64); - - if (head_dim == 64) { - test_init_rope_freq_kernel<64> - <<>>(d_output_freq, rope_rcp_scale, rope_rcp_theta); - } - - FI_GPU_CALL(hipDeviceSynchronize()); - - // Copy all frequencies back - std::vector gpu_frequencies(head_dim); - FI_GPU_CALL( - hipMemcpy(gpu_frequencies.data(), d_output_freq, output_size, hipMemcpyDeviceToHost)); - - FI_GPU_CALL(hipFree(d_output_freq)); - return gpu_frequencies; - } - - std::vector> apply_cpu_rope_all_sequences(size_t kv_len = 1000, - float rope_scale = 1.0f, - float rope_theta = 10000.0f) { - std::vector> results; - - DISPATCH_head_dim(head_dim, HEAD_DIM, { - using namespace flashinfer; - tensor_info_t info(qo_len, kv_len, num_qo_heads, num_qo_heads, QKVLayout::kHND, HEAD_DIM); - - // Apply RoPE to all sequences and heads - for (size_t qo_head_idx = 0; qo_head_idx < num_qo_heads; ++qo_head_idx) { - for (size_t q_idx = 0; q_idx < qo_len; ++q_idx) { - size_t offset = q_idx + kv_len - qo_len; - - // Apply RoPE to this specific Q sequence/head - auto q_rotary_local = cpu_reference::apply_llama_rope_debug( - q.data() + info.get_q_elem_offset(q_idx, qo_head_idx, 0), head_dim, offset, - rope_scale, rope_theta); - - results.push_back(std::move(q_rotary_local)); - } - } - }); - - return results; - } - - std::vector test_gpu_q_frag_apply_rope(size_t kv_len = 1000, float rope_scale = 1.0f, - float rope_theta = 10000.0f) { - // Convert to reciprocal values - float rope_rcp_scale = 1.0f / rope_scale; - float rope_rcp_theta = 1.0f / rope_theta; - uint32_t group_size = 1; // Simple case for now - - // Allocate GPU memory for input and output - __half *d_q_input, *d_q_output; - size_t q_size = q.size() * sizeof(__half); - - FI_GPU_CALL(hipMalloc(&d_q_input, q_size)); - FI_GPU_CALL(hipMalloc(&d_q_output, q_size)); - - // Copy input Q to GPU - FI_GPU_CALL(hipMemcpy(d_q_input, q.data(), q_size, hipMemcpyHostToDevice)); - FI_GPU_CALL(hipMemset(d_q_output, 0, q_size)); - - // Launch kernel - one block with 64 threads - dim3 grid(1); // Single block for simplicity - dim3 block(64); // CDNA3 wavefront size - - if (head_dim == 64) { - test_q_frag_apply_llama_rope_kernel<64><<>>(d_q_input, d_q_output, qo_len, - num_qo_heads, kv_len, rope_rcp_scale, - rope_rcp_theta, group_size); - } - - FI_GPU_CALL(hipDeviceSynchronize()); - - // Copy results back to CPU - std::vector<__half> gpu_output(q.size()); - FI_GPU_CALL(hipMemcpy(gpu_output.data(), d_q_output, q_size, hipMemcpyDeviceToHost)); - - // Convert to float for comparison - std::vector result(head_dim); - for (size_t i = 0; i < head_dim; ++i) { - result[i] = float(gpu_output[i]); // First sequence, first head - } - - FI_GPU_CALL(hipFree(d_q_input)); - FI_GPU_CALL(hipFree(d_q_output)); - - return result; - } -}; - -using LLamaRopeTestWithFP16 = LLamaRopeTestFixture<__half>; -} // namespace - -// Wrapper to validate freq application -// call q_smem_inplace_apply_rotary and copy back results to CPU. - -// Test 1. Copy CPU Q matrix to GPU call freq init validator -// launch kernel - -// Test 2. Copy CPU Q matrix to GPU call freq apply validator -// launch kernel - -TEST_P(LLamaRopeTestWithFP16, TestInitRopeFreq) { - constexpr float RELATIVE_EPSILON = 1e-6f; - size_t num_mismatches = 0; - auto cpu_frequencies = this->get_cpu_rope_frequencies(); - auto gpu_frequencies = this->get_gpu_rope_frequencies(); - - // Print side-by-side comparison for easier visual inspection - std::cout << "\nSide-by-side comparison:\n"; - std::cout << "Index\tCPU\t\tGPU\t\tDifference\n"; - std::cout << "-----\t---\t\t---\t\t----------\n"; - - for (size_t i = 0; i < std::min(16u, this->head_dim); ++i) { - float diff = std::abs(cpu_frequencies[i] - gpu_frequencies[i]); - std::cout << i << "\t" << cpu_frequencies[i] << "\t\t" << gpu_frequencies[i] << "\t\t" << diff - << std::endl; - } - - ASSERT_EQ(cpu_frequencies.size(), this->head_dim); - ASSERT_EQ(gpu_frequencies.size(), this->head_dim); - - for (auto i = 0ul; i < cpu_frequencies.size(); ++i) { - auto diff = std::abs(cpu_frequencies[i] - gpu_frequencies[i]); - if (diff >= RELATIVE_EPSILON) { - std::cout << "Diff : " << diff << " at feature index " << i << " " - << "cpu_frequencies[i]: " << cpu_frequencies[i] << " " - << "gpu_frequencies[i]: " << gpu_frequencies[i] << '\n'; - ++num_mismatches; - } - } - - ASSERT_EQ(num_mismatches, 0); -} - -TEST_P(LLamaRopeTestWithFP16, VectorSizeIsCorrect) { - const auto& params = GetParam(); - size_t expected_size = std::get<0>(params) * std::get<1>(params) * std::get<2>(params); - ASSERT_EQ(this->q.size(), expected_size); -} - -TEST_P(LLamaRopeTestWithFP16, TestQFragApplyRopeComparison) { - constexpr float RELATIVE_EPSILON = 1e-2f; - - auto cpu_result = this->apply_cpu_rope(744); - auto gpu_result = this->test_gpu_q_frag_apply_rope(); - - std::cout << "\n=== CPU vs GPU RoPE Application Comparison ===\n"; - std::cout << "CPU result (offset=1000, first 8 features): "; - for (size_t i = 0; i < std::min(8u, this->head_dim); ++i) { - std::cout << cpu_result[i] << " "; - } - std::cout << std::endl; - - std::cout << "GPU result (offset=1000, first 8 features): "; - for (size_t i = 0; i < std::min(8u, this->head_dim); ++i) { - std::cout << gpu_result[i] << " "; - } - std::cout << std::endl; - - // Compare element by element - size_t num_mismatches = 0; - for (size_t i = 0; i < std::min(cpu_result.size(), gpu_result.size()); ++i) { - float diff = std::abs(cpu_result[i] - gpu_result[i]); - float rel_diff = (std::abs(cpu_result[i]) > 1e-6f) ? diff / std::abs(cpu_result[i]) : diff; - - if (rel_diff > RELATIVE_EPSILON) { - std::cout << "Mismatch at feature " << i << ": CPU=" << cpu_result[i] - << " GPU=" << gpu_result[i] << " diff=" << diff << " rel_diff=" << rel_diff - << std::endl; - ++num_mismatches; - } - } - - std::cout << "Total mismatches: " << num_mismatches << " out of " << head_dim << std::endl; - - EXPECT_EQ(num_mismatches, 0) << "Found mismatches between CPU and GPU RoPE application"; -} - -INSTANTIATE_TEST_SUITE_P( - LLamaRopeTestWithFP16, LLamaRopeTestWithFP16, - ::testing::Values(std::make_tuple(256, 1, 64) // qo_len=256, num_qo_heads=1, head_dim=64 - )); From 38363818a39286d801ef48819c69f3520f1e4967 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Fri, 10 Oct 2025 18:03:22 -0400 Subject: [PATCH 104/109] Remove unused file --- examples/cpp/standalone_single_prefill.cu | 636 ---------------------- 1 file changed, 636 deletions(-) delete mode 100644 examples/cpp/standalone_single_prefill.cu diff --git a/examples/cpp/standalone_single_prefill.cu b/examples/cpp/standalone_single_prefill.cu deleted file mode 100644 index 53181273af..0000000000 --- a/examples/cpp/standalone_single_prefill.cu +++ /dev/null @@ -1,636 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -// GPU interface headers -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace flashinfer { - -// Parameter struct for SinglePrefill -template -struct SinglePrefillParams { - using DTypeQ = half; - using DTypeKV = half; - using DTypeO = DTypeOs; - using IdType = IdTypes; - - half* q; - half* k; - half* v; - DTypeO* o; - float* lse; - uint_fastdiv group_size; - - uint8_t* maybe_custom_mask; - float* maybe_alibi_slopes; - double logits_soft_cap; - double sm_scale; - double rope_rcp_scale; - double rope_rcp_theta; - - uint32_t qo_len; - uint32_t kv_len; - uint32_t num_qo_heads; - uint32_t num_kv_heads; - uint32_t q_stride_n; - uint32_t q_stride_h; - uint32_t k_stride_n; - uint32_t k_stride_h; - uint32_t v_stride_n; - uint32_t v_stride_h; - uint32_t head_dim; - int32_t window_left; - - bool partition_kv; - - __host__ __device__ __forceinline__ uint32_t get_qo_len(uint32_t batch_idx) const { - return qo_len; - } - - __host__ __device__ __forceinline__ uint32_t get_kv_len(uint32_t batch_idx) const { - return kv_len; - } -}; - -} // namespace flashinfer - -// CPU reference implementation for validation -namespace reference { - -template -std::vector single_mha(const std::vector& q, const std::vector& k, const std::vector& v, - size_t qo_len, size_t kv_len, size_t num_qo_heads, size_t num_kv_heads, - size_t head_dim, bool causal, flashinfer::QKVLayout kv_layout, - flashinfer::PosEncodingMode pos_encoding_mode, float rope_scale = 1.0f, - float rope_theta = 10000.0f) { - float sm_scale = 1.0f / std::sqrt(static_cast(head_dim)); - std::vector o(qo_len * num_qo_heads * head_dim, static_cast(0.0f)); - std::vector att(kv_len); - size_t group_size = num_qo_heads / num_kv_heads; - - for (size_t qo_head_idx = 0; qo_head_idx < num_qo_heads; ++qo_head_idx) { - size_t kv_head_idx = qo_head_idx / group_size; - - for (size_t q_idx = 0; q_idx < qo_len; ++q_idx) { - // 1. Compute attention scores - float max_val = -5e4f; - - for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { - if (causal && kv_idx > kv_len + q_idx - qo_len) { - att[kv_idx] = -5e4f; - continue; - } - - // Compute dot product between Q and K - float score = 0.0f; - for (size_t d = 0; d < head_dim; ++d) { - float q_val = 0.0f; - float k_val = 0.0f; - - // Get Q value - always NHD layout - size_t q_offset = q_idx * num_qo_heads * head_dim + qo_head_idx * head_dim + d; - q_val = static_cast(q[q_offset]); - - // Get K value - depends on layout - if (kv_layout == flashinfer::QKVLayout::kNHD) { - size_t k_offset = kv_idx * num_kv_heads * head_dim + kv_head_idx * head_dim + d; - k_val = static_cast(k[k_offset]); - } else { - size_t k_offset = kv_head_idx * kv_len * head_dim + kv_idx * head_dim + d; - k_val = static_cast(k[k_offset]); - } - - score += q_val * k_val; - } - score *= sm_scale; - - att[kv_idx] = score; - max_val = std::max(max_val, score); - } - - // 2. Apply softmax - float sum_exp = 0.0f; - for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { - if (causal && kv_idx > kv_len + q_idx - qo_len) { - att[kv_idx] = 0.0f; - } else { - att[kv_idx] = std::exp(att[kv_idx] - max_val); - sum_exp += att[kv_idx]; - } - } - - // Normalize - for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { - if (sum_exp > 0.0f) { - att[kv_idx] /= sum_exp; - } - } - - // 3. Compute weighted sum of values - for (size_t d = 0; d < head_dim; ++d) { - float weighted_sum = 0.0f; - - for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { - float v_val = 0.0f; - - // Get V value - depends on layout - if (kv_layout == flashinfer::QKVLayout::kNHD) { - size_t v_offset = kv_idx * num_kv_heads * head_dim + kv_head_idx * head_dim + d; - v_val = static_cast(v[v_offset]); - } else { - size_t v_offset = kv_head_idx * kv_len * head_dim + kv_idx * head_dim + d; - v_val = static_cast(v[v_offset]); - } - - weighted_sum += att[kv_idx] * v_val; - } - - // Store result in output - size_t o_offset = q_idx * num_qo_heads * head_dim + qo_head_idx * head_dim + d; - o[o_offset] = static_cast(weighted_sum); - } - } - } - - return o; -} - -} // namespace reference - -// Helper function to generate random data (without Thrust) -void generate_random_data(half* data, size_t size, float min_val = -1.0f, float max_val = 1.0f) { - std::vector host_data(size); - std::mt19937 rng(42); // Fixed seed for reproducibility - std::uniform_real_distribution dist(min_val, max_val); - - for (size_t i = 0; i < size; ++i) { - host_data[i] = static_cast(dist(rng)); - } - - // Copy to device - FI_GPU_CALL(gpuMemcpy(data, host_data.data(), size * sizeof(half), gpuMemcpyHostToDevice)); -} - -// Function to validate GPU results against CPU reference (simplified) -bool validate_results(const half* gpu_output, size_t gpu_size, const std::vector& cpu_output, - float rtol = 1e-3f, float atol = 1e-3f) { - if (gpu_size != cpu_output.size()) { - std::cerr << "Size mismatch: GPU=" << gpu_size << " vs CPU=" << cpu_output.size() << std::endl; - return false; - } - - // Copy GPU data to host for comparison - std::vector host_output(gpu_size); - FI_GPU_CALL( - gpuMemcpy(host_output.data(), gpu_output, gpu_size * sizeof(half), gpuMemcpyDeviceToHost)); - - int errors = 0; - float max_diff = 0.0f; - float max_rel_diff = 0.0f; - - for (size_t i = 0; i < gpu_size; ++i) { - float gpu_val = static_cast(host_output[i]); - float cpu_val = static_cast(cpu_output[i]); - float abs_diff = std::abs(gpu_val - cpu_val); - float rel_diff = (cpu_val != 0.0f) ? abs_diff / std::abs(cpu_val) : abs_diff; - - max_diff = std::max(max_diff, abs_diff); - max_rel_diff = std::max(max_rel_diff, rel_diff); - - bool close = (abs_diff <= atol + rtol * std::abs(cpu_val)); - if (!close) { - errors++; - if (errors <= 10) { // Print just a few examples - std::cerr << "Mismatch at " << i << ": GPU=" << gpu_val << " CPU=" << cpu_val - << " (diff=" << abs_diff << ")" << std::endl; - } - } - } - - float error_rate = static_cast(errors) / gpu_size; - std::cout << "\nValidation Results:" << std::endl; - std::cout << " Max absolute difference: " << max_diff << std::endl; - std::cout << " Max relative difference: " << max_rel_diff << std::endl; - std::cout << " Error rate: " << (error_rate * 100) << "% (" << errors << " / " << gpu_size - << " elements)" << std::endl; - std::cout << " Status: " << (error_rate < 0.05 ? "PASSED" : "FAILED") << std::endl; - - // Allow up to 5% error rate - return error_rate < 0.05; -} - -using namespace flashinfer; - -// Helper class to convert strings to parameters -class ArgParser { - public: - static bool get_bool(const char* arg, bool default_val) { - return arg == nullptr ? default_val : (std::string(arg) == "1" || std::string(arg) == "true"); - } - - static int get_int(const char* arg, int default_val) { - return arg == nullptr ? default_val : std::atoi(arg); - } - - static float get_float(const char* arg, float default_val) { - return arg == nullptr ? default_val : std::atof(arg); - } - - static PosEncodingMode get_pos_encoding_mode(const char* arg) { - if (arg == nullptr) return PosEncodingMode::kNone; - std::string str_val = arg; - if (str_val == "none") return PosEncodingMode::kNone; - if (str_val == "rope") return PosEncodingMode::kRoPELlama; - if (str_val == "alibi") return PosEncodingMode::kALiBi; - return PosEncodingMode::kNone; - } - - static QKVLayout get_layout(const char* arg) { - if (arg == nullptr) return QKVLayout::kNHD; - std::string str_val = arg; - if (str_val == "nhd") return QKVLayout::kNHD; - if (str_val == "hnd") return QKVLayout::kHND; - return QKVLayout::kNHD; - } -}; - -// Dispatch function for half precision -gpuError_t dispatch_single_prefill(half* q_ptr, half* k_ptr, half* v_ptr, half* o_ptr, - half* tmp_ptr, float* lse_ptr, uint32_t num_qo_heads, - uint32_t num_kv_heads, uint32_t qo_len, uint32_t kv_len, - uint32_t head_dim, QKVLayout kv_layout, - PosEncodingMode pos_encoding_mode, bool causal, - bool use_fp16_qk_reduction, double sm_scale, int32_t window_left, - double rope_scale, double rope_theta, gpuStream_t stream) { - // Compute strides based on layout - uint32_t q_stride_n = num_qo_heads * head_dim; - uint32_t q_stride_h = head_dim; - uint32_t k_stride_n, k_stride_h, v_stride_n, v_stride_h; - - if (kv_layout == QKVLayout::kNHD) { - k_stride_n = num_kv_heads * head_dim; - k_stride_h = head_dim; - v_stride_n = num_kv_heads * head_dim; - v_stride_h = head_dim; - } else { - k_stride_h = kv_len * head_dim; - k_stride_n = head_dim; - v_stride_h = kv_len * head_dim; - v_stride_n = head_dim; - } - - // Configure mask mode - const MaskMode mask_mode = causal ? MaskMode::kCausal : MaskMode::kNone; - - // Constants for prefill kernel - constexpr uint32_t HEAD_DIM_QK = 128; - constexpr uint32_t HEAD_DIM_VO = 128; - constexpr PosEncodingMode POS_ENCODING_MODE = PosEncodingMode::kRoPELlama; - constexpr bool USE_FP16_QK_REDUCTION = false; - - gpuError_t status = gpuSuccess; - - if (causal) { - // Causal attention - using AttentionVariantType = DefaultAttention; - using Params = SinglePrefillParams; - - Params params; - params.q = q_ptr; - params.k = k_ptr; - params.v = v_ptr; - params.o = o_ptr; - params.lse = lse_ptr; - params.num_qo_heads = num_qo_heads; - params.num_kv_heads = num_kv_heads; - params.group_size = uint_fastdiv(num_qo_heads / num_kv_heads); - params.qo_len = qo_len; - params.kv_len = kv_len; - params.q_stride_n = q_stride_n; - params.q_stride_h = q_stride_h; - params.k_stride_n = k_stride_n; - params.k_stride_h = k_stride_h; - params.v_stride_n = v_stride_n; - params.v_stride_h = v_stride_h; - params.head_dim = head_dim; - params.window_left = window_left; - params.partition_kv = false; - params.maybe_custom_mask = nullptr; - params.maybe_alibi_slopes = nullptr; - params.logits_soft_cap = 0.0; - params.sm_scale = sm_scale; - params.rope_rcp_scale = 1.0 / rope_scale; - params.rope_rcp_theta = 1.0 / rope_theta; - - status = SinglePrefillWithKVCacheDispatched(params, tmp_ptr, stream); - } else { - // Non-causal attention - using AttentionVariantType = DefaultAttention; - using Params = SinglePrefillParams; - - Params params; - params.q = q_ptr; - params.k = k_ptr; - params.v = v_ptr; - params.o = o_ptr; - params.lse = lse_ptr; - params.num_qo_heads = num_qo_heads; - params.num_kv_heads = num_kv_heads; - params.group_size = uint_fastdiv(num_qo_heads / num_kv_heads); - params.qo_len = qo_len; - params.kv_len = kv_len; - params.q_stride_n = q_stride_n; - params.q_stride_h = q_stride_h; - params.k_stride_n = k_stride_n; - params.k_stride_h = k_stride_h; - params.v_stride_n = v_stride_n; - params.v_stride_h = v_stride_h; - params.head_dim = head_dim; - params.window_left = window_left; - params.partition_kv = false; - params.maybe_custom_mask = nullptr; - params.maybe_alibi_slopes = nullptr; - params.logits_soft_cap = 0.0; - params.sm_scale = sm_scale; - params.rope_rcp_scale = 1.0 / rope_scale; - params.rope_rcp_theta = 1.0 / rope_theta; - - status = SinglePrefillWithKVCacheDispatched(params, tmp_ptr, stream); - } - - return status; -} - -// Function to calculate FLOPs for single_prefill -double calculate_flops(uint32_t qo_len, uint32_t kv_len, uint32_t num_qo_heads, uint32_t head_dim, - bool causal) { - double flops; - if (causal) { - // For causal attention: qo_len * (2 * kv_len - qo_len) * 2 * - // num_qo_heads * head_dim - flops = static_cast(qo_len) * (2.0 * kv_len - qo_len) * 2.0 * num_qo_heads * head_dim; - } else { - // For non-causal attention: qo_len * kv_len * 4 * num_qo_heads * - // head_dim - flops = static_cast(qo_len) * kv_len * 4.0 * num_qo_heads * head_dim; - } - return flops; -} - -void print_usage(const char* program_name) { - std::cerr << "Usage: " << program_name << " [options]\n" - << "Options:\n" - << " --qo_len : Query sequence length (default: " - "512)\n" - << " --kv_len : Key/value sequence length (default: " - "512)\n" - << " --num_qo_heads : Number of query heads (default: 32)\n" - << " --num_kv_heads : Number of key/value heads (default: " - "32)\n" - << " --head_dim : Head dimension (default: 128)\n" - << " --layout : KV tensor layout (default: nhd)\n" - << " --pos_encoding : Position encoding mode " - "(default: none)\n" - << " --causal <0|1> : Use causal mask (default: 1)\n" - << " --use_fp16_qk <0|1> : Use FP16 for QK reduction (default: " - "0)\n" - << " --window_left : Window left size (default: -1)\n" - << " --rope_scale : RoPE scale factor (default: 1.0)\n" - << " --rope_theta : RoPE theta (default: 10000.0)\n" - << " --iterations : Number of iterations for timing " - "(default: 10)\n" - << " --warmup : Number of warmup iterations " - "(default: 5)\n" - << " --validate <0|1> : Validate against CPU reference " - "(default: 0)\n"; -} - -// Main function with simplified memory management -int main(int argc, char* argv[]) { - if (argc > 1 && (std::string(argv[1]) == "--help" || std::string(argv[1]) == "-h")) { - print_usage(argv[0]); - return 0; - } - - // Process parameter pairs (--param value) - uint32_t qo_len = 512; - uint32_t kv_len = 512; - uint32_t num_qo_heads = 32; - uint32_t num_kv_heads = 32; - uint32_t head_dim = 128; - QKVLayout kv_layout = QKVLayout::kNHD; - PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone; - bool causal = true; - bool use_fp16_qk_reduction = false; - int32_t window_left = -1; - float rope_scale = 1.0f; - float rope_theta = 10000.0f; - int iterations = 10; - int warmup = 5; - bool validate = false; - - for (int i = 1; i < argc; i += 2) { - std::string arg = argv[i]; - if (i + 1 >= argc && arg != "--help" && arg != "-h") { - std::cerr << "Missing value for parameter " << arg << std::endl; - print_usage(argv[0]); - return 1; - } - - if (arg == "--qo_len") { - qo_len = ArgParser::get_int(argv[i + 1], 512); - } else if (arg == "--kv_len") { - kv_len = ArgParser::get_int(argv[i + 1], 512); - } else if (arg == "--num_qo_heads") { - num_qo_heads = ArgParser::get_int(argv[i + 1], 32); - } else if (arg == "--num_kv_heads") { - num_kv_heads = ArgParser::get_int(argv[i + 1], 32); - } else if (arg == "--head_dim") { - head_dim = ArgParser::get_int(argv[i + 1], 128); - } else if (arg == "--layout") { - kv_layout = ArgParser::get_layout(argv[i + 1]); - } else if (arg == "--pos_encoding") { - pos_encoding_mode = ArgParser::get_pos_encoding_mode(argv[i + 1]); - } else if (arg == "--causal") { - causal = ArgParser::get_bool(argv[i + 1], true); - } else if (arg == "--use_fp16_qk") { - use_fp16_qk_reduction = ArgParser::get_bool(argv[i + 1], false); - } else if (arg == "--window_left") { - window_left = ArgParser::get_int(argv[i + 1], -1); - } else if (arg == "--rope_scale") { - rope_scale = ArgParser::get_float(argv[i + 1], 1.0f); - } else if (arg == "--rope_theta") { - rope_theta = ArgParser::get_float(argv[i + 1], 10000.0f); - } else if (arg == "--iterations") { - iterations = ArgParser::get_int(argv[i + 1], 10); - } else if (arg == "--warmup") { - warmup = ArgParser::get_int(argv[i + 1], 5); - } else if (arg == "--validate") { - validate = ArgParser::get_bool(argv[i + 1], false); - } else { - std::cerr << "Unknown parameter: " << arg << std::endl; - print_usage(argv[0]); - return 1; - } - } - - // Print configuration - std::cout << "Configuration:" << std::endl - << " QO Length: " << qo_len << std::endl - << " KV Length: " << kv_len << std::endl - << " QO Heads: " << num_qo_heads << std::endl - << " KV Heads: " << num_kv_heads << std::endl - << " Head Dimension: " << head_dim << std::endl - << " KV Layout: " << (kv_layout == QKVLayout::kNHD ? "NHD" : "HND") << std::endl - << " Position Encoding: " - << (pos_encoding_mode == PosEncodingMode::kNone ? "None" - : pos_encoding_mode == PosEncodingMode::kRoPELlama ? "RoPE" - : "ALiBi") - << std::endl - << " Causal: " << (causal ? "Yes" : "No") << std::endl - << " Use FP16 QK Reduction: " << (use_fp16_qk_reduction ? "Yes" : "No") << std::endl - << " Window Left: " << window_left << std::endl - << " RoPE Scale: " << rope_scale << std::endl - << " RoPE Theta: " << rope_theta << std::endl - << " Iterations: " << iterations << std::endl - << " Warmup: " << warmup << std::endl - << " Validation: " << (validate ? "Yes" : "No") << std::endl; - - // Create stream - gpuStream_t stream; - FI_GPU_CALL(gpuStreamCreate(&stream)); - - // Allocate device memory using gpuMalloc instead of Thrust - half *q_dev, *k_dev, *v_dev, *o_dev, *tmp_dev; - float* lse_dev; - - size_t q_size = qo_len * num_qo_heads * head_dim; - size_t k_size = kv_len * num_kv_heads * head_dim; - size_t v_size = kv_len * num_kv_heads * head_dim; - size_t o_size = qo_len * num_qo_heads * head_dim; - size_t lse_size = qo_len * num_qo_heads; - - FI_GPU_CALL(gpuMalloc(&q_dev, q_size * sizeof(half))); - FI_GPU_CALL(gpuMalloc(&k_dev, k_size * sizeof(half))); - FI_GPU_CALL(gpuMalloc(&v_dev, v_size * sizeof(half))); - FI_GPU_CALL(gpuMalloc(&o_dev, o_size * sizeof(half))); - FI_GPU_CALL(gpuMalloc(&tmp_dev, o_size * sizeof(half))); - FI_GPU_CALL(gpuMalloc(&lse_dev, lse_size * sizeof(float))); - - // Initialize data - generate_random_data(q_dev, q_size); - generate_random_data(k_dev, k_size); - generate_random_data(v_dev, v_size); - - // Zero out output arrays - FI_GPU_CALL(gpuMemset(o_dev, 0, o_size * sizeof(half))); - FI_GPU_CALL(gpuMemset(tmp_dev, 0, o_size * sizeof(half))); - FI_GPU_CALL(gpuMemset(lse_dev, 0, lse_size * sizeof(float))); - - // Calculate SM scale - float sm_scale = 1.0f / std::sqrt(static_cast(head_dim)); - - // Warmup runs - for (int i = 0; i < warmup; ++i) { - gpuError_t status = dispatch_single_prefill( - q_dev, k_dev, v_dev, o_dev, tmp_dev, lse_dev, num_qo_heads, num_kv_heads, qo_len, kv_len, - head_dim, kv_layout, pos_encoding_mode, causal, use_fp16_qk_reduction, sm_scale, - window_left, rope_scale, rope_theta, stream); - - if (status != gpuSuccess) { - std::cerr << "Error during warmup: " << gpuGetErrorString(status) << std::endl; - return 1; - } - } - - // Timing runs - gpuEvent_t start, stop; - FI_GPU_CALL(gpuEventCreate(&start)); - FI_GPU_CALL(gpuEventCreate(&stop)); - - FI_GPU_CALL(gpuEventRecord(start, stream)); - - for (int i = 0; i < iterations; ++i) { - gpuError_t status = dispatch_single_prefill( - q_dev, k_dev, v_dev, o_dev, tmp_dev, lse_dev, num_qo_heads, num_kv_heads, qo_len, kv_len, - head_dim, kv_layout, pos_encoding_mode, causal, use_fp16_qk_reduction, sm_scale, - window_left, rope_scale, rope_theta, stream); - - if (status != gpuSuccess) { - std::cerr << "Error during benchmark: " << gpuGetErrorString(status) << std::endl; - return 1; - } - } - - FI_GPU_CALL(gpuEventRecord(stop, stream)); - FI_GPU_CALL(gpuEventSynchronize(stop)); - - float elapsed_ms; - FI_GPU_CALL(gpuEventElapsedTime(&elapsed_ms, start, stop)); - float avg_ms = elapsed_ms / iterations; - - // Calculate and report performance - double flops = calculate_flops(qo_len, kv_len, num_qo_heads, head_dim, causal); - double tflops = flops / (avg_ms * 1e-3) / 1e12; - - // Report results - std::cout << std::fixed << std::setprecision(4); - std::cout << "Performance Results:" << std::endl; - std::cout << " Average time: " << avg_ms << " ms" << std::endl; - std::cout << " Performance: " << tflops << " TFLOPS" << std::endl; - - // Run validation if requested - if (validate) { - std::cout << "\nRunning validation..." << std::endl; - - // Copy input data to host for CPU reference - std::vector h_q(q_size), h_k(k_size), h_v(v_size); - FI_GPU_CALL(gpuMemcpy(h_q.data(), q_dev, q_size * sizeof(half), gpuMemcpyHostToDevice)); - FI_GPU_CALL(gpuMemcpy(h_k.data(), k_dev, k_size * sizeof(half), gpuMemcpyHostToDevice)); - FI_GPU_CALL(gpuMemcpy(h_v.data(), v_dev, v_size * sizeof(half), gpuMemcpyHostToDevice)); - - // Compute reference output on CPU - std::vector ref_output = - reference::single_mha(h_q, h_k, h_v, qo_len, kv_len, num_qo_heads, num_kv_heads, head_dim, - causal, kv_layout, pos_encoding_mode, rope_scale, rope_theta); - - // Validate results - bool validation_passed = validate_results(o_dev, o_size, ref_output); - - // Report validation status - std::cout << "Validation " << (validation_passed ? "PASSED" : "FAILED") << std::endl; - } - - // Cleanup - FI_GPU_CALL(gpuEventDestroy(start)); - FI_GPU_CALL(gpuEventDestroy(stop)); - FI_GPU_CALL(gpuStreamDestroy(stream)); - FI_GPU_CALL(gpuFree(q_dev)); - FI_GPU_CALL(gpuFree(k_dev)); - FI_GPU_CALL(gpuFree(v_dev)); - FI_GPU_CALL(gpuFree(o_dev)); - FI_GPU_CALL(gpuFree(tmp_dev)); - FI_GPU_CALL(gpuFree(lse_dev)); - - return 0; -} From 893e43886cbda794d0e0118fffaf8a0dc39cfdd8 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Mon, 20 Oct 2025 17:12:20 +0000 Subject: [PATCH 105/109] Debug prints --- .../flashinfer/attention/generic/prefill.cuh | 28 ++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh index 38c0ec5d5c..8d9b2ae8a0 100644 --- a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -1911,7 +1911,7 @@ __device__ __forceinline__ void SinglePrefillWithKVCacheDevice( // compute attention score compute_qk(&qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag); -#if Debug +#if Debug1 flashinfer::gpu_iface::debug_utils::hip::write_s_frag_to_lds< DTypeQKAccum, NUM_MMA_Q, NUM_MMA_KV, HALF_ELEMS_PER_THREAD>(s_frag, qk_scratch, CTA_TILE_KV, tid); @@ -1936,9 +1936,35 @@ __device__ __forceinline__ void SinglePrefillWithKVCacheDevice( chunk_start + (iter * NUM_WARPS_KV + get_warp_idx_kv(tid.z)) * NUM_MMA_KV * 16, qo_len, kv_len, chunk_end, group_size, s_frag, tid, kv_head_idx); } +#if Debug1 + flashinfer::gpu_iface::debug_utils::hip::write_s_frag_to_lds< + DTypeQKAccum, NUM_MMA_Q, NUM_MMA_KV, HALF_ELEMS_PER_THREAD>(s_frag, qk_scratch, + CTA_TILE_KV, tid); + + // a) Print thread 0's registers to see the source data. + flashinfer::gpu_iface::debug_utils::hip::print_s_frag_register< + DTypeQKAccum, NUM_MMA_Q, NUM_MMA_KV, HALF_ELEMS_PER_THREAD>(s_frag, tid); + // b) Print the materialized LDS array to see the final result for this iteration. + flashinfer::gpu_iface::debug_utils::hip::print_lds_array(qk_scratch, CTA_TILE_Q, CTA_TILE_KV); + +#endif // compute m,d states in online softmax update_mdo_states(variant, s_frag, o_frag, m, d, warp_idx, lane_idx); + +#if Debug + flashinfer::gpu_iface::debug_utils::hip::write_s_frag_to_lds< + DTypeQKAccum, NUM_MMA_Q, NUM_MMA_KV, HALF_ELEMS_PER_THREAD>(s_frag, qk_scratch, + CTA_TILE_KV, tid); + + // a) Print thread 0's registers to see the source data. + flashinfer::gpu_iface::debug_utils::hip::print_s_frag_register< + DTypeQKAccum, NUM_MMA_Q, NUM_MMA_KV, HALF_ELEMS_PER_THREAD>(s_frag, tid); + + // b) Print the materialized LDS array to see the final result for this iteration. + flashinfer::gpu_iface::debug_utils::hip::print_lds_array(qk_scratch, CTA_TILE_Q, CTA_TILE_KV); + +#endif block.sync(); produce_kv( k_smem, &k_smem_offset_w, &k_ptr, k_stride_n, (iter + 1) * CTA_TILE_KV, chunk_size, tid); From ed30bf68bc0c55ee5ccd038a9ffad7b2dcf835e5 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Tue, 21 Oct 2025 14:22:31 +0000 Subject: [PATCH 106/109] Validated s_frag and m value calcs in online softmax --- .../flashinfer/attention/generic/prefill.cuh | 22 ++-- .../backend/hip/mma_debug_utils_hip.h | 114 ++++++++++++++---- 2 files changed, 101 insertions(+), 35 deletions(-) diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh index 8d9b2ae8a0..b1fd47e397 100644 --- a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -1825,6 +1825,7 @@ __device__ __forceinline__ void SinglePrefillWithKVCacheDevice( printf("KTraits::NUM_MMA_D_QK : %d\n", KTraits::NUM_MMA_D_QK); printf("NUM_MMA_KV : %d\n", NUM_MMA_KV); printf("NUM_MMA_Q : %d\n", NUM_MMA_Q); + printf("sm_scale : %f\n", variant.sm_scale_log2); #if 0 DTypeKV *k_ptr_tmp = k + (chunk_start + warp_idx * KV_THR_LAYOUT_ROW + @@ -1936,17 +1937,18 @@ __device__ __forceinline__ void SinglePrefillWithKVCacheDevice( chunk_start + (iter * NUM_WARPS_KV + get_warp_idx_kv(tid.z)) * NUM_MMA_KV * 16, qo_len, kv_len, chunk_end, group_size, s_frag, tid, kv_head_idx); } -#if Debug1 +#if Debug flashinfer::gpu_iface::debug_utils::hip::write_s_frag_to_lds< DTypeQKAccum, NUM_MMA_Q, NUM_MMA_KV, HALF_ELEMS_PER_THREAD>(s_frag, qk_scratch, CTA_TILE_KV, tid); - // a) Print thread 0's registers to see the source data. - flashinfer::gpu_iface::debug_utils::hip::print_s_frag_register< - DTypeQKAccum, NUM_MMA_Q, NUM_MMA_KV, HALF_ELEMS_PER_THREAD>(s_frag, tid); + // // a) Print thread 0's registers to see the source data. + // flashinfer::gpu_iface::debug_utils::hip::print_s_frag_register< + // DTypeQKAccum, NUM_MMA_Q, NUM_MMA_KV, HALF_ELEMS_PER_THREAD>(s_frag, tid); // b) Print the materialized LDS array to see the final result for this iteration. - flashinfer::gpu_iface::debug_utils::hip::print_lds_array(qk_scratch, CTA_TILE_Q, CTA_TILE_KV); + flashinfer::gpu_iface::debug_utils::hip::print_lds_array( + qk_scratch, CTA_TILE_Q, CTA_TILE_KV, ("S frag before update_mdo for iteration\n")); #endif // compute m,d states in online softmax @@ -1957,13 +1959,13 @@ __device__ __forceinline__ void SinglePrefillWithKVCacheDevice( DTypeQKAccum, NUM_MMA_Q, NUM_MMA_KV, HALF_ELEMS_PER_THREAD>(s_frag, qk_scratch, CTA_TILE_KV, tid); - // a) Print thread 0's registers to see the source data. - flashinfer::gpu_iface::debug_utils::hip::print_s_frag_register< - DTypeQKAccum, NUM_MMA_Q, NUM_MMA_KV, HALF_ELEMS_PER_THREAD>(s_frag, tid); + // // a) Print thread 0's registers to see the source data. + // flashinfer::gpu_iface::debug_utils::hip::print_s_frag_register< + // DTypeQKAccum, NUM_MMA_Q, NUM_MMA_KV, HALF_ELEMS_PER_THREAD>(s_frag, tid); // b) Print the materialized LDS array to see the final result for this iteration. - flashinfer::gpu_iface::debug_utils::hip::print_lds_array(qk_scratch, CTA_TILE_Q, CTA_TILE_KV); - + flashinfer::gpu_iface::debug_utils::hip::print_lds_array( + qk_scratch, CTA_TILE_Q, CTA_TILE_KV, ("S frag after update_mdo for iteration\n")); #endif block.sync(); produce_kv( diff --git a/libflashinfer/include/gpu_iface/backend/hip/mma_debug_utils_hip.h b/libflashinfer/include/gpu_iface/backend/hip/mma_debug_utils_hip.h index f8bc7dd1d2..16a9d610ec 100644 --- a/libflashinfer/include/gpu_iface/backend/hip/mma_debug_utils_hip.h +++ b/libflashinfer/include/gpu_iface/backend/hip/mma_debug_utils_hip.h @@ -125,7 +125,7 @@ __device__ void print_lds_array(float* lds_array, uint32_t dimY, uint32_t dimX, printf("%s (%dx%d):\n", title, dimX, dimY); for (int y = 0; y < dimY; ++y) { for (int x = 0; x < dimX; ++x) { - printf("%8.3f ", lds_array[y * dimX + x]); + printf("%10.6f ", lds_array[y * dimX + x]); } printf("\n"); } @@ -134,36 +134,52 @@ __device__ void print_lds_array(float* lds_array, uint32_t dimY, uint32_t dimX, __syncthreads(); } -/// @brief Materializes a 2D array of accumulator fragments from each thread's registers into a -/// 2D shared memory array. -/// @details This function is the inverse of the hardware's distribution of accumulator results. -/// It reconstructs a logical tile of the S = Q * K^T matrix in shared memory, -/// accounting for the partitioning of work across multiple warps. +/// @brief Prints a 1D LDS array of floats to the console from a single thread. +/// @details Useful for printing row-wise statistics like m or d values. +__device__ void print_lds_array_1d(float* lds_array, uint32_t dim, + const char* title = "LDS Array 1D (float)") { + if (threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0) { + printf("%s (%d elements):\n", title, dim); + for (int i = 0; i < dim; ++i) { + printf("%10.6f ", lds_array[i]); + if ((i + 1) % 16 == 0) printf("\n"); // Line break every 16 elements + } + if (dim % 16 != 0) printf("\n"); + printf("\n"); + } + __syncthreads(); +} + +/// @brief Generic function to materialize 2D fragment arrays into shared memory. +/// @details Works for both s_frag (attention scores) and o_frag (output accumulator). +/// Reconstructs a logical tile from distributed register fragments. /// @tparam T The data type of the fragments and LDS array (e.g., float or half). -/// @tparam NUM_MMA_Q The number of fragments along the Q dimension (rows) per thread. -/// @tparam NUM_MMA_KV The number of fragments along the KV dimension (columns) per thread. -/// @tparam ELEMS_PER_FRAGMENT The number of elements per fragment (typically 4 for float/half). -/// @param s_frag The 3D fragment array from the thread's registers. +/// @tparam NUM_MMA_ROW The number of fragments along the rows dimension per thread. +/// @tparam NUM_MMA_COL The number of fragments along the column dimension per thread. +/// For s_frag: NUM_MMA_KV (KV sequence length) +/// For o_frag: NUM_MMA_D_VO (head dimension) +/// @tparam ELEMS_PER_FRAGMENT The number of elements per fragment (typically 4). +/// @param frag The 3D fragment array from the thread's registers. /// @param lds_scratchpad Pointer to the shared memory array. -/// @param lds_stride The width/stride of the lds_scratchpad (e.g., CTA_TILE_KV). +/// @param lds_stride The width/stride of the lds_scratchpad. /// @param tid The thread's index within the block (threadIdx). -template -__device__ void write_s_frag_to_lds(const T (*s_frag)[NUM_MMA_KV][ELEMS_PER_FRAGMENT], - T* lds_scratchpad, const uint32_t lds_stride, - const dim3 tid = threadIdx) { +template +__device__ void write_frag_to_lds(const T (*frag)[NUM_MMA_COL][ELEMS_PER_FRAGMENT], + T* lds_scratchpad, const uint32_t lds_stride, + const dim3 tid = threadIdx) { const int lane_id = tid.x % 64; const int warp_idx_q = tid.y; // Calculate the starting row in the LDS tile for this entire warp. - const uint32_t warp_base_row = warp_idx_q * NUM_MMA_Q * MMA_COLS; + const uint32_t warp_base_row = warp_idx_q * NUM_MMA_ROW * MMA_COLS; #pragma unroll - for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { + for (uint32_t mma_q = 0; mma_q < NUM_MMA_ROW; ++mma_q) { #pragma unroll - for (uint32_t mma_kv = 0; mma_kv < NUM_MMA_KV; ++mma_kv) { + for (uint32_t mma_col = 0; mma_col < NUM_MMA_COL; ++mma_col) { // -- Calculate the top-left corner of the 16x16 fragment this thread contributes to -- const uint32_t frag_row_offset = mma_q * MMA_COLS; - const uint32_t frag_col_offset = mma_kv * MMA_COLS; + const uint32_t frag_col_offset = mma_col * MMA_COLS; // -- Calculate the specific 4x1 element strip this thread writes within that fragment -- // This logic correctly materializes a B-layout fragment (column strip). @@ -174,7 +190,7 @@ __device__ void write_s_frag_to_lds(const T (*s_frag)[NUM_MMA_KV][ELEMS_PER_FRAG const uint32_t thread_col_in_frag = (lane_id % MMA_COLS); // -- Combine all offsets and write the 4x1 column strip to LDS -- - const T* values = s_frag[mma_q][mma_kv]; + const T* values = frag[mma_q][mma_col]; for (int i = 0; i < MMA_ROWS_PER_THREAD; ++i) { // The row for this element is the thread's starting row + the element's index in the strip. const uint32_t final_row = warp_base_row + frag_row_offset + thread_start_row_in_frag + i; @@ -189,13 +205,40 @@ __device__ void write_s_frag_to_lds(const T (*s_frag)[NUM_MMA_KV][ELEMS_PER_FRAG } } +/// @brief Convenience wrapper for s_frag (attention scores). +template +__device__ void write_s_frag_to_lds(const T (*s_frag)[NUM_MMA_KV][ELEMS_PER_FRAGMENT], + T* lds_scratchpad, const uint32_t lds_stride, + const dim3 tid = threadIdx) { + write_frag_to_lds(s_frag, lds_scratchpad, + lds_stride, tid); +} + +/// @brief Convenience wrapper for o_frag (output accumulator). +template +__device__ void write_o_frag_to_lds(const T (*o_frag)[NUM_MMA_D_VO][ELEMS_PER_FRAGMENT], + T* lds_scratchpad, const uint32_t lds_stride, + const dim3 tid = threadIdx) { + write_frag_to_lds(o_frag, lds_scratchpad, + lds_stride, tid); +} + +/// @brief Generic function to materialize 1D row-wise values (m or d) into shared memory. +/// @details Writes row-wise statistics (like max or denominator) from register arrays +/// to a 1D shared memory array, with one value per row. +/// @tparam T The data type (typically float). +/// @tparam NUM_MMA_Q The number of fragments along the Q dimension per thread. +/// @tparam NUM_ACCUM_ROWS_PER_THREAD The number of accumulator rows per thread (typically 4). +/// @param values The 2D array from registers [NUM_MMA_Q][NUM_ACCUM_ROWS_PER_THREAD]. +/// @param lds_scratchpad Pointer to the 1D shared memory array. +/// @param tid The thread's index within the block (threadIdx). template -__device__ void write_m_new_to_lds(const T (*m)[NUM_ACCUM_ROWS_PER_THREAD], T* lds_scratchpad, - const dim3 tid = threadIdx) { +__device__ void write_row_values_to_lds(const T (*values)[NUM_ACCUM_ROWS_PER_THREAD], + T* lds_scratchpad, const dim3 tid = threadIdx) { const int lane_idx = tid.x; const int warp_idx_q = tid.y; - // Each group of 16 threads (a "row group") computes the max for 4 rows. + // Each group of 16 threads (a "row group") handles 4 rows. // We only need one thread from each group to write the results. if (lane_idx % MMA_COLS == 0) { // Base row index for this warp's Q tile @@ -212,13 +255,34 @@ __device__ void write_m_new_to_lds(const T (*m)[NUM_ACCUM_ROWS_PER_THREAD], T* l // e.g., lane 0 is in group 0, lane 16 is in group 1, etc. const uint32_t row_group_offset = (lane_idx / MMA_COLS) * NUM_ACCUM_ROWS_PER_THREAD; - // The final row index in the logical S matrix + // The final row index in the logical matrix const uint32_t final_row_idx = warp_base_row + mma_base_row + row_group_offset + j; - lds_scratchpad[final_row_idx] = m[mma_q][j]; + lds_scratchpad[final_row_idx] = values[mma_q][j]; } } } } +/// @brief Convenience wrapper for m (row-wise max) values. +template +__device__ void write_m_to_lds(const T (*m)[NUM_ACCUM_ROWS_PER_THREAD], T* lds_scratchpad, + const dim3 tid = threadIdx) { + write_row_values_to_lds(m, lds_scratchpad, tid); +} + +/// @brief Convenience wrapper for d (denominator) values. +template +__device__ void write_d_to_lds(const T (*d)[NUM_ACCUM_ROWS_PER_THREAD], T* lds_scratchpad, + const dim3 tid = threadIdx) { + write_row_values_to_lds(d, lds_scratchpad, tid); +} + +// Legacy alias for backward compatibility +template +__device__ void write_m_new_to_lds(const T (*m)[NUM_ACCUM_ROWS_PER_THREAD], T* lds_scratchpad, + const dim3 tid = threadIdx) { + write_m_to_lds(m, lds_scratchpad, tid); +} + } // namespace flashinfer::gpu_iface::debug_utils::hip From 75bed47ed70816a88320b90851bb2b0d70e1bf23 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Tue, 21 Oct 2025 15:17:12 +0000 Subject: [PATCH 107/109] Remove temorary scripts --- compile_test.sh | 13 ------------- test_prefill_sfrag.sh | 13 ------------- 2 files changed, 26 deletions(-) delete mode 100644 compile_test.sh delete mode 100644 test_prefill_sfrag.sh diff --git a/compile_test.sh b/compile_test.sh deleted file mode 100644 index 60ae60e96d..0000000000 --- a/compile_test.sh +++ /dev/null @@ -1,13 +0,0 @@ -amdclang++ -x hip \ - -std=c++17 \ - -I/home/AMD/diptodeb/devel/flashinfer/libflashinfer/include \ - -I/home/AMD/diptodeb/devel/flashinfer/libflashinfer \ - -I${CONDA_PREFIX}/include \ - -Wall \ - -DHIP_ENABLE_WARP_SYNC_BUILTINS=1 \ - -L${CONDA_PREFIX}/lib \ - -lgtest \ - -DDebug \ - -Wl,-rpath=${CONDA_PREFIX}/lib \ - libflashinfer/tests/hip/test_single_prefill.cpp \ - --offload-arch=gfx942 diff --git a/test_prefill_sfrag.sh b/test_prefill_sfrag.sh deleted file mode 100644 index f0bd5d0c3e..0000000000 --- a/test_prefill_sfrag.sh +++ /dev/null @@ -1,13 +0,0 @@ -#!/bin/bash - -OUTPUT_FILE="sfrag_full.log" -> $OUTPUT_FILE # Clear the file - -for warp in $(seq 0 3); do - for thread in $(seq 0 63); do - echo "Running thread ${thread}... warp ${warp}" - ./a.out --thread $thread --warp $warp >> $OUTPUT_FILE 2>&1 - done -done - -echo "All threads complete. Output in $OUTPUT_FILE" From d4b3e1368f797809814f1993bd047ba0bbe6c4e2 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Tue, 21 Oct 2025 22:07:20 -0500 Subject: [PATCH 108/109] Add instrumentation to validate compute_sfm_v --- .../flashinfer/attention/generic/prefill.cuh | 50 ++- .../include/gpu_iface/backend/hip/mma_hip.h | 19 +- libflashinfer/include/gpu_iface/mma_ops.hpp | 10 + sfrag_tester_script.py | 262 ---------------- validate_online_softmax_stateful.py | 289 ++++++++++++++++++ 5 files changed, 362 insertions(+), 268 deletions(-) delete mode 100644 sfrag_tester_script.py create mode 100755 validate_online_softmax_stateful.py diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh index b1fd47e397..b7e00555c8 100644 --- a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -1166,7 +1166,8 @@ __device__ __forceinline__ void compute_sfm_v( uint32_t* v_smem_offset_r, typename KTraits::DTypeQKAccum (*s_frag)[KTraits::NUM_MMA_KV][KTraits::HALF_ELEMS_PER_THREAD], float (*o_frag)[KTraits::NUM_MMA_D_VO][KTraits::HALF_ELEMS_PER_THREAD], - float (*d)[KTraits::NUM_ACCUM_ROWS_PER_THREAD], const dim3 tid = threadIdx) { + float (*d)[KTraits::NUM_ACCUM_ROWS_PER_THREAD], const dim3 tid = threadIdx, + typename KTraits::DTypeQKAccum* qk_scratch = nullptr) { constexpr uint32_t UPCAST_STRIDE_V = KTraits::UPCAST_STRIDE_V; constexpr uint32_t HALF_ELEMS_PER_THREAD = KTraits::HALF_ELEMS_PER_THREAD; constexpr uint32_t INT32_ELEMS_PER_THREAD = KTraits::INT32_ELEMS_PER_THREAD; @@ -1174,6 +1175,29 @@ __device__ __forceinline__ void compute_sfm_v( typename KTraits::DTypeQ s_frag_f16[KTraits::NUM_MMA_Q][KTraits::NUM_MMA_KV] [HALF_ELEMS_PER_THREAD]; +#if defined(PLATFORM_HIP_DEVICE) +#if Debug + // Print S fragment BEFORE transpose (in B/C/D layout: 128x64) + flashinfer::gpu_iface::debug_utils::hip::write_s_frag_to_lds< + typename KTraits::DTypeQKAccum, KTraits::NUM_MMA_Q, KTraits::NUM_MMA_KV, + KTraits::NUM_ACCUM_ROWS_PER_THREAD>(s_frag, qk_scratch, tid); + flashinfer::gpu_iface::debug_utils::hip::print_lds_array( + qk_scratch, KTraits::CTA_TILE_Q, KTraits::CTA_TILE_KV, + "S frag BEFORE transpose (B/C/D layout, 128x64)"); +#endif + // In-place transposition of the s_frag MMA tile to get the data into CDNA3 A-matrix layout. + mma::transpose_mma_tile(reinterpret_cast(s_frag)); +#if Debug + // Print S fragment AFTER transpose (in A-matrix layout: 64x128) + flashinfer::gpu_iface::debug_utils::hip::write_s_frag_to_lds< + typename KTraits::DTypeQKAccum, KTraits::NUM_MMA_Q, KTraits::NUM_MMA_KV, + KTraits::NUM_ACCUM_ROWS_PER_THREAD>(s_frag, qk_scratch, tid); + flashinfer::gpu_iface::debug_utils::hip::print_lds_array( + qk_scratch, KTraits::CTA_TILE_KV, KTraits::CTA_TILE_Q, + "S frag AFTER transpose (A-matrix layout, 64x128)"); +#endif +#endif + if constexpr (std::is_same_v) { #pragma unroll for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { @@ -1204,6 +1228,15 @@ __device__ __forceinline__ void compute_sfm_v( } } +#if Debug1 + // Print d values after update_mdo_states + flashinfer::gpu_iface::debug_utils::hip::write_d_to_lds( + d, qk_scratch, tid); + flashinfer::gpu_iface::debug_utils::hip::print_lds_array_1d( + qk_scratch, KTraits::CTA_TILE_Q, "--- d values after rowsum inside compute_sfm_v ---"); +#endif + #pragma unroll for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { #pragma unroll @@ -1923,8 +1956,8 @@ __device__ __forceinline__ void SinglePrefillWithKVCacheDevice( // b) Print the materialized LDS array to see the final result for this iteration. flashinfer::gpu_iface::debug_utils::hip::print_lds_array(qk_scratch, CTA_TILE_Q, CTA_TILE_KV); - #endif + logits_transform( params, variant, /*batch_idx=*/0, qo_packed_idx_base, chunk_start + (iter * NUM_WARPS_KV + get_warp_idx_kv(tid.z)) * NUM_MMA_KV * 16, @@ -1937,7 +1970,7 @@ __device__ __forceinline__ void SinglePrefillWithKVCacheDevice( chunk_start + (iter * NUM_WARPS_KV + get_warp_idx_kv(tid.z)) * NUM_MMA_KV * 16, qo_len, kv_len, chunk_end, group_size, s_frag, tid, kv_head_idx); } -#if Debug +#if Debug1 flashinfer::gpu_iface::debug_utils::hip::write_s_frag_to_lds< DTypeQKAccum, NUM_MMA_Q, NUM_MMA_KV, HALF_ELEMS_PER_THREAD>(s_frag, qk_scratch, CTA_TILE_KV, tid); @@ -1954,7 +1987,7 @@ __device__ __forceinline__ void SinglePrefillWithKVCacheDevice( // compute m,d states in online softmax update_mdo_states(variant, s_frag, o_frag, m, d, warp_idx, lane_idx); -#if Debug +#if Debug1 flashinfer::gpu_iface::debug_utils::hip::write_s_frag_to_lds< DTypeQKAccum, NUM_MMA_Q, NUM_MMA_KV, HALF_ELEMS_PER_THREAD>(s_frag, qk_scratch, CTA_TILE_KV, tid); @@ -1966,6 +1999,13 @@ __device__ __forceinline__ void SinglePrefillWithKVCacheDevice( // b) Print the materialized LDS array to see the final result for this iteration. flashinfer::gpu_iface::debug_utils::hip::print_lds_array( qk_scratch, CTA_TILE_Q, CTA_TILE_KV, ("S frag after update_mdo for iteration\n")); + + // c) Print d values after update_mdo_states + flashinfer::gpu_iface::debug_utils::hip::write_d_to_lds( + d, qk_scratch, tid); + flashinfer::gpu_iface::debug_utils::hip::print_lds_array_1d( + qk_scratch, CTA_TILE_Q, "--- d values after update_mdo_states ---"); #endif block.sync(); produce_kv( @@ -1975,7 +2015,7 @@ __device__ __forceinline__ void SinglePrefillWithKVCacheDevice( block.sync(); // compute sfm*v - compute_sfm_v(&v_smem, &v_smem_offset_r, s_frag, o_frag, d, tid); + compute_sfm_v(&v_smem, &v_smem_offset_r, s_frag, o_frag, d, tid, qk_scratch); block.sync(); produce_kv( v_smem, &v_smem_offset_w, &v_ptr, v_stride_n, (iter + 1) * CTA_TILE_KV, chunk_size, tid); diff --git a/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h b/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h index 1316e454ce..1cc9ea8863 100644 --- a/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h +++ b/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h @@ -111,6 +111,24 @@ __device__ __forceinline__ void transpose_inter_quad_fragments(uint32_t* R) { R[1] = __shfl_xor(R[1], xor_mask, 64); } +/// @brief Performs a full 16x16 in-register matrix transpose by combining intra-quad and +/// inter-quad fragment transpositions. +/// @details This function converts between A-matrix layout (row-major) and B/C/D-matrix layout +/// (column-major) for CDNA3 MFMA operations. It applies both +/// transpose_intra_quad_fragments and transpose_inter_quad_fragments to fully transpose a +/// 16x16 tile distributed across 64 threads. +/// +/// Use cases: +/// - B→A layout: Convert column slices to row slices (e.g., for rowsum where S must be +/// A-matrix) +/// - A→B layout: Convert row slices to column slices (if needed for other operations) +/// +/// @param R Pointer to 2 uint32_t registers containing the fragment data +__device__ __forceinline__ void transpose_mma_tile(uint32_t* R) { + transpose_intra_quad_fragments(R); + transpose_inter_quad_fragments(R); +} + // Single unified load function for all fragment types /// @param R [in] pointer to the register file to load the fragment into /// @param smem_ptr [in] pointer to the shared memory to load the fragment from @@ -191,7 +209,6 @@ __device__ __forceinline__ void load_quad_transposed_fragment(uint32_t* R, const template __device__ __forceinline__ void m16k16_rowsum_f16f16f32(float* d, DType* s_frag) { static_assert(sizeof(DType) == 2, "DType must be 16-bit type"); - transpose_intra_quad_fragments(reinterpret_cast(s_frag)); f16x4 a = reinterpret_cast(s_frag)[0]; f16x4 b = {f16(1.0f), f16(1.0f), f16(1.0f), f16(1.0f)}; f32x4 c = {d[0], d[1], d[2], d[3]}; diff --git a/libflashinfer/include/gpu_iface/mma_ops.hpp b/libflashinfer/include/gpu_iface/mma_ops.hpp index b015b116a7..78264ac6e2 100644 --- a/libflashinfer/include/gpu_iface/mma_ops.hpp +++ b/libflashinfer/include/gpu_iface/mma_ops.hpp @@ -40,6 +40,16 @@ __device__ __forceinline__ void load_quad_transposed_fragment(uint32_t* R, const "Only __half is supported for load_quad_transposed_fragment"); mma_detail::load_quad_transposed_fragment(R, smem_ptr); } + +/*! + * \brief Performs a full 16x16 in-register matrix transpose for CDNA3 MFMA tiles + * \details Converts between A-matrix layout (row-major) and B/C/D-matrix layout (column-major) + * by combining intra-quad and inter-quad fragment transpositions. + * \param R Pointer to 2 uint32_t registers containing the fragment data + */ +__device__ __forceinline__ void transpose_mma_tile(uint32_t* R) { + mma_detail::transpose_mma_tile(R); +} #endif /*! diff --git a/sfrag_tester_script.py b/sfrag_tester_script.py deleted file mode 100644 index 32c239cc39..0000000000 --- a/sfrag_tester_script.py +++ /dev/null @@ -1,262 +0,0 @@ -#!/usr/bin/env python3 -import re -import sys - -import numpy as np -import pandas as pd - - -def parse_sfrag_log( - log_file_path, num_threads=64, num_warps=4, num_mma_q=2, num_mma_kv=4 -): - """ - Parse s_frag debug output from multiple thread/warp runs into a 128x128 DataFrame. - - Args: - log_file_path: Path to the concatenated log file - num_threads: Number of threads per warp (default 64) - num_warps: Number of warps (default 4) - num_mma_q: NUM_MMA_Q value (default 2) - num_mma_kv: NUM_MMA_KV value (default 4) - - Returns: - DataFrame with shape (128, 128) containing the s_frag values - """ - - # Initialize the full result matrix (128x128) - matrix = np.zeros((128, 128)) - - # Read the log file - with open(log_file_path, "r") as f: - lines = f.readlines() - - # Track current thread, warp and value position - current_thread = -1 - current_warp = -1 - values = [] - - for line in lines: - line = line.strip() - - # Check if this is a thread ID line - if line.startswith("Debug thread ID set to:"): - if current_thread >= 0 and current_warp >= 0 and values: - # Process the previous thread's data - populate_matrix_with_warp(matrix, current_thread, current_warp, values) - - # Extract thread ID - current_thread = int(line.split(":")[-1].strip()) - values = [] - - # Check if this is a warp ID line - elif line.startswith("Debug warp ID set to:"): - current_warp = int(line.split(":")[-1].strip()) - - # Otherwise, it should be a line of float values - elif line and current_thread >= 0 and current_warp >= 0: - # Parse the float values from the line - try: - line_values = [float(x) for x in line.split()] - values.extend(line_values) - except ValueError: - # Skip lines that can't be parsed as floats - continue - - # Don't forget to process the last thread - if current_thread >= 0 and current_warp >= 0 and values: - populate_matrix_with_warp(matrix, current_thread, current_warp, values) - - # Create DataFrame with appropriate column and row labels - df = pd.DataFrame(matrix) - df.index = [f"Row_{i}" for i in range(128)] - df.columns = [f"Col_{i}" for i in range(128)] - - return df - - -def populate_matrix_with_warp(matrix, thread_id, warp_id, values): - """ - Populate the matrix with values from a specific thread and warp. - - Args: - matrix: The 128x128 numpy array to populate - thread_id: The thread ID (0-63) - warp_id: The warp ID (0-3) - values: List of 64 float values from this thread - """ - - if len(values) != 64: - print( - f"Warning: Thread {thread_id} Warp {warp_id} has {len(values)} values instead of 64" - ) - return - - # Calculate base row and column for this thread within its warp - # Each warp handles 32 rows (warp 0: rows 0-31, warp 1: rows 32-63, etc.) - warp_row_offset = warp_id * 32 - thread_row_base = (thread_id // 16) * 4 - row_base = warp_row_offset + thread_row_base - col_base = thread_id % 16 - - # Split values into two calls (32 values each) - first_call = values[:32] - second_call = values[32:] - - # Process first call (columns 0-63) - process_call_values(matrix, first_call, row_base, col_base, col_offset=0) - - # Process second call (columns 64-127) - process_call_values(matrix, second_call, row_base, col_base, col_offset=64) - - -def process_call_values(matrix, values, row_base, col_base, col_offset): - """ - Process 32 values from one call according to the nested loop pattern. - - Args: - matrix: The matrix to populate - values: 32 values from one call - row_base: Base row for this thread - col_base: Base column for this thread - col_offset: Column offset (0 for first call, 64 for second call) - """ - - value_idx = 0 - current_row = row_base - current_col = col_base + col_offset - - # Outer loop: 2 iterations (NUM_MMA_Q) - for mma_q in range(2): - # Middle loop: 4 iterations (NUM_MMA_KV) - for mma_kv in range(4): - # Inner loop: 4 values - for i in range(4): - if value_idx < len(values): - # Place values in consecutive rows, same column - matrix[current_row + i, current_col] = values[value_idx] - value_idx += 1 - - # After inner loop, move to next column set - current_col += 16 - - # After middle loop, reset column and move to next row set - current_col = col_base + col_offset - current_row += 16 - - -def print_matrix_info(df): - """Print summary information about the populated matrix.""" - print(f"Matrix shape: {df.shape}") - print(f"Non-zero elements: {(df != 0).sum().sum()}") - print(f"Matrix statistics:") - print(f" Min: {df.min().min():.6f}") - print(f" Max: {df.max().max():.6f}") - print(f" Mean: {df.mean().mean():.6f}") - print(f" Std: {df.values.std():.6f}") - - # Check which warps have been populated - print("\nWarp population check:") - for warp in range(4): - start_row = warp * 32 - end_row = (warp + 1) * 32 - warp_data = df.iloc[start_row:end_row, :] - non_zero = (warp_data != 0).sum().sum() - print( - f" Warp {warp} (rows {start_row}-{end_row-1}): {non_zero} non-zero elements" - ) - - -def save_results(df, output_prefix="sfrag_matrix"): - """Save the DataFrame in multiple formats.""" - # Save as CSV - csv_file = f"{output_prefix}.csv" - df.to_csv(csv_file) - print(f"Saved matrix to {csv_file}") - - # Save as pickle for exact preservation - pickle_file = f"{output_prefix}.pkl" - df.to_pickle(pickle_file) - print(f"Saved matrix to {pickle_file}") - - # Save a heatmap visualization - try: - import matplotlib.pyplot as plt - import seaborn as sns - - plt.figure(figsize=(20, 20)) - sns.heatmap( - df, - cmap="RdBu_r", - center=0, - cbar_kws={"label": "Value"}, - xticklabels=16, - yticklabels=16, # Show every 16th label - ) - plt.title("S_FRAG Matrix Heatmap (Full 128x128)") - plt.xlabel("Column") - plt.ylabel("Row") - - # Add grid lines to show warp boundaries - for i in range(1, 4): - plt.axhline(y=i * 32, color="black", linewidth=2, alpha=0.5) - - plt.tight_layout() - plt.savefig(f"{output_prefix}_heatmap.png", dpi=150) - plt.close() - print(f"Saved heatmap to {output_prefix}_heatmap.png") - - # Also save individual warp heatmaps - fig, axes = plt.subplots(2, 2, figsize=(20, 20)) - for warp in range(4): - ax = axes[warp // 2, warp % 2] - start_row = warp * 32 - end_row = (warp + 1) * 32 - warp_data = df.iloc[start_row:end_row, :] - sns.heatmap( - warp_data, cmap="RdBu_r", center=0, ax=ax, cbar_kws={"label": "Value"} - ) - ax.set_title(f"Warp {warp} (Rows {start_row}-{end_row-1})") - ax.set_xlabel("Column") - ax.set_ylabel("Row (relative to warp)") - - plt.tight_layout() - plt.savefig(f"{output_prefix}_warps_heatmap.png", dpi=150) - plt.close() - print(f"Saved per-warp heatmap to {output_prefix}_warps_heatmap.png") - - except ImportError: - print("Matplotlib/Seaborn not available, skipping heatmap") - - -def main(): - if len(sys.argv) < 2: - print("Usage: python parse_sfrag.py [output_prefix]") - sys.exit(1) - - log_file = sys.argv[1] - output_prefix = sys.argv[2] if len(sys.argv) > 2 else "sfrag_matrix_full" - - print(f"Parsing log file: {log_file}") - - # Parse the log file - df = parse_sfrag_log(log_file) - - # Print summary information - print_matrix_info(df) - - # Save results - save_results(df, output_prefix) - - # Print a sample of the matrix - print("\nSample of the matrix (first 8x8 block):") - print(df.iloc[:8, :8]) - - print("\nSample from each warp (first 4x4 block):") - for warp in range(4): - start_row = warp * 32 - print(f"\nWarp {warp}:") - print(df.iloc[start_row : start_row + 4, :4]) - - -if __name__ == "__main__": - main() diff --git a/validate_online_softmax_stateful.py b/validate_online_softmax_stateful.py new file mode 100755 index 0000000000..be282d6a7e --- /dev/null +++ b/validate_online_softmax_stateful.py @@ -0,0 +1,289 @@ +#!/usr/bin/env python3 +""" +Stateful validation of online softmax across multiple iterations. +Maintains running m (maximum) state across KV chunks. + +Usage: + ./validate_online_softmax_stateful.py [LOG_FILE] + +Examples: + ./validate_online_softmax_stateful.py prefill.log + ./validate_online_softmax_stateful.py my_debug.log + ./validate_online_softmax_stateful.py # defaults to prefill.log +""" + +import argparse +import re +import sys +from pathlib import Path + +import numpy as np + +# ============================================================================ +# PARSING FUNCTIONS +# ============================================================================ + + +def parse_sm_scale(lines): + """Extract sm_scale from log file.""" + for line in lines: + if "sm_scale" in line: + match = re.search(r"sm_scale\s*:\s*([\d.]+)", line) + if match: + return float(match.group(1)) + raise ValueError("Could not find sm_scale in log file") + + +def parse_matrix(lines, start_line): + """Parse matrix starting from start_line. Returns 128×64 matrix.""" + data = [] + for i in range(start_line, len(lines)): + line = lines[i] + if "frag" in line or "DEBUG" in line or line.strip().startswith("num_"): + break + if line.strip(): + nums = re.findall(r"-?\d+\.\d+", line) + if nums: + data.extend([float(x) for x in nums]) + + if len(data) == 0: + return None + + expected_size = 128 * 64 + if len(data) != expected_size: + print(f"Warning: Expected {expected_size} values, got {len(data)}") + return None + + return np.array(data).reshape(128, 64) + + +def find_iteration_data(lines, iteration_num): + """ + Find before and after matrices for a given iteration. + Returns (before_matrix, after_matrix). + """ + # Find before data + iter_count = 0 + before_line = None + for i, line in enumerate(lines): + if "S frag before update_mdo for iteration" in line: + if iter_count == iteration_num: + before_line = i + 2 + iter_count += 1 + + # Find after data + iter_count = 0 + after_line = None + for i, line in enumerate(lines): + if "S frag after update_mdo for iteration" in line: + if iter_count == iteration_num: + after_line = i + 2 + break + iter_count += 1 + + if before_line is None or after_line is None: + return None, None + + before = parse_matrix(lines, before_line) + after = parse_matrix(lines, after_line) + + return before, after + + +def validate_with_state(before_row, after_row, m_prev, sm_scale): + """ + Validate online softmax transformation with stateful m. + + Args: + before_row: Raw scores for this chunk (64 values) + after_row: Transformed scores (64 values) + m_prev: Maximum from previous chunks (scalar) + sm_scale: Softmax scale factor (scalar) + + Returns: + (is_valid, m_new, max_error) + """ + # Step 1: Find maximum in current chunk + m_chunk = before_row.max() + + # Step 2: Update running maximum + m_new = max(m_prev, m_chunk) + + # Step 3: Apply transformation to each element + expected = np.exp2((before_row - m_new) * sm_scale) + + # Step 4: Compare with actual + diff = np.abs(after_row - expected) + max_error = diff.max() + + # Tolerance for floating point comparison + tolerance = 5e-3 # 0.005 + is_valid = max_error < tolerance + + return is_valid, m_new, max_error + + +# ============================================================================ +# VALIDATION ORCHESTRATION +# ============================================================================ + + +def validate_all_iterations(lines, sm_scale, num_iterations=2): + """ + Validate all iterations with proper state management. + + Returns: + (total_passed, total_rows, max_error_overall) + """ + print(f"{'='*80}") + print(f"STATEFUL ONLINE SOFTMAX VALIDATION") + print(f"{'='*80}") + print(f"sm_scale: {sm_scale}") + print(f"Formula: s_after = exp2((s_before - m_new) * sm_scale)") + print(f" where m_new = max(m_prev, max(s_before_chunk))") + print(f"{'='*80}\n") + + total_passed = 0 + total_rows = 0 + max_error_overall = 0.0 + + # Initialize m_prev to -inf for first iteration + m_prev_per_row = np.full(128, -np.inf) + + for iteration in range(num_iterations): + print(f"\n{'─'*80}") + print(f"ITERATION {iteration}") + print(f"{'─'*80}") + + before, after = find_iteration_data(lines, iteration) + + if before is None or after is None: + print(f"❌ Could not find data for iteration {iteration}") + continue + + print(f"Matrix shape: {before.shape}") + + passed = 0 + failed = 0 + max_error_iter = 0.0 + + # Process each row with its own m_prev + for row_idx in range(128): + before_row = before[row_idx, :] + after_row = after[row_idx, :] + m_prev = m_prev_per_row[row_idx] + + is_valid, m_new, max_error = validate_with_state( + before_row, after_row, m_prev, sm_scale + ) + + # Update running m for this row + m_prev_per_row[row_idx] = m_new + + max_error_iter = max(max_error_iter, max_error) + max_error_overall = max(max_error_overall, max_error) + + if is_valid: + passed += 1 + else: + failed += 1 + if failed <= 3: # Show first 3 failures + print( + f" ❌ Row {row_idx}: m_prev={m_prev:.6f}, " + f"m_chunk={before_row.max():.6f}, " + f"m_new={m_new:.6f}, max_error={max_error:.6e}" + ) + + total_passed += passed + total_rows += 128 + + print(f"\nIteration {iteration} Results:") + print(f" ✓ Passed: {passed}/128 rows") + print(f" ✗ Failed: {failed}/128 rows") + print(f" 📊 Max error: {max_error_iter:.6e}") + + if failed == 0: + print(f" 🎉 ITERATION {iteration} VALIDATED SUCCESSFULLY!") + + # Debug: Show sample row state + sample_row = 0 + print(f"\n Sample (Row {sample_row}):") + print(f" m_prev: {m_prev_per_row[sample_row]:.6f}") + print(f" m_chunk: {before[sample_row, :].max():.6f}") + print(f" m_new: {m_prev_per_row[sample_row]:.6f}") + + print(f"\n{'='*80}") + print(f"OVERALL RESULTS") + print(f"{'='*80}") + print( + f"Total rows validated: {total_passed}/{total_rows} ({100*total_passed/total_rows:.1f}%)" + ) + print(f"Max error across all iterations: {max_error_overall:.6e}") + + if total_passed == total_rows: + print(f"\n🎉 ALL ROWS VALIDATED SUCCESSFULLY!") + print(f"✅ Online softmax is correctly implemented with stateful m propagation") + return True + else: + print(f"\n⚠️ VALIDATION INCOMPLETE: {total_rows - total_passed} rows failed") + return False + + +# ============================================================================ +# MAIN ENTRY POINT +# ============================================================================ + + +def main(): + parser = argparse.ArgumentParser( + description="Validate online softmax with stateful m propagation", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + %(prog)s prefill.log Validate using prefill.log + %(prog)s my_debug.log Validate using my_debug.log + %(prog)s Validate using prefill.log (default) + """, + ) + parser.add_argument( + "logfile", + nargs="?", + default="prefill.log", + help="Path to log file (default: prefill.log)", + ) + parser.add_argument( + "-n", + "--num-iterations", + type=int, + default=2, + help="Number of iterations to validate (default: 2)", + ) + + args = parser.parse_args() + + logfile = Path(args.logfile) + + if not logfile.exists(): + print(f"❌ Error: Log file '{logfile}' not found") + sys.exit(1) + + print(f"Reading log file: {logfile}") + + with open(logfile, "r") as f: + lines = f.readlines() + + print(f"Loaded {len(lines)} lines from {logfile}\n") + + try: + sm_scale = parse_sm_scale(lines) + except ValueError as e: + print(f"❌ Error: {e}") + sys.exit(1) + + success = validate_all_iterations(lines, sm_scale, args.num_iterations) + + sys.exit(0 if success else 1) + + +if __name__ == "__main__": + main() From 6be1fb72631549c0987526a6e5e639334ff09947 Mon Sep 17 00:00:00 2001 From: Diptorup Deb Date: Tue, 21 Oct 2025 22:10:06 -0500 Subject: [PATCH 109/109] Remove old test cases --- libflashinfer/tests/hip/test_compute_sfm.cpp | 189 -------------- .../hip/test_inplace_transpose_loads.cpp | 211 --------------- .../tests/hip/test_load_q_global_smem_v1.cpp | 170 ------------ .../tests/hip/test_load_q_global_smem_v2.cpp | 247 ------------------ libflashinfer/tests/hip/test_produce_kv.cpp | 204 --------------- 5 files changed, 1021 deletions(-) delete mode 100644 libflashinfer/tests/hip/test_compute_sfm.cpp delete mode 100644 libflashinfer/tests/hip/test_inplace_transpose_loads.cpp delete mode 100644 libflashinfer/tests/hip/test_load_q_global_smem_v1.cpp delete mode 100644 libflashinfer/tests/hip/test_load_q_global_smem_v2.cpp delete mode 100644 libflashinfer/tests/hip/test_produce_kv.cpp diff --git a/libflashinfer/tests/hip/test_compute_sfm.cpp b/libflashinfer/tests/hip/test_compute_sfm.cpp deleted file mode 100644 index 09ac6fddec..0000000000 --- a/libflashinfer/tests/hip/test_compute_sfm.cpp +++ /dev/null @@ -1,189 +0,0 @@ -// SPDX - FileCopyrightText : 2023 - 2025 Flashinfer team -// SPDX - FileCopyrightText : 2025 Advanced Micro Devices, Inc. -// -// SPDX - License - Identifier : Apache 2.0 - -#include - -#include - -#include "../../utils/flashinfer_prefill_ops.hip.h" -#include "../../utils/utils_hip.h" -#include "flashinfer/attention/generic/prefill.cuh" -#include "gpu_iface/gpu_runtime_compat.hpp" - -#define HIP_ENABLE_WARP_SYNC_BUILTINS 1 - -using namespace flashinfer; - -namespace { -template -std::vector test_compute_qk_and_softmax_cpu( - const std::vector& q, const std::vector& k, const std::vector& v, - size_t qo_len, size_t kv_len, size_t num_qo_heads, size_t num_kv_heads, size_t head_dim, - bool causal = true, QKVLayout kv_layout = QKVLayout::kHND, - PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, float rope_scale = 1.f, - float rope_theta = 1e4) { - assert(qo_len <= kv_len); - assert(num_qo_heads % num_kv_heads == 0); - float sm_scale = 1.f / std::sqrt(float(head_dim)); - std::vector o(qo_len * num_qo_heads * head_dim); - std::vector att(kv_len); - std::vector q_rotary_local(head_dim); - std::vector k_rotary_local(head_dim); - DISPATCH_head_dim(head_dim, HEAD_DIM, { - tensor_info_t info(qo_len, kv_len, num_qo_heads, num_kv_heads, kv_layout, HEAD_DIM); - - for (size_t qo_head_idx = 0; qo_head_idx < num_qo_heads; ++qo_head_idx) { - const size_t kv_head_idx = qo_head_idx / info.get_group_size(); - for (size_t q_idx = 0; q_idx < qo_len; ++q_idx) { - float max_val = -5e4; - - for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { - att[kv_idx] = 0.; - switch (pos_encoding_mode) { - case PosEncodingMode::kNone: { - for (size_t feat_idx = 0; feat_idx < head_dim; ++feat_idx) { - att[kv_idx] += fi::con::explicit_casting( - q[info.get_q_elem_offset(q_idx, qo_head_idx, feat_idx)]) * - fi::con::explicit_casting( - k[info.get_kv_elem_offset(kv_idx, kv_head_idx, feat_idx)]) * - sm_scale; - } - break; - } - default: { - std::ostringstream err_msg; - err_msg << "Unsupported rotary mode."; - FLASHINFER_ERROR(err_msg.str()); - } - } - max_val = std::max(max_val, att[kv_idx]); - } - // exp minus max - float denom = 0; - for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { - att[kv_idx] = std::exp(att[kv_idx] - max_val); - denom += att[kv_idx]; - } - - // divide by denom - for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { - att[kv_idx] /= denom; - } - } - } - }); - return std::move(att); -} -} // namespace - -template -void _TestComputeSFMCorrectness(size_t qo_len, size_t kv_len, size_t num_qo_heads, - size_t num_kv_heads, size_t head_dim, bool causal, - QKVLayout kv_layout, PosEncodingMode pos_encoding_mode, - bool use_fp16_qk_reduction, float rtol = 1e-3, float atol = 1e-3) { - std::vector q(qo_len * num_qo_heads * head_dim); - std::vector k(kv_len * num_kv_heads * head_dim); - std::vector v(kv_len * num_kv_heads * head_dim); - std::vector o(qo_len * num_qo_heads * head_dim); - - utils::generate_data(q); - utils::generate_data(k); - utils::generate_data(v); - utils::generate_data(o); - - DTypeQ* q_d; - FI_GPU_CALL(hipMalloc(&q_d, q.size() * sizeof(DTypeQ))); - FI_GPU_CALL(hipMemcpy(q_d, q.data(), q.size() * sizeof(DTypeQ), hipMemcpyHostToDevice)); - - DTypeKV* k_d; - FI_GPU_CALL(hipMalloc(&k_d, k.size() * sizeof(DTypeKV))); - FI_GPU_CALL(hipMemcpy(k_d, k.data(), k.size() * sizeof(DTypeKV), hipMemcpyHostToDevice)); - - DTypeKV* v_d; - FI_GPU_CALL(hipMalloc(&v_d, v.size() * sizeof(DTypeKV))); - FI_GPU_CALL(hipMemcpy(v_d, v.data(), v.size() * sizeof(DTypeKV), hipMemcpyHostToDevice)); - - DTypeO* o_d; - FI_GPU_CALL(hipMalloc(&o_d, o.size() * sizeof(DTypeO))); - FI_GPU_CALL(hipMemcpy(o_d, o.data(), o.size() * sizeof(DTypeO), hipMemcpyHostToDevice)); - - DTypeO* tmp_d; - FI_GPU_CALL(hipMalloc(&tmp_d, 16 * 1024 * 1024 * sizeof(DTypeO))); - - hipError_t status = flashinfer::SinglePrefillWithKVCache( - q_d, k_d, v_d, o_d, tmp_d, - /*lse=*/nullptr, num_qo_heads, num_kv_heads, qo_len, kv_len, head_dim, causal, kv_layout, - pos_encoding_mode, use_fp16_qk_reduction); - - EXPECT_EQ(status, hipSuccess) << "SinglePrefillWithKVCache kernel launch failed, error message: " - << hipGetErrorString(status); - - std::vector o_h(o.size()); - FI_GPU_CALL(hipMemcpy(o_h.data(), o_d, o_h.size() * sizeof(DTypeO), hipMemcpyDeviceToHost)); - - // Print the first 10 elements of the output vector for debugging - // std::cout << "Output vector (first 10 elements):"; - // std::cout << "[" << std::endl; - // for (int i = 0; i < 10; ++i) { - // std::cout << fi::con::explicit_casting(o_h[i]) << " "; - // } - // std::cout << "]" << std::endl; - - bool isEmpty = o_h.empty(); - EXPECT_EQ(isEmpty, false) << "Output vector is empty"; - - std::vector o_ref = test_compute_qk_and_softmax_cpu( - q, k, v, qo_len, kv_len, num_qo_heads, num_kv_heads, head_dim, causal, kv_layout, - pos_encoding_mode); - size_t num_results_error_atol = 0; - bool nan_detected = false; - - for (size_t i = 0; i < o_ref.size(); ++i) { - float o_h_val = fi::con::explicit_casting(o_h[i]); - float o_ref_val = fi::con::explicit_casting(o_ref[i]); - - if (isnan(o_h_val)) { - nan_detected = true; - } - - num_results_error_atol += (!utils::isclose(o_ref_val, o_h_val, rtol, atol)); - if (!utils::isclose(o_ref_val, o_h_val, rtol, atol)) { - std::cout << "i=" << i << ", o_ref[i]=" << o_ref_val << ", o_h[i]=" << o_h_val << std::endl; - } - } - - float result_accuracy = 1. - float(num_results_error_atol) / float(o_ref.size()); - std::cout << "num_qo_heads=" << num_qo_heads << ", num_kv_heads=" << num_kv_heads - << ", qo_len=" << qo_len << ", kv_len=" << kv_len << ", head_dim=" << head_dim - << ", causal=" << causal << ", kv_layout=" << QKVLayoutToString(kv_layout) - << ", pos_encoding_mode=" << PosEncodingModeToString(pos_encoding_mode) - << ", result_accuracy=" << result_accuracy << std::endl; - - EXPECT_GT(result_accuracy, 0.90) << "Result correctness test failed."; - EXPECT_FALSE(nan_detected) << "Nan detected in the result."; - - FI_GPU_CALL(hipFree(q_d)); - FI_GPU_CALL(hipFree(k_d)); - FI_GPU_CALL(hipFree(v_d)); - FI_GPU_CALL(hipFree(o_d)); - FI_GPU_CALL(hipFree(tmp_d)); -} - -int main(int argc, char** argv) { - using DTypeIn = __half; - using DTypeO = __half; - bool use_fp16_qk_reduction = false; - size_t qo_len = 399; - size_t kv_len = 533; - size_t num_heads = 1; - size_t head_dim = 64; - bool causal = false; - size_t pos_encoding_mode = 0; - size_t kv_layout = 0; - - _TestComputeSFMCorrectness( - qo_len, kv_len, num_heads, num_heads, head_dim, causal, QKVLayout(kv_layout), - PosEncodingMode(pos_encoding_mode), use_fp16_qk_reduction); -} diff --git a/libflashinfer/tests/hip/test_inplace_transpose_loads.cpp b/libflashinfer/tests/hip/test_inplace_transpose_loads.cpp deleted file mode 100644 index 98ee9aedf0..0000000000 --- a/libflashinfer/tests/hip/test_inplace_transpose_loads.cpp +++ /dev/null @@ -1,211 +0,0 @@ -/// 1. Allocate a 128x64 memory on CPU and init lexicographically to represent a -/// 128X64 matrix. -/// 2. Copy CPU array to global memory. -// 3 Copy global memory into LDS using produce_kv function. The LDS should -// also of be 128x64 elements -/// 4. Call transpose kernel that inplace transposes the LDS 128x64 matrix into -// a 64x128 matrix. Each warp handles multiple blocks of 16x16 chunks -/// 5. Post transposition copy back the 128x64 LDS linear memory to global and -/// then back to CPU. -/// 6. Evaluate the output is same as the transpose of the original array. - -#include -#include - -#include -#include - -#include "flashinfer/attention/generic/permuted_smem.cuh" -#include "flashinfer/attention/generic/prefill.cuh" -#include "gpu_iface/backend/hip/mma_hip.h" -#include "gpu_iface/gpu_runtime_compat.hpp" - -using namespace flashinfer; - -namespace { - -// Define matrix dimensions for the test -constexpr int MATRIX_ROWS = 128; -constexpr int MATRIX_COLS = 64; -constexpr uint32_t KV_THR_LAYOUT_ROW = 4; -constexpr uint32_t KV_THR_LAYOUT_COL = 16; -constexpr uint32_t NUM_WARPS = 4; -constexpr uint32_t NUM_MMA_KV = MATRIX_ROWS / 16; -constexpr uint32_t NUM_WARPS_Q = MATRIX_COLS / 16; -constexpr uint32_t NUM_MMA_D = 4; -constexpr uint32_t UPCAST_STRIDE = 64; -constexpr uint32_t VECTOR_BIT_WIDTH = 64; -constexpr uint32_t CTA_TILE_KV = NUM_MMA_KV * 4 * 16; - -using DTypeKV = __half; - -template -__device__ __forceinline__ void load_matrix_global_to_smem(uint32_t warp_idx, uint32_t lane_idx, - smem_t smem, - uint32_t* smem_offset, DTypeKV** gptr, - const uint32_t stride_n, - const uint32_t kv_idx_base, - const uint32_t kv_len) { - static_assert(NUM_MMA_KV * 4 % NUM_WARPS_Q == 0); - - uint32_t kv_idx = kv_idx_base + warp_idx * 4 + lane_idx / KV_THR_LAYOUT_ROW; - -#pragma unroll - for (uint32_t i = 0; i < NUM_MMA_KV * 4 / NUM_WARPS_Q; ++i) { -#pragma unroll - for (uint32_t j = 0; j < NUM_MMA_D / (8 / sizeof(DTypeKV)); ++j) { - smem.template load_vector_async(*smem_offset, *gptr, kv_idx < kv_len); - *smem_offset = smem.template advance_offset_by_column<16>(*smem_offset, j); - *gptr += 16 * upcast_size(); - } - kv_idx += NUM_WARPS * 4; - *smem_offset = smem.template advance_offset_by_row(*smem_offset) - - (sizeof(DTypeKV) * NUM_MMA_D * 2); - *gptr += NUM_WARPS * 4 * stride_n - - sizeof(DTypeKV) * NUM_MMA_D * 2 * upcast_size(); - } - *smem_offset -= CTA_TILE_KV * UPCAST_STRIDE; -} - -} // namespace - -// Helper to initialize matrix with lexicographic values -void initMatrixLexicographic(half* matrix, int rows, int cols) { - for (int i = 0; i < rows; ++i) { - for (int j = 0; j < cols; ++j) { - matrix[i * cols + j] = static_cast(i * cols + j); - } - } -} - -// Helper to transpose a matrix on CPU (for verification) -void transposeMatrixCPU(half* input, half* output, int rows, int cols) { - for (int i = 0; i < rows; ++i) { - for (int j = 0; j < cols; ++j) { - output[j * rows + i] = input[i * cols + j]; - } - } -} - -// Helper to print a matrix section (for debugging) -void printMatrixSection(half* matrix, int rows, int cols, const char* name) { - std::cout << "Matrix " << name << " (" << rows << "x" << cols << "):" << std::endl; - for (int i = 0; i < std::min(rows, 8); ++i) { - for (int j = 0; j < std::min(cols, 8); ++j) { - std::cout << static_cast(matrix[i * cols + j]) << " "; - } - std::cout << (cols > 8 ? "..." : "") << std::endl; - } - if (rows > 8) std::cout << "..." << std::endl; -} - -// Kernel to load the matrix from global to shared memory using produce_kv -__device__ __forceinline__ void loadGlobalToSharedKernel(__half* input, - smem_t v_smem, - int rows, int cols) { - const uint32_t tid = threadIdx.x; - const uint32_t lane_idx = tid % 64; - const uint32_t warp_idx = tid / 64; - - uint32_t smem_offset = - v_smem.template get_permuted_offset<64>(warp_idx * 4 + lane_idx / 16, lane_idx % 16); - - DTypeKV* input_ptr = input + (warp_idx * KV_THR_LAYOUT_ROW + lane_idx / KV_THR_LAYOUT_COL) * 64 + - +(lane_idx % KV_THR_LAYOUT_COL) * upcast_size(); - - // Load global memory to shared memory collaboratively - load_matrix_global_to_smem(warp_idx, lane_idx, v_smem, &smem_offset, - &input_ptr, cols, 0, rows); - - __syncthreads(); - - if (tid == 0) { - printf("\n DEBUG LDS loaded from global\n"); - auto hMem = reinterpret_cast<__half*>(v_smem.base); - uint32_t offset_r_debug; - // for (auto i = 0; i < rows; ++i) { - for (auto j = 0; j < 256; ++j) { - printf("%f ", float(hMem[j])); - } - printf("\n"); - //} - } - - // TODO: Store shared memory back to global memory for verification -} - -// Kernel to transpose shared memory in-place -__global__ void transposeSharedMemoryKernel(half* input, half* output, int rows, int cols) { - // Define shared memory for the matrix - extern __shared__ half shared_mem[]; - smem_t v_smem(shared_mem); - - // TODO: Load data from global to shared memory - loadGlobalToSharedKernel(input, v_smem, rows, cols); - - __syncthreads(); - - // TODO: Call transpose_4x4_half_registers to transpose in-place - - __syncthreads(); - - // TODO: Copy transposed data back to global memory -} - -TEST(InplaceTransposeTest, TestTransposeLDS) { - // 1. Allocate a 128x64 memory on CPU and init lexicographically - std::vector h_input(MATRIX_ROWS * MATRIX_COLS); - std::vector h_output(MATRIX_COLS * MATRIX_ROWS); - std::vector h_expected(MATRIX_COLS * MATRIX_ROWS); - - initMatrixLexicographic(h_input.data(), MATRIX_ROWS, MATRIX_COLS); - - for (auto i = 0; i < 32; ++i) { - std::cout << float(h_input[i]) << " "; - } - std::cout << std::endl; - - transposeMatrixCPU(h_input.data(), h_expected.data(), MATRIX_ROWS, MATRIX_COLS); - - // 2. Copy CPU array to global memory - half *d_input, *d_output; - FI_GPU_CALL(hipMalloc(&d_input, h_input.size() * sizeof(half))); - FI_GPU_CALL(hipMalloc(&d_output, h_output.size() * sizeof(half))); - FI_GPU_CALL( - hipMemcpy(d_input, h_input.data(), h_input.size() * sizeof(half), hipMemcpyHostToDevice)); - - // 3 & 4. Load into shared memory and transpose in-place - const int blockSize = 256; - const int gridSize = 1; - size_t sharedMemSize = MATRIX_ROWS * MATRIX_COLS * sizeof(half); - - // Single wave of four wavefronts - transposeSharedMemoryKernel<<>>(d_input, d_output, - MATRIX_ROWS, MATRIX_COLS); - - // 5. Copy back to CPU - FI_GPU_CALL( - hipMemcpy(h_output.data(), d_output, h_output.size() * sizeof(half), hipMemcpyDeviceToHost)); - - // 6. Verify the output matches the transpose of the original array - bool all_match = true; - for (int i = 0; i < MATRIX_COLS * MATRIX_ROWS; ++i) { - if (static_cast(h_output[i]) != static_cast(h_expected[i])) { - std::cout << "Mismatch at index " << i << ": " << static_cast(h_output[i]) << " vs " - << static_cast(h_expected[i]) << std::endl; - all_match = false; - if (i > 10) break; // Limit output - } - } - - EXPECT_TRUE(all_match) << "Transposed matrix doesn't match expected result"; - - // Clean up - FI_GPU_CALL(hipFree(d_input)); - FI_GPU_CALL(hipFree(d_output)); -} - -int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} diff --git a/libflashinfer/tests/hip/test_load_q_global_smem_v1.cpp b/libflashinfer/tests/hip/test_load_q_global_smem_v1.cpp deleted file mode 100644 index ffad3a5a47..0000000000 --- a/libflashinfer/tests/hip/test_load_q_global_smem_v1.cpp +++ /dev/null @@ -1,170 +0,0 @@ -#include - -#include -#include -#include -#include - -// Constants for MI300 -constexpr uint32_t WARP_STEP_SIZE = 16; // 16 threads per warp row -constexpr uint32_t QUERY_ELEMS_PER_THREAD = 4; // Each thread loads 4 fp16 elements -constexpr uint32_t WARP_THREAD_ROWS = 4; // 4 rows of threads in a warp - -// Simplified linear shared memory operations (CPU implementation) -template -uint32_t get_permuted_offset_linear(uint32_t row, uint32_t col) { - return row * stride + col; -} - -template -uint32_t advance_offset_by_column_linear(uint32_t offset, uint32_t step_idx) { - return offset + step_size; -} - -template -uint32_t advance_offset_by_row_linear(uint32_t offset) { - return offset + step_size * row_stride; -} - -// CPU-based offset pattern verification with configurable NUM_MMA_Q -template -void SimulateOffsetPattern(std::vector& thread_ids_at_offsets) { - // Constants derived from HEAD_DIM - constexpr uint32_t UPCAST_STRIDE_Q = HEAD_DIM / QUERY_ELEMS_PER_THREAD; - constexpr uint32_t NUM_MMA_D_QK = HEAD_DIM / 16; - constexpr uint32_t COLUMN_RESET_OFFSET = (NUM_MMA_D_QK / 4) * WARP_STEP_SIZE; - constexpr uint32_t grid_width = (HEAD_DIM / QUERY_ELEMS_PER_THREAD); // 16 for 64, 32 for 128 - constexpr uint32_t grid_height = 16 * NUM_MMA_Q; // 16 for NUM_MMA_Q=1, 32 for NUM_MMA_Q=2 - - // Initialize with -1 (unwritten) - thread_ids_at_offsets.assign(grid_height * grid_width, -1); - - // Simulate each thread - for (uint32_t tid = 0; tid < 64; tid++) { - uint32_t row = tid / WARP_STEP_SIZE; // 0-3 for 64 threads - uint32_t col = tid % WARP_STEP_SIZE; // 0-15 - - // Calculate initial offset using linear addressing - uint32_t q_smem_offset_w = get_permuted_offset_linear(row, col); - - // Main loop structure from load_q_global_smem - for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { - for (uint32_t j = 0; j < 4; ++j) { - // Calculate sequence index - const uint32_t seq_idx = row + mma_q * 16 + j; - - for (uint32_t mma_do = 0; mma_do < NUM_MMA_D_QK / 4; ++mma_do) { - // Record which thread wrote to this offset - if (q_smem_offset_w < grid_height * grid_width) { // Safety check - thread_ids_at_offsets[q_smem_offset_w] = tid; - } else { - printf("ERROR by tid: %d, offset: %d\n", tid, q_smem_offset_w); - } - - // Advance to next column within same row - q_smem_offset_w = - advance_offset_by_column_linear(q_smem_offset_w, mma_do); - } - - // Advance to next sequence (row) with adjustment back to first - // column - q_smem_offset_w = - advance_offset_by_row_linear(q_smem_offset_w) - - COLUMN_RESET_OFFSET; - } - } - } -} - -// Helper function to run the test with configurable NUM_MMA_Q -template -void RunOffsetTest() { - constexpr uint32_t grid_width = (HEAD_DIM / QUERY_ELEMS_PER_THREAD); // 16 for 64, 32 for 128 - constexpr uint32_t grid_height = 16 * NUM_MMA_Q; // 16 for NUM_MMA_Q=1, 32 for NUM_MMA_Q=2 - - printf( - "\n=== Testing offset calculations with HEAD_DIM = %u, NUM_MMA_Q = " - "%u ===\n", - HEAD_DIM, NUM_MMA_Q); - - // Host array to store thread IDs at each offset - std::vector thread_ids(grid_height * grid_width, -1); - - // Run CPU simulation of offset pattern - SimulateOffsetPattern(thread_ids); - - // Print the grid of thread IDs (potentially truncated for readability) - printf("Thread IDs writing to each offset (%dx%d grid):\n", grid_height, grid_width); - - // Column headers - printf(" "); - for (int c = 0; c < grid_width; c++) { - printf("%3d ", c); - if (c == 15 && grid_width > 16) printf("| "); // Divider between first and second half - } - printf("\n +"); - for (int c = 0; c < grid_width; c++) { - printf("----"); - if (c == 15 && grid_width > 16) printf("+"); // Divider between first and second half - } - printf("\n"); - - // Print quadrants with clear separation - for (int r = 0; r < grid_height; r++) { - printf("%2d | ", r); - for (int c = 0; c < grid_width; c++) { - int thread_id = thread_ids[r * grid_width + c]; - if (thread_id >= 0) { - printf("%3d ", thread_id); - } else { - printf(" . "); // Dot for unwritten positions - } - if (c == 15 && grid_width > 16) printf("| "); // Divider between first and second half - } - printf("\n"); - - // Add horizontal divider between first and second block of sequences - if (r == 15 && NUM_MMA_Q > 1) { - printf(" +"); - for (int c = 0; c < grid_width; c++) { - printf("----"); - if (c == 15 && grid_width > 16) printf("+"); // Intersection divider - } - printf("\n"); - } - } - - // Check for unwritten positions - int unwritten = 0; - for (int i = 0; i < grid_height * grid_width; i++) { - if (thread_ids[i] == -1) { - unwritten++; - } - } - - // Print statistics - printf("\nStatistics:\n"); - printf("- Positions written: %d/%d (%.1f%%)\n", grid_height * grid_width - unwritten, - grid_height * grid_width, - 100.0f * (grid_height * grid_width - unwritten) / (grid_height * grid_width)); - printf("- Unwritten positions: %d/%d (%.1f%%)\n", unwritten, grid_height * grid_width, - 100.0f * unwritten / (grid_height * grid_width)); - - // Validate full coverage - EXPECT_EQ(unwritten, 0) << "Not all positions were written"; -} - -// Original tests with NUM_MMA_Q = 1 -TEST(MI300OffsetTest, HeadDim64_NumMmaQ1) { RunOffsetTest<64, 1>(); } - -TEST(MI300OffsetTest, HeadDim128_NumMmaQ1) { RunOffsetTest<128, 1>(); } - -// New tests with NUM_MMA_Q = 2 -TEST(MI300OffsetTest, HeadDim64_NumMmaQ2) { RunOffsetTest<64, 2>(); } - -TEST(MI300OffsetTest, HeadDim128_NumMmaQ2) { RunOffsetTest<128, 2>(); } - -int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} diff --git a/libflashinfer/tests/hip/test_load_q_global_smem_v2.cpp b/libflashinfer/tests/hip/test_load_q_global_smem_v2.cpp deleted file mode 100644 index dc05a86b4c..0000000000 --- a/libflashinfer/tests/hip/test_load_q_global_smem_v2.cpp +++ /dev/null @@ -1,247 +0,0 @@ -// test_load_q_global_smem.cpp -#include -#include - -#include -#include -#include -#include - -// Include necessary headers -#include "flashinfer/attention/generic/default_prefill_params.cuh" -#include "flashinfer/attention/generic/prefill.cuh" -#include "flashinfer/attention/generic/variants.cuh" -#include "utils/cpu_reference_hip.h" -#include "utils/utils_hip.h" - -using namespace flashinfer; - -// CPU Reference Implementation for Q Loading -template -std::vector cpu_reference_q_smem_layout(const std::vector& q_global, size_t qo_len, - size_t num_qo_heads, size_t head_dim, - size_t q_stride_n, size_t q_stride_h, - size_t qo_packed_idx_base, uint32_t group_size, - size_t smem_height, size_t smem_width) { - std::vector q_smem_expected(smem_height * smem_width, DTypeQ(0)); - - // Simulate the loading pattern that load_q_global_smem should follow - for (size_t smem_row = 0; smem_row < smem_height; ++smem_row) { - uint32_t q_packed_idx = qo_packed_idx_base + smem_row; - uint32_t q_idx = q_packed_idx / group_size; // Sequence position - uint32_t r = q_packed_idx % group_size; // Head offset within group - - if (q_idx < qo_len) { - for (size_t feat_idx = 0; feat_idx < head_dim; ++feat_idx) { - // Calculate global memory offset - size_t global_offset = q_idx * q_stride_n + r * q_stride_h + feat_idx; - - // Place in shared memory layout (assuming linear layout for - // test) - size_t smem_offset = smem_row * smem_width + feat_idx; - if (global_offset < q_global.size()) { - q_smem_expected[smem_offset] = q_global[global_offset]; - } - } - } - } - - return q_smem_expected; -} - -uint_fastdiv create_group_size_div(uint32_t group_size) { return uint_fastdiv(group_size); } - -// Test kernel for Q loading -template -__global__ void test_q_loading_kernel(typename KTraits::DTypeQ* q_global, - typename KTraits::DTypeQ* q_smem_output, - uint32_t qo_packed_idx_base, uint32_t qo_len, - uint32_t q_stride_n, uint32_t q_stride_h, - uint_fastdiv group_size_div) { - // Set up shared memory - extern __shared__ uint8_t smem[]; - typename KTraits::SharedStorage& smem_storage = - reinterpret_cast(smem); - - smem_t q_smem(smem_storage.q_smem); - - // Call the function we're testing - load_q_global_smem(qo_packed_idx_base, qo_len, q_global, q_stride_n, q_stride_h, - group_size_div, &q_smem, threadIdx); - - // Synchronize to ensure loading is complete - __syncthreads(); - - if (threadIdx.y == 0 && threadIdx.z == 0) { - const uint32_t lane_idx = threadIdx.x; - constexpr uint32_t smem_height = KTraits::CTA_TILE_Q; // 16 - constexpr uint32_t smem_width = KTraits::HEAD_DIM_QK; // 64 - constexpr uint32_t total_elements = smem_height * smem_width; - - // Each thread copies using proper swizzled access - for (uint32_t linear_idx = lane_idx; linear_idx < total_elements; - linear_idx += KTraits::NUM_THREADS) { - if (linear_idx < total_elements) { - uint32_t row = linear_idx / smem_width; - uint32_t col = linear_idx % smem_width; - uint32_t swizzled_offset = q_smem.template get_permuted_offset< - smem_width / upcast_size()>( - row, col / upcast_size()); - uint32_t element_idx = - col % upcast_size(); - typename KTraits::DTypeQ* smem_ptr = - reinterpret_cast(q_smem.base + swizzled_offset); - q_smem_output[linear_idx] = smem_ptr[element_idx]; - } - } - } -} - -// Main test function -template -bool test_q_loading_correctness() { - std::cout << "Testing Q loading correctness with " << sizeof(DTypeQ) * 8 << "-bit precision..." - << std::endl; - - // Test parameters - small sizes for initial validation - constexpr size_t qo_len = 8; - constexpr size_t num_qo_heads = 8; - constexpr size_t num_kv_heads = 2; - constexpr size_t head_dim = 64; - constexpr uint32_t group_size = num_qo_heads / num_kv_heads; - - // Create test data with known pattern for easier debugging - const size_t q_size = qo_len * num_qo_heads * head_dim; - std::vector q_host(q_size); - - // Fill with simple pattern: row*1000 + col for easier validation - for (size_t i = 0; i < q_size; ++i) { - float val = float(i % 100) / 10.0f; // Values 0.0, 0.1, 0.2, ... 9.9 - q_host[i] = fi::con::explicit_casting(val); - } - - // GPU memory allocation - DTypeQ *q_device, *q_smem_output; - const size_t smem_elements = 16 * head_dim; // Single MMA block - FI_GPU_CALL(hipMalloc(&q_device, q_size * sizeof(DTypeQ))); - FI_GPU_CALL(hipMalloc(&q_smem_output, smem_elements * sizeof(DTypeQ))); - - FI_GPU_CALL(hipMemcpy(q_device, q_host.data(), q_size * sizeof(DTypeQ), hipMemcpyHostToDevice)); - - // Define kernel traits for CDNA3 - using KTraits = - KernelTraits>; - - // Launch parameters - dim3 block_size(64, 1, 1); // CDNA3: 64 threads per wavefront - dim3 grid_size(1, 1, 1); - size_t shared_mem_size = sizeof(typename KTraits::SharedStorage); - - // Test parameters - const uint32_t qo_packed_idx_base = 0; // Start from beginning - const uint32_t q_stride_n = num_qo_heads * head_dim; - const uint32_t q_stride_h = head_dim; - - std::cout << "Launching kernel with:" << std::endl; - std::cout << " Block size: " << block_size.x << "x" << block_size.y << "x" << block_size.z - << std::endl; - std::cout << " Shared memory: " << shared_mem_size << " bytes" << std::endl; - std::cout << " Q size: " << q_size << " elements" << std::endl; - - uint_fastdiv group_size_div = create_group_size_div(group_size); - - // Launch test kernel - test_q_loading_kernel<<>>( - q_device, q_smem_output, qo_packed_idx_base, qo_len, q_stride_n, q_stride_h, group_size_div); - - FI_GPU_CALL(hipDeviceSynchronize()); - - // Get results back - std::vector q_smem_actual(smem_elements); - FI_GPU_CALL(hipMemcpy(q_smem_actual.data(), q_smem_output, smem_elements * sizeof(DTypeQ), - hipMemcpyDeviceToHost)); - - // Generate CPU reference - std::vector q_smem_expected = - cpu_reference_q_smem_layout(q_host, qo_len, num_qo_heads, head_dim, q_stride_n, q_stride_h, - qo_packed_idx_base, group_size, 16, head_dim); - - // Compare results - bool passed = true; - float max_diff = 0.0f; - size_t mismatch_count = 0; - - std::cout << "\nValidation results:" << std::endl; - std::cout << "Comparing " << q_smem_actual.size() << " elements..." << std::endl; - - for (size_t i = 0; i < std::min(q_smem_actual.size(), q_smem_expected.size()); ++i) { - float actual = fi::con::explicit_casting(q_smem_actual[i]); - float expected = fi::con::explicit_casting(q_smem_expected[i]); - float diff = std::abs(actual - expected); - max_diff = std::max(max_diff, diff); - - if (!utils::isclose(q_smem_actual[i], q_smem_expected[i], 1e-3f, 1e-4f)) { - if (mismatch_count < 10) { // Show first 10 mismatches - size_t row = i / head_dim; - size_t col = i % head_dim; - std::cout << "Mismatch at [" << row << "][" << col << "] (index " << i << "): " - << "expected " << expected << ", got " << actual << ", diff " << diff - << std::endl; - } - mismatch_count++; - passed = false; - } - } - - std::cout << "Max difference: " << max_diff << std::endl; - std::cout << "Total mismatches: " << mismatch_count << " / " << q_smem_actual.size() << std::endl; - std::cout << "Q loading test: " << (passed ? "PASSED" : "FAILED") << std::endl; - - // Show some sample values for debugging - if (!passed) { - std::cout << "\nFirst 10 expected vs actual values:" << std::endl; - for (size_t i = 0; i < std::min(size_t(10), q_smem_actual.size()); ++i) { - float actual = fi::con::explicit_casting(q_smem_actual[i]); - float expected = fi::con::explicit_casting(q_smem_expected[i]); - std::cout << "[" << i << "] expected: " << expected << ", actual: " << actual << std::endl; - } - } - - // Cleanup - FI_GPU_CALL(hipFree(q_device)); - FI_GPU_CALL(hipFree(q_smem_output)); - - return passed; -} - -// Main function -int main() { - std::cout << "=== FlashInfer Q Loading Component Test ===" << std::endl; - std::cout << "Testing load_q_global_smem function for CDNA3 architecture" << std::endl; - - // Initialize HIP - hipError_t err = hipSetDevice(0); - if (err != hipSuccess) { - std::cout << "Failed to set HIP device: " << hipGetErrorString(err) << std::endl; - return 1; - } - - hipDeviceProp_t prop; - FI_GPU_CALL(hipGetDeviceProperties(&prop, 0)); - std::cout << "Running on: " << prop.name << std::endl; - - bool all_passed = true; - - // Test with half precision - std::cout << "\n--- Testing with FP16 ---" << std::endl; - all_passed &= test_q_loading_correctness<__half>(); - - if (all_passed) { - std::cout << "\n✅ All Q loading tests PASSED!" << std::endl; - return 0; - } else { - std::cout << "\n❌ Some Q loading tests FAILED!" << std::endl; - return 1; - } -} diff --git a/libflashinfer/tests/hip/test_produce_kv.cpp b/libflashinfer/tests/hip/test_produce_kv.cpp deleted file mode 100644 index d6b79d173f..0000000000 --- a/libflashinfer/tests/hip/test_produce_kv.cpp +++ /dev/null @@ -1,204 +0,0 @@ -#include - -#include -#include -#include -#include - -// Constants -constexpr uint32_t WARP_SIZE_NV = 32; -constexpr uint32_t WARP_SIZE_AMD = 64; -constexpr uint32_t WARP_STEP_SIZE = 16; // 16 threads per warp row for AMD -constexpr uint32_t WARP_THREAD_ROWS = 4; // 4 rows of threads in a warp for AMD - -// SwizzleMode enum to match the original code -enum class SwizzleMode { - k64B = 0U, // Original NVIDIA mode (32 threads, 8 rows x 4 columns) - k128B = 1U, // Original pseudo-128B mode (32 threads, 4 rows x 8 columns) - kLinear = 2U // New AMD-specific mode (64 threads, 4 rows x 16 columns) -}; - -// Simplified linear shared memory operations (CPU implementation) -template -uint32_t get_permuted_offset_linear(uint32_t row, uint32_t col) { - return row * stride + col; -} - -template -uint32_t advance_offset_by_column_linear(uint32_t offset, uint32_t step_idx) { - return offset + step_size; -} - -template -uint32_t advance_offset_by_row_linear(uint32_t offset) { - return offset + step_size * row_stride; -} - -// CPU-based simulation of produce_kv for AMD MI300 with linear offset -// addressing -template -void SimulateProduceKV(std::vector& thread_ids_at_offsets) { - // Constants for MI300 (64-thread warp, 4×16 thread layout) - constexpr uint32_t WARP_SIZE = 64; - constexpr uint32_t WARP_THREAD_ROWS = 4; // 4 rows of threads - constexpr uint32_t WARP_STEP_SIZE = 16; // 16 threads per row - constexpr uint32_t ELEMS_PER_THREAD = 4; // Each thread loads 4 fp16 elements - - // Derived constants - constexpr uint32_t UPCAST_STRIDE = HEAD_DIM / ELEMS_PER_THREAD; - constexpr uint32_t NUM_MMA_D = HEAD_DIM / 16; - constexpr uint32_t grid_width = HEAD_DIM / ELEMS_PER_THREAD; - constexpr uint32_t grid_height = 16 * NUM_MMA_KV; - constexpr uint32_t NUM_WARPS = 1; - constexpr uint32_t NUM_WARPS_Q = 1; - constexpr uint32_t COLUMN_RESET_OFFSET = (NUM_MMA_D / 4) * WARP_STEP_SIZE; - //(NUM_MMA_D / (4 / sizeof(uint16_t))) * WARP_STEP_SIZE; - - // Initialize with -1 (unwritten) - thread_ids_at_offsets.assign(grid_height * grid_width, -1); - - // Simulate each thread's write pattern - for (uint32_t tid = 0; tid < WARP_SIZE; tid++) { - uint32_t warp_idx = 0; // Always 0 for single warp - uint32_t lane_idx = tid; - - // Calculate thread's row and column - uint32_t row = lane_idx / WARP_STEP_SIZE; - uint32_t col = lane_idx % WARP_STEP_SIZE; - - // Calculate initial offset - uint32_t kv_smem_offset_w = - get_permuted_offset_linear(warp_idx * WARP_THREAD_ROWS + row, col); - - // Initial kv_idx points to the first row this thread handles - uint32_t kv_idx = warp_idx * WARP_THREAD_ROWS + row; - - // Handle all blocks of rows - for (uint32_t i = 0; i < NUM_MMA_KV * 4 / NUM_WARPS_Q; ++i) { - // Process columns within a row (each thread loads 4 elements per - // iteration) - // for (uint32_t j = 0; j < NUM_MMA_D / (4 / sizeof(uint16_t)); ++j) - // { - for (uint32_t j = 0; j < NUM_MMA_D / 4; ++j) { - // Record which thread writes to this offset - // if(tid == 0) { - // std::cout << "tid : " << tid << " kv_smem_offset_w at - // start " << kv_smem_offset_w << '\n'; - // } - if (kv_smem_offset_w < grid_height * grid_width && kv_idx < grid_height) { - thread_ids_at_offsets[kv_smem_offset_w] = tid; - } else { - std::cerr << "ERROR: Out of bound offset (" << kv_smem_offset_w << ") at " << tid << '\n'; - } - - // Advance to next column by 16 (number of threads per row) - kv_smem_offset_w = advance_offset_by_column_linear(kv_smem_offset_w, j); - // if(tid == 0) { - // std::cout << "tid : " << tid << " kv_smem_offset_w after - // column inc: " << kv_smem_offset_w << '\n'; - // } - } - - // Move to next set of rows - kv_idx += WARP_THREAD_ROWS; - - // if(tid == 0) { - // std::cout << "tid : " << tid << " kv_smem_offset_w before row - // inc " << kv_smem_offset_w << '\n'; - // } - // Reset column position and advance rows - kv_smem_offset_w = advance_offset_by_row_linear( - kv_smem_offset_w) - - COLUMN_RESET_OFFSET; - - // if(tid == 0) { - // std::cout << "tid : " << tid << " kv_smem_offset_w after row - // inc " << kv_smem_offset_w << '\n'; - // } - } - // FIXME: Verify with original in prefill.cuh - kv_smem_offset_w -= 16 * NUM_MMA_KV * UPCAST_STRIDE; - } -} - -// Helper function to run the test -template -void RunProduceKVTest() { - constexpr uint32_t grid_width = HEAD_DIM / 4; // 16 for 64, 32 for 128 - constexpr uint32_t grid_height = 16 * NUM_MMA_KV; // 16 for NUM_MMA_KV=1, 32 for NUM_MMA_KV=2 - - printf("\n=== Testing produce_kv with HEAD_DIM = %u, NUM_MMA_KV = %u ===\n", HEAD_DIM, - NUM_MMA_KV); - - // Host array to store thread IDs at each offset - std::vector thread_ids(grid_height * grid_width, -1); - - // Run CPU simulation of produce_kv - SimulateProduceKV(thread_ids); - - // Print the grid of thread IDs - printf("Thread IDs writing to each offset (%dx%d grid):\n", grid_height, grid_width); - - // Column headers - printf(" "); - for (int c = 0; c < std::min(32, (int)grid_width); c++) { - printf("%3d ", c); - if (c == 15 && grid_width > 16) printf("| "); - } - printf("\n +"); - for (int c = 0; c < std::min(32, (int)grid_width); c++) { - printf("----"); - if (c == 15 && grid_width > 16) printf("+"); - } - printf("\n"); - - // Print grid with clear separation - for (int r = 0; r < grid_height; r++) { - printf("%2d | ", r); - for (int c = 0; c < std::min(32, (int)grid_width); c++) { - int thread_id = thread_ids[r * grid_width + c]; - if (thread_id >= 0) { - printf("%3d ", thread_id); - } else { - printf(" . "); - } - if (c == 15 && grid_width > 16) printf("| "); - } - printf("\n"); - - // Add horizontal divider between blocks - if (r == 15 && NUM_MMA_KV > 1) { - printf(" +"); - for (int c = 0; c < std::min(32, (int)grid_width); c++) { - printf("----"); - if (c == 15 && grid_width > 16) printf("+"); - } - printf("\n"); - } - } - - // Check for unwritten positions - int unwritten = 0; - for (int i = 0; i < grid_height * grid_width; i++) { - if (thread_ids[i] == -1) { - unwritten++; - } - } - - // Print statistics - printf("\nStatistics:\n"); - printf("- Positions written: %d/%d (%.1f%%)\n", grid_height * grid_width - unwritten, - grid_height * grid_width, - 100.0f * (grid_height * grid_width - unwritten) / (grid_height * grid_width)); - printf("- Unwritten positions: %d/%d (%.1f%%)\n", unwritten, grid_height * grid_width, - 100.0f * unwritten / (grid_height * grid_width)); -} - -TEST(KVCacheWritePatternTest, HeadDim64_AMD_kLinear) { RunProduceKVTest<64, 1>(); } - -TEST(KVCacheWritePatternTest, HeadDim128_AMD_kLinear) { RunProduceKVTest<128, 1>(); } - -int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -}