diff --git a/libflashinfer/include/flashinfer/attention/generic/default_prefill_params.cuh b/libflashinfer/include/flashinfer/attention/generic/default_prefill_params.cuh new file mode 100644 index 0000000000..5fdcdf52c8 --- /dev/null +++ b/libflashinfer/include/flashinfer/attention/generic/default_prefill_params.cuh @@ -0,0 +1,395 @@ +/* + * Copyright (c) 2024 by FlashInfer team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef FLASHINFER_PREFILL_PARAMS_CUH_ +#define FLASHINFER_PREFILL_PARAMS_CUH_ + +#include +#include + +#include "gpu_iface/gpu_runtime_compat.hpp" +#include "page.cuh" + +namespace flashinfer { + +template +struct SinglePrefillParams { + using DTypeQ = DTypeQ_; + using DTypeKV = DTypeKV_; + using DTypeO = DTypeO_; + using IdType = int32_t; + DTypeQ* q; + DTypeKV* k; + DTypeKV* v; + uint8_t* maybe_custom_mask; + DTypeO* o; + float* lse; + float* maybe_alibi_slopes; + uint_fastdiv group_size; + uint32_t num_qo_heads; + uint32_t num_kv_heads; + uint32_t qo_len; + uint32_t kv_len; + uint32_t q_stride_n; + uint32_t q_stride_h; + uint32_t k_stride_n; + uint32_t k_stride_h; + uint32_t v_stride_n; + uint32_t v_stride_h; + uint32_t head_dim; + int32_t window_left; + float logits_soft_cap; + float sm_scale; + float rope_rcp_scale; + float rope_rcp_theta; + uint32_t debug_thread_id; + uint32_t debug_warp_id; + + uint32_t partition_kv; + + __host__ SinglePrefillParams() + : q(nullptr), + k(nullptr), + v(nullptr), + maybe_custom_mask(nullptr), + o(nullptr), + lse(nullptr), + maybe_alibi_slopes(nullptr), + group_size(), + qo_len(0), + kv_len(0), + num_qo_heads(0), + num_kv_heads(0), + q_stride_n(0), + q_stride_h(0), + k_stride_n(0), + k_stride_h(0), + v_stride_n(0), + v_stride_h(0), + head_dim(0), + window_left(0), + logits_soft_cap(0.0f), + sm_scale(0.0f), + rope_rcp_scale(0.0f), + rope_rcp_theta(0.0f), + partition_kv(false) {} + + __host__ SinglePrefillParams(DTypeQ* q, DTypeKV* k, DTypeKV* v, uint8_t* maybe_custom_mask, + DTypeO* o, float* lse, float* maybe_alibi_slopes, + uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t qo_len, + uint32_t kv_len, uint32_t q_stride_n, uint32_t q_stride_h, + uint32_t kv_stride_n, uint32_t kv_stride_h, uint32_t head_dim, + int32_t window_left, float logits_soft_cap, float sm_scale, + float rope_scale, float rope_theta, uint32_t debug_thread_id, + uint32_t debug_warp_id) + : q(q), + k(k), + v(v), + maybe_custom_mask(maybe_custom_mask), + o(o), + lse(lse), + maybe_alibi_slopes(maybe_alibi_slopes), + group_size(num_qo_heads / num_kv_heads), + num_qo_heads(num_qo_heads), + num_kv_heads(num_kv_heads), + qo_len(qo_len), + kv_len(kv_len), + q_stride_n(q_stride_n), + q_stride_h(q_stride_h), + k_stride_n(kv_stride_n), + k_stride_h(kv_stride_h), + v_stride_n(kv_stride_n), + v_stride_h(kv_stride_h), + head_dim(head_dim), + window_left(window_left), + logits_soft_cap(logits_soft_cap), + sm_scale(sm_scale), + rope_rcp_scale(1. / rope_scale), + rope_rcp_theta(1. / rope_theta), + debug_thread_id(debug_thread_id), + debug_warp_id(debug_warp_id), + partition_kv(false) {} + + __host__ __device__ __forceinline__ uint32_t get_qo_len(uint32_t batch_idx) const { + return qo_len; + } + + __host__ __device__ __forceinline__ uint32_t get_kv_len(uint32_t batch_idx) const { + return kv_len; + } +}; + +template +struct BatchPrefillRaggedParams { + using DTypeQ = DTypeQ_; + using DTypeKV = DTypeKV_; + using DTypeO = DTypeO_; + using IdType = IdType_; + + DTypeQ* q; + DTypeKV* k; + DTypeKV* v; + uint8_t* maybe_custom_mask; + IdType* q_indptr; + IdType* kv_indptr; + IdType* maybe_mask_indptr; + IdType* maybe_q_rope_offset; // maybe_q_rope_offset is only used for + // fused-rope attention + IdType* maybe_k_rope_offset; // maybe_k_rope_offset is only used for + // fused-rope attention + DTypeO* o; + float* lse; + float* maybe_alibi_slopes; + uint_fastdiv group_size; + uint32_t num_qo_heads; + uint32_t num_kv_heads; + uint32_t q_stride_n; + uint32_t q_stride_h; + uint32_t k_stride_n; + uint32_t k_stride_h; + uint32_t v_stride_n; + uint32_t v_stride_h; + int32_t window_left; + float logits_soft_cap; + float sm_scale; + float rope_rcp_scale; + float rope_rcp_theta; + + IdType* request_indices; + IdType* qo_tile_indices; + IdType* kv_tile_indices; + IdType* merge_indptr; + IdType* o_indptr; + IdType* kv_chunk_size_ptr; + bool* block_valid_mask; + uint32_t max_total_num_rows; + uint32_t* total_num_rows; + uint32_t padded_batch_size; + bool partition_kv; + + __host__ BatchPrefillRaggedParams() + : q(nullptr), + k(nullptr), + v(nullptr), + maybe_custom_mask(nullptr), + q_indptr(nullptr), + kv_indptr(nullptr), + maybe_mask_indptr(nullptr), + maybe_q_rope_offset(nullptr), + maybe_k_rope_offset(nullptr), + o(nullptr), + lse(nullptr), + maybe_alibi_slopes(nullptr), + group_size(), + num_qo_heads(0), + num_kv_heads(0), + q_stride_n(0), + q_stride_h(0), + k_stride_n(0), + k_stride_h(0), + v_stride_n(0), + v_stride_h(0), + window_left(0), + logits_soft_cap(0.0f), + sm_scale(0.0f), + rope_rcp_scale(0.0f), + rope_rcp_theta(0.0f), + request_indices(nullptr), + qo_tile_indices(nullptr), + kv_tile_indices(nullptr), + merge_indptr(nullptr), + o_indptr(nullptr), + kv_chunk_size_ptr(nullptr), + block_valid_mask(nullptr), + max_total_num_rows(0), + total_num_rows(nullptr), + padded_batch_size(0), + partition_kv(false) {} + + __host__ BatchPrefillRaggedParams(DTypeQ* q, DTypeKV* k, DTypeKV* v, uint8_t* maybe_custom_mask, + IdType* q_indptr, IdType* kv_indptr, IdType* maybe_mask_indptr, + IdType* maybe_q_rope_offset, IdType* maybe_k_rope_offset, + DTypeO* o, float* lse, float* maybe_alibi_slopes, + uint32_t num_qo_heads, uint32_t num_kv_heads, + uint32_t q_stride_n, uint32_t q_stride_h, uint32_t kv_stride_n, + uint32_t kv_stride_h, int32_t window_left, + float logits_soft_cap, float sm_scale, float rope_scale, + float rope_theta) + : q(q), + k(k), + v(v), + maybe_custom_mask(maybe_custom_mask), + q_indptr(q_indptr), + kv_indptr(kv_indptr), + maybe_mask_indptr(maybe_mask_indptr), + maybe_q_rope_offset(maybe_q_rope_offset), + maybe_k_rope_offset(maybe_k_rope_offset), + o(o), + lse(lse), + maybe_alibi_slopes(maybe_alibi_slopes), + group_size(num_qo_heads / num_kv_heads), + num_qo_heads(num_qo_heads), + num_kv_heads(num_kv_heads), + q_stride_n(q_stride_n), + q_stride_h(q_stride_h), + k_stride_n(kv_stride_n), + k_stride_h(kv_stride_h), + v_stride_n(kv_stride_n), + v_stride_h(kv_stride_h), + window_left(window_left), + logits_soft_cap(logits_soft_cap), + sm_scale(sm_scale), + rope_rcp_scale(1.f / rope_scale), + rope_rcp_theta(1.f / rope_theta), + request_indices(nullptr), + qo_tile_indices(nullptr), + kv_tile_indices(nullptr), + merge_indptr(nullptr), + o_indptr(nullptr), + kv_chunk_size_ptr(nullptr), + block_valid_mask(nullptr), + max_total_num_rows(0), + total_num_rows(nullptr), + padded_batch_size(0), + partition_kv(false) {} + + __host__ __device__ __forceinline__ uint32_t get_qo_len(uint32_t batch_idx) const { + return q_indptr[batch_idx + 1] - q_indptr[batch_idx]; + } + + __host__ __device__ __forceinline__ uint32_t get_kv_len(uint32_t batch_idx) const { + return kv_indptr[batch_idx + 1] - kv_indptr[batch_idx]; + } +}; + +template +struct BatchPrefillPagedParams { + using DTypeQ = DTypeQ_; + using DTypeKV = DTypeKV_; + using DTypeO = DTypeO_; + using IdType = IdType_; + + DTypeQ* q; + paged_kv_t paged_kv; + uint8_t* maybe_custom_mask; + IdType* q_indptr; + IdType* maybe_mask_indptr; + IdType* maybe_q_rope_offset; // maybe_q_rope_offset is only used for + // fused-rope attention + DTypeO* o; + float* lse; + float* maybe_alibi_slopes; + uint_fastdiv group_size; + uint32_t num_qo_heads; + IdType q_stride_n; + IdType q_stride_h; + int32_t window_left; + float logits_soft_cap; + float sm_scale; + float rope_rcp_scale; + float rope_rcp_theta; + + IdType* request_indices; + IdType* qo_tile_indices; + IdType* kv_tile_indices; + IdType* merge_indptr; + IdType* o_indptr; + bool* block_valid_mask; + IdType* kv_chunk_size_ptr; + uint32_t max_total_num_rows; + uint32_t* total_num_rows; + uint32_t padded_batch_size; + bool partition_kv; + + __host__ BatchPrefillPagedParams() + : q(nullptr), + paged_kv(), + maybe_custom_mask(nullptr), + q_indptr(nullptr), + maybe_mask_indptr(nullptr), + maybe_q_rope_offset(nullptr), + o(nullptr), + lse(nullptr), + maybe_alibi_slopes(nullptr), + group_size(), + num_qo_heads(0), + q_stride_n(0), + q_stride_h(0), + window_left(0), + logits_soft_cap(0.0f), + sm_scale(0.0f), + rope_rcp_scale(0.0f), + rope_rcp_theta(0.0f), + request_indices(nullptr), + qo_tile_indices(nullptr), + kv_tile_indices(nullptr), + merge_indptr(nullptr), + o_indptr(nullptr), + block_valid_mask(nullptr), + kv_chunk_size_ptr(nullptr), + max_total_num_rows(0), + total_num_rows(nullptr), + padded_batch_size(0), + partition_kv(false) {} + + __host__ BatchPrefillPagedParams(DTypeQ* q, paged_kv_t paged_kv, + uint8_t* maybe_custom_mask, IdType* q_indptr, + IdType* maybe_mask_indptr, IdType* maybe_q_rope_offset, + DTypeO* o, float* lse, float* maybe_alibi_slopes, + uint32_t num_qo_heads, IdType q_stride_n, IdType q_stride_h, + int32_t window_left, float logits_soft_cap, float sm_scale, + float rope_scale, float rope_theta) + : q(q), + paged_kv(paged_kv), + maybe_custom_mask(maybe_custom_mask), + q_indptr(q_indptr), + maybe_mask_indptr(maybe_mask_indptr), + maybe_q_rope_offset(maybe_q_rope_offset), + o(o), + lse(lse), + maybe_alibi_slopes(maybe_alibi_slopes), + group_size(num_qo_heads / paged_kv.num_heads), + num_qo_heads(num_qo_heads), + q_stride_n(q_stride_n), + q_stride_h(q_stride_h), + window_left(window_left), + logits_soft_cap(logits_soft_cap), + sm_scale(sm_scale), + rope_rcp_scale(1.f / rope_scale), + rope_rcp_theta(1.f / rope_theta), + request_indices(nullptr), + qo_tile_indices(nullptr), + kv_tile_indices(nullptr), + merge_indptr(nullptr), + o_indptr(nullptr), + block_valid_mask(nullptr), + kv_chunk_size_ptr(nullptr), + max_total_num_rows(0), + total_num_rows(nullptr), + padded_batch_size(0), + partition_kv(false) {} + + __host__ __device__ __forceinline__ uint32_t get_qo_len(uint32_t batch_idx) const { + return q_indptr[batch_idx + 1] - q_indptr[batch_idx]; + } + + __host__ __device__ __forceinline__ uint32_t get_kv_len(uint32_t batch_idx) const { + return paged_kv.get_length(batch_idx); + } +}; + +} // namespace flashinfer + +#endif // FLASHINFER_DECODE_PARAMS_CUH_ diff --git a/libflashinfer/include/flashinfer/attention/generic/dispatch.cuh b/libflashinfer/include/flashinfer/attention/generic/dispatch.cuh new file mode 100644 index 0000000000..abe0a3020e --- /dev/null +++ b/libflashinfer/include/flashinfer/attention/generic/dispatch.cuh @@ -0,0 +1,216 @@ +// SPDX - FileCopyrightText : 2023-2035 FlashInfer team. +// SPDX - FileCopyrightText : 2025 Advanced Micro Devices, Inc. +// +// SPDX - License - Identifier : Apache 2.0 + +#pragma once + +#include "gpu_iface/enums.hpp" +#include "gpu_iface/exception.h" + +#define DISPATCH_USE_FP16_QK_REDUCTION(use_fp16_qk_reduction, USE_FP16_QK_REDUCTION, ...) \ + if (use_fp16_qk_reduction) { \ + FLASHINFER_ERROR("FP16_QK_REDUCTION disabled at compile time"); \ + } else { \ + constexpr bool USE_FP16_QK_REDUCTION = false; \ + __VA_ARGS__ \ + } + +#define DISPATCH_NUM_MMA_Q(num_mma_q, NUM_MMA_Q, ...) \ + if (num_mma_q == 1) { \ + constexpr size_t NUM_MMA_Q = 1; \ + __VA_ARGS__ \ + } else if (num_mma_q == 2) { \ + constexpr size_t NUM_MMA_Q = 2; \ + __VA_ARGS__ \ + } else { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported num_mma_q: " << num_mma_q; \ + FLASHINFER_ERROR(err_msg.str()); \ + } + +#define DISPATCH_NUM_MMA_KV(max_mma_kv, NUM_MMA_KV, ...) \ + if (max_mma_kv >= 8) { \ + constexpr size_t NUM_MMA_KV = 8; \ + __VA_ARGS__ \ + } else if (max_mma_kv >= 4) { \ + constexpr size_t NUM_MMA_KV = 4; \ + __VA_ARGS__ \ + } else if (max_mma_kv >= 2) { \ + constexpr size_t NUM_MMA_KV = 2; \ + __VA_ARGS__ \ + } else if (max_mma_kv >= 1) { \ + constexpr size_t NUM_MMA_KV = 1; \ + __VA_ARGS__ \ + } else { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported max_mma_kv: " << max_mma_kv; \ + FLASHINFER_ERROR(err_msg.str()); \ + } + +#define DISPATCH_CTA_TILE_Q(cta_tile_q, CTA_TILE_Q, ...) \ + switch (cta_tile_q) { \ + case 128: { \ + constexpr uint32_t CTA_TILE_Q = 128; \ + __VA_ARGS__ \ + break; \ + } \ + case 64: { \ + constexpr uint32_t CTA_TILE_Q = 64; \ + __VA_ARGS__ \ + break; \ + } \ + case 16: { \ + constexpr uint32_t CTA_TILE_Q = 16; \ + __VA_ARGS__ \ + break; \ + } \ + default: { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported cta_tile_q: " << cta_tile_q; \ + FLASHINFER_ERROR(err_msg.str()); \ + } \ + } + +#define DISPATCH_GQA_GROUP_SIZE(group_size, GROUP_SIZE, ...) \ + if (group_size == 1) { \ + constexpr size_t GROUP_SIZE = 1; \ + __VA_ARGS__ \ + } else if (group_size == 2) { \ + constexpr size_t GROUP_SIZE = 2; \ + __VA_ARGS__ \ + } else if (group_size == 3) { \ + constexpr size_t GROUP_SIZE = 3; \ + __VA_ARGS__ \ + } else if (group_size == 4) { \ + constexpr size_t GROUP_SIZE = 4; \ + __VA_ARGS__ \ + } else if (group_size == 8) { \ + constexpr size_t GROUP_SIZE = 8; \ + __VA_ARGS__ \ + } else { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported group_size: " << group_size; \ + FLASHINFER_ERROR(err_msg.str()); \ + } + +#define DISPATCH_MASK_MODE(mask_mode, MASK_MODE, ...) \ + switch (mask_mode) { \ + case MaskMode::kNone: { \ + constexpr MaskMode MASK_MODE = MaskMode::kNone; \ + __VA_ARGS__ \ + break; \ + } \ + case MaskMode::kCausal: { \ + constexpr MaskMode MASK_MODE = MaskMode::kCausal; \ + __VA_ARGS__ \ + break; \ + } \ + case MaskMode::kCustom: { \ + constexpr MaskMode MASK_MODE = MaskMode::kCustom; \ + __VA_ARGS__ \ + break; \ + } \ + default: { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported mask_mode: " << int(mask_mode); \ + FLASHINFER_ERROR(err_msg.str()); \ + } \ + } + +// convert head_dim to compile-time constant +#define DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, ...) \ + switch (head_dim) { \ + case 64: { \ + constexpr size_t HEAD_DIM = 64; \ + __VA_ARGS__ \ + break; \ + } \ + case 128: { \ + constexpr size_t HEAD_DIM = 128; \ + __VA_ARGS__ \ + break; \ + } \ + case 256: { \ + constexpr size_t HEAD_DIM = 256; \ + __VA_ARGS__ \ + break; \ + } \ + case 512: { \ + constexpr size_t HEAD_DIM = 512; \ + __VA_ARGS__ \ + break; \ + } \ + default: { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported head_dim: " << head_dim; \ + FLASHINFER_ERROR(err_msg.str()); \ + } \ + } + +#define DISPATCH_POS_ENCODING_MODE(pos_encoding_mode, POS_ENCODING_MODE, ...) \ + switch (pos_encoding_mode) { \ + case PosEncodingMode::kNone: { \ + constexpr PosEncodingMode POS_ENCODING_MODE = PosEncodingMode::kNone; \ + __VA_ARGS__ \ + break; \ + } \ + case PosEncodingMode::kRoPELlama: { \ + constexpr PosEncodingMode POS_ENCODING_MODE = PosEncodingMode::kRoPELlama; \ + __VA_ARGS__ \ + break; \ + } \ + case PosEncodingMode::kALiBi: { \ + constexpr PosEncodingMode POS_ENCODING_MODE = PosEncodingMode::kALiBi; \ + __VA_ARGS__ \ + break; \ + } \ + default: { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported pos_encoding_mode: " << int(pos_encoding_mode); \ + FLASHINFER_ERROR(err_msg.str()); \ + } \ + } + +#define DISPATCH_ALIGNED_VEC_SIZE(aligned_vec_size, ALIGNED_VEC_SIZE, ...) \ + switch (aligned_vec_size) { \ + case 16: { \ + constexpr size_t ALIGNED_VEC_SIZE = 16; \ + __VA_ARGS__ \ + break; \ + } \ + case 8: { \ + constexpr size_t ALIGNED_VEC_SIZE = 8; \ + __VA_ARGS__ \ + break; \ + } \ + case 4: { \ + constexpr size_t ALIGNED_VEC_SIZE = 4; \ + __VA_ARGS__ \ + break; \ + } \ + case 2: { \ + constexpr size_t ALIGNED_VEC_SIZE = 2; \ + __VA_ARGS__ \ + break; \ + } \ + case 1: { \ + constexpr size_t ALIGNED_VEC_SIZE = 1; \ + __VA_ARGS__ \ + break; \ + } \ + default: { \ + std::ostringstream err_msg; \ + err_msg << "Unsupported aligned_vec_size: " << aligned_vec_size; \ + FLASHINFER_ERROR(err_msg.str()); \ + } \ + } + +#define DISPATCH_COMPUTE_CAP_DECODE_NUM_STAGES_SMEM(compute_capacity, NUM_STAGES_SMEM, ...) \ + if (compute_capacity.first >= 8) { \ + constexpr uint32_t NUM_STAGES_SMEM = 2; \ + __VA_ARGS__ \ + } else { \ + constexpr uint32_t NUM_STAGES_SMEM = 1; \ + __VA_ARGS__ \ + } diff --git a/libflashinfer/include/flashinfer/attention/generic/page.cuh b/libflashinfer/include/flashinfer/attention/generic/page.cuh index 28ed38bebc..14df54be80 100644 --- a/libflashinfer/include/flashinfer/attention/generic/page.cuh +++ b/libflashinfer/include/flashinfer/attention/generic/page.cuh @@ -262,7 +262,6 @@ __global__ void AppendPagedKVCacheKernel(paged_kv_t paged_kv, size_t append_k_stride_n, size_t append_k_stride_h, size_t append_v_stride_n, size_t append_v_stride_h) { uint32_t tx = threadIdx.x, ty = threadIdx.y; - uint32_t num_heads = paged_kv.num_heads; uint32_t head_idx = ty; uint32_t cta_id = blockIdx.x; uint32_t num_ctas = gridDim.x; diff --git a/libflashinfer/include/flashinfer/attention/generic/prefill.cuh b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh new file mode 100644 index 0000000000..b7e00555c8 --- /dev/null +++ b/libflashinfer/include/flashinfer/attention/generic/prefill.cuh @@ -0,0 +1,3011 @@ +// SPDX - FileCopyrightText : 2023-2025 FlashInfer team. +// SPDX - FileCopyrightText : 2025 Advanced Micro Devices, Inc. +// +// SPDX - License - Identifier : Apache - 2.0 +#ifndef FLASHINFER_PREFILL_CUH_ +#define FLASHINFER_PREFILL_CUH_ + +#include "gpu_iface/cooperative_groups.h" +#include "gpu_iface/fastdiv.cuh" +#include "gpu_iface/math_ops.hpp" +#include "gpu_iface/memory_ops.hpp" +#include "gpu_iface/mma_ops.hpp" +#include "gpu_iface/platform.hpp" +#include "gpu_iface/utils.cuh" + +#ifdef FP16_QK_REDUCTION_SUPPORTED +#include "../../fp16.h" +#endif +#include + +#include "cascade.cuh" +#include "dispatch.cuh" +#include "frag_layout_swizzle.cuh" +#include "page.cuh" +#include "permuted_smem.cuh" +#include "pos_enc.cuh" +#include "variants.cuh" + +#if Debug +#include "gpu_iface/backend/hip/mma_debug_utils_hip.h" +#endif + +namespace flashinfer { + +DEFINE_HAS_MEMBER(maybe_q_rope_offset) +DEFINE_HAS_MEMBER(maybe_k_rope_offset) + +namespace cg = flashinfer::gpu_iface::cg; +namespace memory = flashinfer::gpu_iface::memory; +namespace mma = gpu_iface::mma; + +using gpu_iface::vec_dtypes::vec_cast; +using mma::MMAMode; + +constexpr uint32_t WARP_SIZE = gpu_iface::kWarpSize; + +constexpr uint32_t get_num_warps_q(const uint32_t cta_tile_q) { + if (cta_tile_q > 16) { + return 4; + } else { + return 1; + } +} + +constexpr uint32_t get_num_warps_kv(const uint32_t cta_tile_kv) { + return 4 / get_num_warps_q(cta_tile_kv); +} + +constexpr uint32_t get_num_mma_q(const uint32_t cta_tile_q) { + if (cta_tile_q > 64) { + return 2; + } else { + return 1; + } +} + +template +struct SharedStorageQKVO { + union { + struct { + alignas(16) DTypeQ q_smem[CTA_TILE_Q * HEAD_DIM_QK]; + alignas(16) DTypeKV k_smem[CTA_TILE_KV * HEAD_DIM_QK]; + alignas(16) DTypeKV v_smem[CTA_TILE_KV * HEAD_DIM_VO]; + }; + struct { // NOTE(Zihao): synchronize attention states across warps + alignas( + 16) std::conditional_t cta_sync_o_smem; + alignas(16) std::conditional_t cta_sync_md_smem; + }; + alignas(16) DTypeO smem_o[CTA_TILE_Q * HEAD_DIM_VO]; + }; +}; + +template +struct KernelTraits { + static constexpr MaskMode MASK_MODE = MASK_MODE_; + static constexpr uint32_t NUM_MMA_Q = NUM_MMA_Q_; + static constexpr uint32_t NUM_MMA_KV = NUM_MMA_KV_; + static constexpr uint32_t NUM_MMA_D_QK = NUM_MMA_D_QK_; + static constexpr uint32_t NUM_MMA_D_VO = NUM_MMA_D_VO_; + static constexpr uint32_t NUM_WARPS_Q = NUM_WARPS_Q_; + static constexpr uint32_t NUM_WARPS_KV = NUM_WARPS_KV_; + static constexpr uint32_t NUM_WARPS = NUM_WARPS_Q * NUM_WARPS_KV; + static constexpr uint32_t HEAD_DIM_QK = NUM_MMA_D_QK * 16; + static constexpr uint32_t HEAD_DIM_VO = NUM_MMA_D_VO * 16; + static constexpr uint32_t CTA_TILE_Q = CTA_TILE_Q_; + static constexpr uint32_t CTA_TILE_KV = NUM_MMA_KV * NUM_WARPS_KV * 16; + static constexpr PosEncodingMode POS_ENCODING_MODE = POS_ENCODING_MODE_; + + using DTypeQ = DTypeQ_; + using DTypeKV = DTypeKV_; + using DTypeO = DTypeO_; + using DTypeQKAccum = DTypeQKAccum_; + using IdType = IdType_; + using AttentionVariant = AttentionVariant_; + +#if defined(PLATFORM_HIP_DEVICE) + static_assert(sizeof(DTypeKV_) != 1, "8-bit types not supported for CDNA3"); + + using SmemBasePtrTy = uint2; + static constexpr uint32_t NUM_THREADS = NUM_WARPS_Q * NUM_WARPS_KV * 64; + static constexpr uint32_t WARP_THREAD_ROWS = 4; + static constexpr uint32_t WARP_THREAD_COLS = 16; + static constexpr uint32_t HALF_ELEMS_PER_THREAD = 4; + static constexpr uint32_t INT32_ELEMS_PER_THREAD = 2; + static constexpr uint32_t VECTOR_BIT_WIDTH = HALF_ELEMS_PER_THREAD * 16; + // FIXME: Update with a proper swizzle pattern. Linear is used primarily + // for intial testing. + static constexpr SwizzleMode SWIZZLE_MODE_Q = SwizzleMode::kLinear; + static constexpr SwizzleMode SWIZZLE_MODE_KV = SwizzleMode::kLinear; + + // Presently we use 16x4 thread layout for all cases. + static constexpr uint32_t KV_THR_LAYOUT_ROW = WARP_THREAD_ROWS; + static constexpr uint32_t KV_THR_LAYOUT_COL = WARP_THREAD_COLS; + // FIXME: [The comment is not correct] The constant is defined based on the + // matrix layout of the "D/C" accumulator matrix in a D = A*B+C computation. + // On CDNA3 the D/C matrices are distributed as four 4x16 bands across the + // 64 threads. Each thread owns one element from four different rows. + static constexpr uint32_t NUM_ACCUM_ROWS_PER_THREAD = 4; + // Number of threads that collaboratively handle the same set of matrix rows + // in attention score computation and cross-warp synchronization. + // CUDA: 4 threads (each thread handles 2 elements from same row group) + // CDNA3: 16 threads (each thread handles 1 element from same row group) + static constexpr uint32_t THREADS_PER_BMATRIX_ROW_SET = 16; + // controls the indexing stride used in logits-related functions + // (logits_transform, logits_mask, and LSE writing). + static constexpr uint32_t LOGITS_INDEX_STRIDE = 4; +#else + using SmemBasePtrTy = uint4; + static constexpr uint32_t NUM_THREADS = NUM_WARPS_Q * NUM_WARPS_KV * 32; + constexpr uint32_t WARP_THREAD_ROWS = 4; + constexpr uint32_t WARP_THREAD_COLS = 8; + constexpr uint32_t HALF_ELEMS_PER_THREAD = 8; + constexpr uint32_t INT32_ELEMS_PER_THREAD = 4; + constexpr uint32_t VECTOR_BIT_WIDTH = HALF_ELEMS_PER_THREAD * 16; + + static constexpr SwizzleMode SWIZZLE_MODE_Q = SwizzleMode::k128B; + static constexpr SwizzleMode SWIZZLE_MODE_KV = + (sizeof(DTypeKV_) == 1 && HEAD_DIM_VO == 64) ? SwizzleMode::k64B : SwizzleMode::k128B; + static constexpr uint32_t KV_THR_LAYOUT_ROW = + SWIZZLE_MODE_KV == SwizzleMode::k128B ? WARP_THREAD_ROWS : WARP_THREAD_COLS; + static constexpr uint32_t KV_THR_LAYOUT_COL = SWIZZLE_MODE_KV == SwizzleMode::k128B ? 8 : 4; + + // The constant is defined based on the matrix layout of the "D/C" + // accumulator matrix in a D = A*B+C computation. On CUDA for + // m16n8k16 mma ops the D/C matrix is distributed as 4 8x8 block and each + // thread stores eight elements from two different rows. + // Refer: + // https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-fragment-mma-16816-i8-f8 + static constexpr uint32_t NUM_ACCUM_ROWS_PER_THREAD = 2; + static constexpr uint32_t THREADS_PER_BMATRIX_ROW_SET = 4; + static constexpr uint32_t LOGITS_INDEX_STRIDE = 8; +#endif + static constexpr uint32_t UPCAST_STRIDE_Q = + HEAD_DIM_QK / upcast_size(); + static constexpr uint32_t UPCAST_STRIDE_K = + HEAD_DIM_QK / upcast_size(); + static constexpr uint32_t UPCAST_STRIDE_V = + HEAD_DIM_VO / upcast_size(); + static constexpr uint32_t UPCAST_STRIDE_O = + HEAD_DIM_VO / upcast_size(); + + static constexpr bool IsInvalid() { + return ((NUM_MMA_D_VO < 4) || (NUM_MMA_D_VO == 4 && NUM_MMA_KV % 2 == 1) || + (POS_ENCODING_MODE == PosEncodingMode::kRoPELlama && NUM_MMA_D_VO > 4 && + NUM_MMA_D_VO % (2 * NUM_WARPS_Q) != 0) || + (NUM_MMA_Q * (8 * NUM_MMA_D_VO + 2 * sizeof(DTypeQKAccum) * NUM_MMA_KV) >= 256) || + (sizeof(DTypeKV) == 1 && NUM_MMA_KV * 2 % NUM_WARPS_Q != 0) || + (sizeof(DTypeKV) == 1 && POS_ENCODING_MODE == PosEncodingMode::kRoPELlama)); + } + + using SharedStorage = SharedStorageQKVO; +#ifdef FP16_QK_REDUCTION_SUPPORTED + template + static constexpr DT getNegInf() { + if constexpr (std::is_same::value) { + return std::bit_cast(fp16_ieee_from_fp32_value(-gpu_iface::math::inf)); + } else { + return static_cast(-gpu_iface::math::inf); + } + } + + static constexpr DTypeQKAccum MaskFillValue = + AttentionVariant::use_softmax ? getNegInf() : DTypeQKAccum(0.f); +#else + static_assert(!std::is_same::value, + "Set -DFP16_QK_REDUCTION_SUPPORTED and install boost_math " + "then recompile to support fp16 reduction"); + static constexpr DTypeQKAccum MaskFillValue = + AttentionVariant::use_softmax ? DTypeQKAccum(-gpu_iface::math::inf) : DTypeQKAccum(0.f); +#endif +}; + +namespace { + +template +__device__ __forceinline__ uint32_t get_warp_idx_q(const uint32_t tid_y = threadIdx.y) { + if constexpr (KTraits::NUM_WARPS_Q == 1) { + return 0; + } else { + return tid_y; + } +} + +template +__device__ __forceinline__ uint32_t get_warp_idx_kv(const uint32_t tid_z = threadIdx.z) { + if constexpr (KTraits::NUM_WARPS_KV == 1) { + return 0; + } else { + return tid_z; + } +} + +template +__device__ __forceinline__ uint32_t get_warp_idx(const uint32_t tid_y = threadIdx.y, + const uint32_t tid_z = threadIdx.z) { + return get_warp_idx_kv(tid_z) * KTraits::NUM_WARPS_Q + get_warp_idx_q(tid_y); +} + +/*! + * \brief Apply Llama style rotary embedding to two 16x16 fragments. + * \tparam T The data type of the input fragments. + * \param x_first_half First fragment x[offset:offset+16, j*16:(j+1)*16] + * \param x_second_half Second fragment x[offset:offset*16, + * j*16+d/2:(j+1)*16+d/2] + * \param rope_freq Rope frequency + * \param offset The offset of the first row in both fragments. + * \note The sin/cos computation is slow, especially for A100 GPUs which has low + * non tensor-ops flops, will optimize in the future. + */ +template +__device__ __forceinline__ void k_frag_apply_llama_rope(T* x_first_half, T* x_second_half, + const float* rope_freq, + const uint32_t kv_offset) { + static_assert(sizeof(T) == 2); +#pragma unroll + for (uint32_t reg_id = 0; reg_id < HALF_ELEMS_PER_THREAD; ++reg_id) { + float cos, sin, tmp; + // 0 1 | 2 3 + // --------- + // 4 5 | 6 7 + +#if defined(PLATFORM_HIP_DEVICE) + uint32_t i = reg_id / 2, j = reg_id % 2; +#else + uint32_t i = reg_id / 4, j = (reg_id % 4) / 2; +#endif + __sincosf(float(kv_offset + 8 * i) * rope_freq[2 * j + reg_id % 2], &sin, &cos); + tmp = x_first_half[reg_id]; + x_first_half[reg_id] = (tmp * cos - (float)x_second_half[reg_id] * sin); + x_second_half[reg_id] = ((float)x_second_half[reg_id] * cos + tmp * sin); + } +} + +template +__device__ __forceinline__ void q_frag_apply_llama_rope(T* x_first_half, T* x_second_half, + const float* rope_freq, + const uint32_t qo_packed_offset, + const uint_fastdiv group_size) { +#pragma unroll + for (uint32_t reg_id = 0; reg_id < HALF_ELEMS_PER_THREAD; ++reg_id) { + float cos, sin, tmp; +#if defined(PLATFORM_HIP_DEVICE) + uint32_t freq_idx = reg_id; + uint32_t position = qo_packed_offset; +#else + // 0 1 | 4 5 + // --------- + // 2 3 | 6 7 + uint32_t i = ((reg_id % 4) / 2), j = (reg_id / 4); + uint32_t freq_idx = 2 * j + reg_id % 2; + uint32_t position = qo_packed_offset + 8 * i; +#endif + __sincosf(float(position / group_size) * rope_freq[freq_idx], &sin, &cos); + tmp = x_first_half[reg_id]; + x_first_half[reg_id] = (tmp * cos - (float)x_second_half[reg_id] * sin); + x_second_half[reg_id] = ((float)x_second_half[reg_id] * cos + tmp * sin); + } +} + +template +__device__ __forceinline__ void q_frag_apply_llama_rope_with_pos(T* x_first_half, T* x_second_half, + const float* rope_freq, + const uint32_t qo_packed_offset, + const uint_fastdiv group_size, + const IdType* q_rope_offset) { + float pos[2] = {static_cast(q_rope_offset[qo_packed_offset / group_size]), + static_cast(q_rope_offset[(qo_packed_offset + 8) / group_size])}; +#pragma unroll + for (uint32_t reg_id = 0; reg_id < HALF_ELEMS_PER_THREAD; ++reg_id) { + float cos, sin, tmp; + // 0 1 | 4 5 + // --------- + // 2 3 | 6 7 +#if defined(PLATFORM_HIP_DEVICE) + const uint32_t i = reg_id / 2; + const uint32_t j = reg_id % 2; +#else + const uint32_t i = (reg_id % 4) / 2; + const uint32_t j = reg_id / 4; +#endif + __sincosf(pos[i] * rope_freq[2 * j + reg_id % 2], &sin, &cos); + tmp = x_first_half[reg_id]; + x_first_half[reg_id] = (tmp * cos - (float)x_second_half[reg_id] * sin); + x_second_half[reg_id] = ((float)x_second_half[reg_id] * cos + tmp * sin); + } +} + +template +__device__ __forceinline__ void produce_kv_impl_cuda_( + uint32_t warp_idx, uint32_t lane_idx, + smem_t smem, uint32_t* smem_offset, + typename KTraits::DTypeKV** gptr, const uint32_t stride_n, const uint32_t kv_idx_base, + const uint32_t kv_len) { + using DTypeKV = typename KTraits::DTypeKV; + constexpr uint32_t NUM_WARPS = KTraits::NUM_WARPS; + constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; + constexpr uint32_t NUM_WARPS_Q = KTraits::NUM_WARPS_Q; + constexpr uint32_t NUM_MMA_D = produce_v ? KTraits::NUM_MMA_D_VO : KTraits::NUM_MMA_D_QK; + constexpr uint32_t UPCAST_STRIDE = + produce_v ? KTraits::UPCAST_STRIDE_V : KTraits::UPCAST_STRIDE_K; + constexpr uint32_t VECTOR_BIT_WIDTH = KTraits::VECTOR_BIT_WIDTH; + + if constexpr (KTraits::SWIZZLE_MODE_KV == SwizzleMode::k128B) { + uint32_t kv_idx = kv_idx_base + warp_idx * 4 + lane_idx / 8; + // NOTE: NUM_MMA_KV * 4 / NUM_WARPS_Q = NUM_WARPS_KV * NUM_MMA_KV * 4 / num_warps + static_assert(NUM_MMA_KV * 4 % NUM_WARPS_Q == 0); +#pragma unroll + for (uint32_t i = 0; i < NUM_MMA_KV * 4 / NUM_WARPS_Q; ++i) { +#pragma unroll + for (uint32_t j = 0; j < NUM_MMA_D / (8 / sizeof(DTypeKV)); ++j) { + smem.template load_128b_async(*smem_offset, *gptr, kv_idx < kv_len); + *smem_offset = smem.template advance_offset_by_column<8>(*smem_offset, j); + *gptr += 8 * upcast_size(); + } + kv_idx += NUM_WARPS * 4; + *smem_offset = + smem.template advance_offset_by_row(*smem_offset) - + sizeof(DTypeKV) * NUM_MMA_D; + *gptr += NUM_WARPS * 4 * stride_n - + sizeof(DTypeKV) * NUM_MMA_D * upcast_size(); + } + *smem_offset -= KTraits::CTA_TILE_KV * UPCAST_STRIDE; + } else { + uint32_t kv_idx = kv_idx_base + warp_idx * 8 + lane_idx / 4; + // NOTE: NUM_MMA_KV * 2 / NUM_WARPS_Q = NUM_WARPS_KV * NUM_MMA_KV * 2 / + // num_warps + static_assert(NUM_MMA_KV * 2 % NUM_WARPS_Q == 0); +#pragma unroll + for (uint32_t i = 0; i < NUM_MMA_KV * 2 / NUM_WARPS_Q; ++i) { + smem.template load_128b_async(*smem_offset, *gptr, kv_idx < kv_len); + *smem_offset = + smem.template advance_offset_by_row(*smem_offset); + kv_idx += NUM_WARPS * 8; + *gptr += NUM_WARPS * 8 * stride_n; + } + *smem_offset -= KTraits::CTA_TILE_KV * UPCAST_STRIDE; + } +} + +template +__device__ __forceinline__ void produce_kv_impl_cdna3_( + uint32_t warp_idx, uint32_t lane_idx, + smem_t smem, uint32_t* smem_offset, + typename KTraits::DTypeKV** gptr, const uint32_t stride_n, const uint32_t kv_idx_base, + const uint32_t kv_len) { + static_assert(KTraits::SWIZZLE_MODE_KV == SwizzleMode::kLinear); + using DTypeKV = typename KTraits::DTypeKV; + constexpr uint32_t KV_THR_LAYOUT_COL = KTraits::KV_THR_LAYOUT_COL; // 16 + constexpr uint32_t NUM_WARPS = KTraits::NUM_WARPS; + constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; + constexpr uint32_t NUM_WARPS_Q = KTraits::NUM_WARPS_Q; + constexpr uint32_t NUM_MMA_D = produce_v ? KTraits::NUM_MMA_D_VO : KTraits::NUM_MMA_D_QK; + constexpr uint32_t UPCAST_STRIDE = + produce_v ? KTraits::UPCAST_STRIDE_V : KTraits::UPCAST_STRIDE_K; + constexpr uint32_t VECTOR_BIT_WIDTH = KTraits::VECTOR_BIT_WIDTH; + + // NOTE: NUM_MMA_KV*4/NUM_WARPS_Q = NUM_WARPS_KV*NUM_MMA_KV*4/num_warps + static_assert(NUM_MMA_KV * 4 % NUM_WARPS_Q == 0); + uint32_t kv_idx = kv_idx_base + warp_idx * 4 + lane_idx / KV_THR_LAYOUT_COL; + +#pragma unroll + for (uint32_t i = 0; i < NUM_MMA_KV * 4 / NUM_WARPS_Q; ++i) { +#pragma unroll + for (uint32_t j = 0; j < NUM_MMA_D / (8 / sizeof(DTypeKV)); ++j) { + smem.template load_vector_async(*smem_offset, *gptr, kv_idx < kv_len); + *smem_offset = smem.template advance_offset_by_column<16>(*smem_offset, j); + *gptr += 16 * upcast_size(); + } + kv_idx += NUM_WARPS * 4; + *smem_offset = smem.template advance_offset_by_row(*smem_offset) - + (sizeof(DTypeKV) * NUM_MMA_D * 2); + *gptr += NUM_WARPS * 4 * stride_n - + sizeof(DTypeKV) * NUM_MMA_D * 2 * upcast_size(); + } + *smem_offset -= KTraits::CTA_TILE_KV * UPCAST_STRIDE; +} + +/*! + * \brief Produce k/v fragments from global memory to shared memory. + * \tparam fill_mode The fill mode of the shared memory. + * \tparam NUM_MMA_D_VO The number of fragments in y dimension. + * \tparam NUM_MMA_KV The number of fragments in z dimension. + * \tparam num_warps The number of warps in the threadblock. + * \tparam T The data type of the input tensor. + * \param smem The shared memory to store kv fragments. + * \param gptr The global memory pointer. + * \param kv_idx_base The base kv index. + * \param kv_len The length of kv tensor. + */ +template +__device__ __forceinline__ void produce_kv( + smem_t smem, uint32_t* smem_offset, + typename KTraits::DTypeKV** gptr, const uint32_t stride_n, const uint32_t kv_idx_base, + const uint32_t kv_len, const dim3 tid = threadIdx) { + // NOTE: for fp8, this function doesn't work for head_dim = 64 at the moment + const uint32_t warp_idx = get_warp_idx(tid.y, tid.z), lane_idx = tid.x; + +#if defined(PLATFORM_HIP_DEVICE) + produce_kv_impl_cdna3_(warp_idx, lane_idx, smem, smem_offset, gptr, + stride_n, kv_idx_base, kv_len); +#elif defined(PLATFORM_CUDA_DEVICE) + produce_kv_impl_cuda_(warp_idx, lane_idx, smem, smem_offset, gptr, + stride_n, kv_idx_base, kv_len); +#endif +} + +template +__device__ __forceinline__ void page_produce_kv( + smem_t smem, uint32_t* smem_offset, + const paged_kv_t& paged_kv, + const uint32_t kv_idx_base, const size_t* thr_local_kv_offset, const uint32_t kv_len, + const dim3 tid = threadIdx) { + // NOTE: for fp8, this function doesn't work for head_dim = 64 at the moment + using DType = typename KTraits::DTypeKV; + constexpr SharedMemFillMode fill_mode = + produce_v ? SharedMemFillMode::kFillZero : SharedMemFillMode::kNoFill; + constexpr uint32_t NUM_WARPS = KTraits::NUM_WARPS; + constexpr uint32_t NUM_WARPS_Q = KTraits::NUM_WARPS_Q; + constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; + constexpr uint32_t NUM_MMA_D = produce_v ? KTraits::NUM_MMA_D_VO : KTraits::NUM_MMA_D_QK; + constexpr uint32_t UPCAST_STRIDE = + produce_v ? KTraits::UPCAST_STRIDE_V : KTraits::UPCAST_STRIDE_K; + constexpr uint32_t VECTOR_BIT_WIDTH = KTraits::VECTOR_BIT_WIDTH; + + const uint32_t warp_idx = get_warp_idx(tid.y, tid.z), lane_idx = tid.x; + if constexpr (KTraits::SWIZZLE_MODE_KV == SwizzleMode::k128B) { + uint32_t kv_idx = kv_idx_base + warp_idx * 4 + lane_idx / 8; + // NOTE: NUM_MMA_KV * 4/NUM_WARPS_Q=NUM_WARPS_KV*NUM_MMA_KV*4/num_warps + static_assert(NUM_MMA_KV * 4 % NUM_WARPS_Q == 0); +#pragma unroll + for (uint32_t i = 0; i < NUM_MMA_KV * 4 / NUM_WARPS_Q; ++i) { + DType* gptr = produce_v ? paged_kv.v_data + thr_local_kv_offset[i] + : paged_kv.k_data + thr_local_kv_offset[i]; +#pragma unroll + for (uint32_t j = 0; j < NUM_MMA_D / (8 / sizeof(DType)); ++j) { + smem.template load_vector_async(*smem_offset, gptr, kv_idx < kv_len); + *smem_offset = smem.template advance_offset_by_column<8>(*smem_offset, j); + gptr += 8 * upcast_size(); + } + kv_idx += NUM_WARPS * 4; + *smem_offset = + smem.template advance_offset_by_row(*smem_offset) - + sizeof(DType) * NUM_MMA_D; + } + *smem_offset -= KTraits::CTA_TILE_KV * UPCAST_STRIDE; + } else { + uint32_t kv_idx = kv_idx_base + warp_idx * 8 + lane_idx / 4; + // NOTE: NUM_MMA_KV * 2 / NUM_WARPS_Q = NUM_WARPS_KV * NUM_MMA_KV * 2 / num_warps + static_assert(NUM_MMA_KV * 2 % NUM_WARPS_Q == 0); +#pragma unroll + for (uint32_t i = 0; i < NUM_MMA_KV * 2 / NUM_WARPS_Q; ++i) { + DType* gptr = produce_v ? paged_kv.v_data + thr_local_kv_offset[i] + : paged_kv.k_data + thr_local_kv_offset[i]; + smem.template load_vector_async(*smem_offset, gptr, kv_idx < kv_len); + kv_idx += NUM_WARPS * 8; + *smem_offset = + smem.template advance_offset_by_row(*smem_offset); + } + *smem_offset -= KTraits::CTA_TILE_KV * UPCAST_STRIDE; + } +} + +template +__device__ __forceinline__ uint32_t get_feature_index(uint32_t mma_d, uint32_t lane_idx, + uint32_t j) { +#if defined(PLATFORM_HIP_DEVICE) + // CDNA3 A-matrix MMA tile to thread mapping for a 64-thread wavefront: + // Each group of 16 threads handles the same four consecutive features for + // different sequences: + // T0-T15: Features [0,1,2,3] for sequences 0-15 respectively + // T16-T31: Features [4,5,6,7] for sequences 0-15 respectively + // T32-T47: Features [8,9,10,11] for sequences 0-15 respectively + // T48-T63: Features [12,13,14,15] for sequences 0-15 respectively + // + uint32_t feature_index = (mma_d * 16 + (lane_idx / 4) + j) % (HEAD_DIM / 2); +#elif defined(PLATFORM_CUDA_DEVICE) + // CUDA A-matrix MMA tile to thread mapping for a 32 thread warp: + // Each group of four consecutive threads map four different features for + // the same sequence. + // T0: {0,1,8,9}, T1: {2,3,10,11}, T2: {4,5,12,13}, T3: {6,7,14,15} + // + // The pattern repeats across 8 rows with each row mapped to a set of four + // consecutive threads. + // row 0 --> T0, T1, T2, T3 + // row 1 --> T4, T5, T6, T7 + // ... + // row 7 --> T28, T29, T30, T31 + // The full data to thread mapping repeats again for the next set of 16 + // rows. Thereby, forming a 16x16 MMA tile dubdivided into four 8x8 + // quadrants. + uint32_t feature_index = + ((mma_d * 16 + (j / 2) * 8 + (lane_idx % 4) * 2 + (j % 2)) % (HEAD_DIM / 2)); +#endif + return feature_index; +} + +template +__device__ __forceinline__ void init_rope_freq(float (*rope_freq)[4], const float rope_rcp_scale, + const float rope_rcp_theta, + const uint32_t tid_x = threadIdx.x) { + constexpr uint32_t HEAD_DIM = KTraits::NUM_MMA_D_QK * 16; + const uint32_t lane_idx = tid_x; + +#pragma unroll + for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO / 2; ++mma_d) { +#pragma unroll + for (uint32_t j = 0; j < 4; ++j) { + uint32_t feature_index = get_feature_index(mma_d, lane_idx, j); + float freq_base = float(2 * feature_index) / float(HEAD_DIM); + rope_freq[mma_d][j] = rope_rcp_scale * __powf(rope_rcp_theta, freq_base); + } + } +} + +template +__device__ __forceinline__ void init_states( + typename KTraits::AttentionVariant variant, + float (*o_frag)[KTraits::NUM_MMA_D_VO][KTraits::HALF_ELEMS_PER_THREAD], + typename KTraits::DTypeQKAccum (*m)[KTraits::NUM_ACCUM_ROWS_PER_THREAD], + float (*d)[KTraits::NUM_ACCUM_ROWS_PER_THREAD]) { + constexpr uint32_t NUM_ACCUM_ROWS_PER_THREAD = KTraits::NUM_ACCUM_ROWS_PER_THREAD; +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { +#pragma unroll + for (uint32_t reg_id = 0; reg_id < KTraits::HALF_ELEMS_PER_THREAD; ++reg_id) { + o_frag[mma_q][mma_d][reg_id] = 0.f; + } + } + } + + if constexpr (variant.use_softmax) { +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t j = 0; j < NUM_ACCUM_ROWS_PER_THREAD; ++j) { + m[mma_q][j] = typename KTraits::DTypeQKAccum(-gpu_iface::math::inf); + d[mma_q][j] = 1.f; + } + } + } +} + +template +__device__ __forceinline__ void load_q_global_smem( + uint32_t packed_offset, const uint32_t qo_upper_bound, typename KTraits::DTypeQ* q_ptr_base, + const uint32_t q_stride_n, const uint32_t q_stride_h, const uint_fastdiv group_size, + smem_t* q_smem, + const dim3 tid = threadIdx) { + using DTypeQ = typename KTraits::DTypeQ; + constexpr uint32_t WARP_THREAD_COLS = KTraits::WARP_THREAD_COLS; + constexpr uint32_t WARP_THREAD_ROWS = KTraits::WARP_THREAD_ROWS; + constexpr uint32_t HALF_ELEMS_PER_THREAD = KTraits::HALF_ELEMS_PER_THREAD; + constexpr uint32_t NUM_MMA_D_QK = KTraits::NUM_MMA_D_QK; + constexpr uint32_t UPCAST_STRIDE_Q = KTraits::UPCAST_STRIDE_Q; + constexpr uint32_t VECTOR_BIT_WIDTH = KTraits::VECTOR_BIT_WIDTH; + +#if defined(PLATFORM_HIP_DEVICE) + constexpr uint32_t COLUMN_RESET_OFFSET = (NUM_MMA_D_QK / 4) * WARP_THREAD_COLS; +#else + constexpr uint32_t COLUMN_RESET_OFFSET = 2 * KTraits::NUM_MMA_D_QK; +#endif + + const uint32_t lane_idx = tid.x, warp_idx_x = get_warp_idx_q(tid.y); + uint32_t row = lane_idx / WARP_THREAD_COLS; + uint32_t col = lane_idx % WARP_THREAD_COLS; + + if (get_warp_idx_kv(tid.z) == 0) { + uint32_t q_smem_offset_w = q_smem->template get_permuted_offset( + warp_idx_x * KTraits::NUM_MMA_Q * 16 + row, col); + +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t j = 0; j < 2 * 2; ++j) { + uint32_t q, r; + group_size.divmod(packed_offset + row + mma_q * 16 + j * 4, q, r); + const uint32_t q_idx = q; + DTypeQ* q_ptr = q_ptr_base + q * q_stride_n + r * q_stride_h + + col * upcast_size(); +#pragma unroll + for (uint32_t mma_do = 0; mma_do < KTraits::NUM_MMA_D_QK / 4; ++mma_do) { + // load q fragment from gmem to smem + q_smem->template load_vector_async(q_smem_offset_w, q_ptr, + q_idx < qo_upper_bound); + q_smem_offset_w = + q_smem->template advance_offset_by_column(q_smem_offset_w, mma_do); + q_ptr += HALF_ELEMS_PER_THREAD * upcast_size(); + } + q_smem_offset_w = q_smem->template advance_offset_by_row( + q_smem_offset_w) - + COLUMN_RESET_OFFSET; + } + } + } +} + +template +__device__ __forceinline__ void q_smem_inplace_apply_rotary( + const uint32_t q_packed_idx, const uint32_t qo_len, const uint32_t kv_len, + const uint_fastdiv group_size, + smem_t* q_smem, + uint32_t* q_smem_offset_r, float (*rope_freq)[4], const dim3 tid = threadIdx) { + if (get_warp_idx_kv(tid.z) != 0) return; + + constexpr uint32_t UPCAST_STRIDE_Q = KTraits::UPCAST_STRIDE_Q; + constexpr uint32_t COL_ADVANCE_TO_NEXT = 16 / KTraits::HALF_ELEMS_PER_THREAD; + +#if defined(PLATFORM_HIP_DEVICE) + constexpr uint32_t COL_ADVANCE_TO_LAST_HALF = KTraits::NUM_MMA_D_QK * 2; +#elif defined(PLATFORM_CUDA_DEVICE) + constexpr uint32_t COL_ADVANCE_TO_LAST_HALF = KTraits::NUM_MMA_D_QK; +#endif + + const uint32_t lane_idx = tid.x; + uint32_t q_frag_local[2][KTraits::INT32_ELEMS_PER_THREAD]; + static_assert(KTraits::NUM_MMA_D_QK % 4 == 0, "NUM_MMA_D_QK must be a multiple of 4"); +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { + uint32_t q_smem_offset_r_first_half = *q_smem_offset_r; +#if defined(PLATFORM_HIP_DEVICE) + const uint32_t seq_id = + q_packed_idx + kv_len * group_size - qo_len * group_size + mma_q * 16 + lane_idx % 16; +#elif defined(PLATFORM_CUDA_DEVICE) + const uint32_t seq_id = + q_packed_idx + kv_len * group_size - qo_len * group_size + mma_q * 16 + lane_idx / 4; +#endif +#pragma unroll + for (uint32_t mma_di = 0; mma_di < KTraits::NUM_MMA_D_QK / 2; ++mma_di) { + q_smem->template load_fragment(q_smem_offset_r_first_half, q_frag_local[0]); + uint32_t q_smem_offset_r_last_half = + q_smem->template advance_offset_by_column( + q_smem_offset_r_first_half, 0); + q_smem->template load_fragment(q_smem_offset_r_last_half, q_frag_local[1]); + q_frag_apply_llama_rope( + (typename KTraits::DTypeQ*)q_frag_local[0], (typename KTraits::DTypeQ*)q_frag_local[1], + rope_freq[mma_di], seq_id, group_size); + q_smem->template store_fragment(q_smem_offset_r_last_half, q_frag_local[1]); + q_smem->template store_fragment(q_smem_offset_r_first_half, q_frag_local[0]); + q_smem_offset_r_first_half = q_smem->template advance_offset_by_column( + q_smem_offset_r_first_half, mma_di); + } + *q_smem_offset_r += 16 * UPCAST_STRIDE_Q; + } + *q_smem_offset_r -= KTraits::NUM_MMA_Q * 16 * UPCAST_STRIDE_Q; +} + +template +__device__ __forceinline__ void q_smem_inplace_apply_rotary_with_pos( + const uint32_t q_packed_idx_base, const typename KTraits::IdType* q_rope_offset, + smem_t* q_smem, + const uint_fastdiv group_size, uint32_t* q_smem_offset_r, float (*rope_freq)[4], + const dim3 tid = threadIdx) { + if (get_warp_idx_kv(tid.z) == 0) { + constexpr uint32_t UPCAST_STRIDE_Q = KTraits::UPCAST_STRIDE_Q; + const uint32_t lane_idx = tid.x; + uint32_t q_frag_local[2][KTraits::INT32_ELEMS_PER_THREAD]; + static_assert(KTraits::NUM_MMA_D_QK % 4 == 0, "NUM_MMA_D_QK must be a multiple of 4"); +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { + uint32_t q_smem_offset_r_first_half = *q_smem_offset_r; +#pragma unroll + for (uint32_t mma_di = 0; mma_di < KTraits::NUM_MMA_D_QK / 2; ++mma_di) { + q_smem->load_fragment(q_smem_offset_r_first_half, q_frag_local[0]); + uint32_t q_smem_offset_r_last_half = + q_smem->template advance_offset_by_column( + q_smem_offset_r_first_half, 0); + q_smem->load_fragment(q_smem_offset_r_last_half, q_frag_local[1]); + q_frag_apply_llama_rope_with_pos( + (typename KTraits::DTypeQ*)q_frag_local[0], (typename KTraits::DTypeQ*)q_frag_local[1], + rope_freq[mma_di], + q_packed_idx_base + mma_q * 16 + lane_idx / KTraits::THREADS_PER_BMATRIX_ROW_SET, + group_size, q_rope_offset); + q_smem->store_fragment(q_smem_offset_r_last_half, q_frag_local[1]); + q_smem->store_fragment(q_smem_offset_r_first_half, q_frag_local[0]); + q_smem_offset_r_first_half = + q_smem->template advance_offset_by_column<2>(q_smem_offset_r_first_half, mma_di); + } + *q_smem_offset_r += 16 * UPCAST_STRIDE_Q; + } + *q_smem_offset_r -= KTraits::NUM_MMA_Q * 16 * UPCAST_STRIDE_Q; + } +} + +template +__device__ __forceinline__ void k_smem_inplace_apply_rotary( + const uint32_t kv_idx_base, + smem_t* k_smem, + uint32_t* k_smem_offset_r, float (*rope_freq)[4], const dim3 tid = threadIdx) { + using DTypeKV = typename KTraits::DTypeKV; + static_assert(sizeof(DTypeKV) == 2); + constexpr uint32_t UPCAST_STRIDE_K = KTraits::UPCAST_STRIDE_K; + constexpr uint32_t THREADS_PER_BMATRIX_ROW_SET = KTraits::THREADS_PER_BMATRIX_ROW_SET; + constexpr uint32_t HALF_ELEMS_PER_THREAD = KTraits::HALF_ELEMS_PER_THREAD; + uint32_t k_frag_local[2][KTraits::INT32_ELEMS_PER_THREAD]; + const uint32_t lane_idx = tid.x; + if constexpr (KTraits::NUM_MMA_D_QK == 4 && KTraits::NUM_WARPS_Q == 4) { + static_assert(KTraits::NUM_WARPS_KV == 1); + const uint32_t warp_idx = get_warp_idx_q(tid.y); + // horizontal-axis: y + // vertical-axis: z + // | 1-16 | 16-32 | 32-48 | 48-64 | + // | 1-16 | warp_idx=0 | warp_idx=1 | warp_idx=0 | warp_idx=1 | + // | 16-32 | warp_idx=2 | warp_idx=3 | warp_idx=2 | warp_idx=3 | + static_assert(KTraits::NUM_MMA_KV % 2 == 0, + "when NUM_MMA_D_QK == 4, NUM_MMA_KV must be a multiple of 2"); + uint32_t kv_idx = kv_idx_base + (warp_idx / 2) * 16 + lane_idx / THREADS_PER_BMATRIX_ROW_SET; + *k_smem_offset_r = + (*k_smem_offset_r ^ (0x2 * (warp_idx % 2))) + (warp_idx / 2) * 16 * UPCAST_STRIDE_K; +#pragma unroll + for (uint32_t i = 0; i < KTraits::NUM_MMA_KV / 2; ++i) { + uint32_t k_smem_offset_r_first_half = *k_smem_offset_r; + uint32_t mma_di = (warp_idx % 2); + k_smem->load_fragment(k_smem_offset_r_first_half, k_frag_local[0]); + uint32_t k_smem_offset_r_last_half = + k_smem->template advance_offset_by_column<4>(k_smem_offset_r_first_half, 0); + k_smem->load_fragment(k_smem_offset_r_last_half, k_frag_local[1]); + k_frag_apply_llama_rope( + (DTypeKV*)k_frag_local[0], (DTypeKV*)k_frag_local[1], rope_freq[mma_di], kv_idx); + k_smem->store_fragment(k_smem_offset_r_last_half, k_frag_local[1]); + k_smem->store_fragment(k_smem_offset_r_first_half, k_frag_local[0]); + *k_smem_offset_r += 32 * UPCAST_STRIDE_K; + kv_idx += 32; + } + *k_smem_offset_r = (*k_smem_offset_r ^ (0x2 * (warp_idx % 2))) - + ((warp_idx / 2) + KTraits::NUM_MMA_KV) * 16 * UPCAST_STRIDE_K; + } else { + const uint32_t warp_idx_x = get_warp_idx_q(tid.y), + warp_idx_z = get_warp_idx_kv(tid.z); + static_assert(KTraits::NUM_MMA_D_QK % (2 * KTraits::NUM_WARPS_Q) == 0); + // horizontal axis: y + // vertical axis: z + // | (warp_idx_z, warp_idx_x) | 1-16 | 16-32 | 32-48 | 48-64 + // | ... | 1-16*NUM_MMA_KV | (0, 0) | (0, 1) | (0, 2) | (0, 3) + // | ... | 16*NUM_MMA_KV-32*NUM_MMA_KV | (1, 0) | (1, 1) | (1, 2) | (1, 3) + // | ... ... + uint32_t kv_idx = kv_idx_base + (warp_idx_z * KTraits::NUM_MMA_KV * 16) + + lane_idx / THREADS_PER_BMATRIX_ROW_SET; + *k_smem_offset_r = *k_smem_offset_r ^ (0x2 * warp_idx_x); +#pragma unroll + for (uint32_t i = 0; i < KTraits::NUM_MMA_KV; ++i) { + uint32_t k_smem_offset_r_first_half = *k_smem_offset_r; +#pragma unroll + for (uint32_t j = 0; j < KTraits::NUM_MMA_D_QK / (2 * KTraits::NUM_WARPS_Q); ++j) { + uint32_t mma_di = warp_idx_x + j * KTraits::NUM_WARPS_Q; + k_smem->load_fragment(k_smem_offset_r_first_half, k_frag_local[0]); + uint32_t k_smem_offset_r_last_half = + k_smem->template advance_offset_by_column( + k_smem_offset_r_first_half, 0); + k_smem->load_fragment(k_smem_offset_r_last_half, k_frag_local[1]); + k_frag_apply_llama_rope( + (DTypeKV*)k_frag_local[0], (DTypeKV*)k_frag_local[1], rope_freq[mma_di], kv_idx); + k_smem->store_fragment(k_smem_offset_r_last_half, k_frag_local[1]); + k_smem->store_fragment(k_smem_offset_r_first_half, k_frag_local[0]); + k_smem_offset_r_first_half = + k_smem->template advance_offset_by_column<2 * KTraits::NUM_WARPS_Q>( + k_smem_offset_r_first_half, mma_di); + } + *k_smem_offset_r += 16 * UPCAST_STRIDE_K; + kv_idx += 16; + } + *k_smem_offset_r = + (*k_smem_offset_r ^ (0x2 * warp_idx_x)) - KTraits::NUM_MMA_KV * 16 * UPCAST_STRIDE_K; + } +} + +template +__device__ __forceinline__ void compute_qk( + smem_t* q_smem, + uint32_t* q_smem_offset_r, + smem_t* k_smem, + uint32_t* k_smem_offset_r, + typename KTraits::DTypeQKAccum (*s_frag)[KTraits::NUM_MMA_KV][KTraits::HALF_ELEMS_PER_THREAD]) { + constexpr uint32_t UPCAST_STRIDE_Q = KTraits::UPCAST_STRIDE_Q; + constexpr uint32_t UPCAST_STRIDE_K = KTraits::UPCAST_STRIDE_K; + constexpr uint32_t QK_SMEM_COLUMN_ADVANCE = 16 / KTraits::HALF_ELEMS_PER_THREAD; + + uint32_t a_frag[KTraits::NUM_MMA_Q][KTraits::INT32_ELEMS_PER_THREAD], + b_frag[KTraits::INT32_ELEMS_PER_THREAD]; + // compute q*k^T +#pragma unroll + for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_QK; ++mma_d) { +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { + q_smem->load_fragment(*q_smem_offset_r, a_frag[mma_q]); + *q_smem_offset_r = + q_smem->template advance_offset_by_row<16, UPCAST_STRIDE_Q>(*q_smem_offset_r); + } + + *q_smem_offset_r = + q_smem->template advance_offset_by_column(*q_smem_offset_r, mma_d) - + KTraits::NUM_MMA_Q * 16 * UPCAST_STRIDE_Q; + +#pragma unroll + for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { + if constexpr (sizeof(typename KTraits::DTypeKV) == 1) { +#if defined(PLATFORM_HIP_DEVICE) + static_assert(false, "FP8 support not yet implemented for CDNA3"); +#else + uint32_t b_frag_f8[2]; + if (mma_d % 2 == 0) { + k_smem->ldmatrix_m8n8x4_left_half(*k_smem_offset_r, b_frag_f8); + } else { + k_smem->ldmatrix_m8n8x4_right_half(*k_smem_offset_r, b_frag_f8); + } + b_frag_f8[0] = frag_layout_swizzle_16b_to_8b(b_frag_f8[0]); + b_frag_f8[1] = frag_layout_swizzle_16b_to_8b(b_frag_f8[1]); + vec_cast::template cast<8>( + (typename KTraits::DTypeQ*)b_frag, (typename KTraits::DTypeKV*)b_frag_f8); +#endif + } else { + k_smem->load_fragment(*k_smem_offset_r, b_frag); + } + + *k_smem_offset_r = + k_smem->template advance_offset_by_row<16, UPCAST_STRIDE_K>(*k_smem_offset_r); + +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { + if constexpr (std::is_same_v) { + if (mma_d == 0) { + mma::mma_sync_m16n16k16_row_col_f16f16f32( + s_frag[mma_q][mma_kv], a_frag[mma_q], b_frag); + } else { + mma::mma_sync_m16n16k16_row_col_f16f16f32( + s_frag[mma_q][mma_kv], a_frag[mma_q], b_frag); + } + } + + else if (std::is_same_v) { +#if defined(PLATFORM_HIP_DEVICE) + static_assert(false, "FP16 DTypeQKAccum not yet implemented for CDNA3"); +#else + if (mma_d == 0) { + mma::mma_sync_m16n16k16_row_col_f16f16f16( + (uint32_t*)s_frag[mma_q][mma_kv], a_frag[mma_q], b_frag); + } else { + mma::mma_sync_m16n16k16_row_col_f16f16f16((uint32_t*)s_frag[mma_q][mma_kv], + a_frag[mma_q], b_frag); + } +#endif + } + } + } + if constexpr (sizeof(typename KTraits::DTypeKV) == 1) { + if (mma_d % 2 == 1) { + *k_smem_offset_r = k_smem->template advance_offset_by_column( + *k_smem_offset_r, mma_d / 2); + } + *k_smem_offset_r -= KTraits::NUM_MMA_KV * 16 * UPCAST_STRIDE_K; + } else { + *k_smem_offset_r = k_smem->template advance_offset_by_column( + *k_smem_offset_r, mma_d) - + KTraits::NUM_MMA_KV * 16 * UPCAST_STRIDE_K; + } + } + *q_smem_offset_r -= KTraits::NUM_MMA_D_QK * QK_SMEM_COLUMN_ADVANCE; + +#if defined(PLATFORM_HIP_DEVICE) + *k_smem_offset_r -= KTraits::NUM_MMA_D_QK * (QK_SMEM_COLUMN_ADVANCE); +#elif defined(PLATFORM_CUDA_DEVICE) + *k_smem_offset_r -= KTraits::NUM_MMA_D_QK * sizeof(typename KTraits::DTypeKV); +#endif +} + +template +__device__ __forceinline__ void logits_transform( + const Params& params, typename KTraits::AttentionVariant variant, const uint32_t batch_idx, + const uint32_t qo_packed_idx_base, const uint32_t kv_idx_base, const uint32_t qo_len, + const uint32_t kv_len, const uint_fastdiv group_size, + DTypeQKAccum (*s_frag)[KTraits::NUM_MMA_KV][KTraits::HALF_ELEMS_PER_THREAD], + const dim3 tid = threadIdx, const uint32_t kv_head_idx = blockIdx.z) { + constexpr uint32_t TPR = KTraits::THREADS_PER_BMATRIX_ROW_SET; + constexpr uint32_t NAPTR = KTraits::NUM_ACCUM_ROWS_PER_THREAD; + constexpr uint32_t LIS = KTraits::LOGITS_INDEX_STRIDE; + + const uint32_t lane_idx = tid.x; + uint32_t q[KTraits::NUM_MMA_Q][NAPTR], r[KTraits::NUM_MMA_Q][NAPTR]; + float logits = 0., logitsTransformed = 0.; + +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t j = 0; j < NAPTR; ++j) { + group_size.divmod(qo_packed_idx_base + mma_q * 16 + lane_idx / TPR + LIS * j, q[mma_q][j], + r[mma_q][j]); + } + } + +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { +#pragma unroll + for (uint32_t reg_id = 0; reg_id < KTraits::HALF_ELEMS_PER_THREAD; ++reg_id) { +#if defined(PLATFORM_HIP_DEVICE) + const uint32_t q_idx = q[mma_q][reg_id % NAPTR]; + const uint32_t qo_head_idx = kv_head_idx * group_size + r[mma_q][reg_id % NAPTR]; + const uint32_t kv_idx = kv_idx_base + mma_kv * 16 + (lane_idx % TPR); +#else + const uint32_t q_idx = q[mma_q][(reg_id % 4) / 2], kv_idx = kv_idx_base + mma_kv * 16 + + 2 * (lane_idx % 4) + + 8 * (reg_id / 4) + reg_id % 2; + const uint32_t qo_head_idx = kv_head_idx * group_size + r[mma_q][(reg_id % 4) / 2]; +#endif + +#ifdef FP16_QK_REDUCTION_SUPPORTED + if constexpr (std::is_same::value) { + logits = std::bit_cast(fp16_ieee_to_fp32_value(s_frag[mma_q][mma_kv][reg_id])); + } else if constexpr (!std::is_same::value) { + logits = s_frag[mma_q][mma_kv][reg_id]; + } +#else + static_assert(!std::is_same::value, + "Set -DFP16_QK_REDUCTION_SUPPORTED and install boost_math " + "then recompile to support fp16 reduction"); + logits = s_frag[mma_q][mma_kv][reg_id]; +#endif + logitsTransformed = variant.LogitsTransform(params, logits, batch_idx, q_idx, kv_idx, + qo_head_idx, kv_head_idx); +#if Debug1 + const uint32_t lane_idx = tid.x, warp_idx = get_warp_idx(tid.y, tid.z); + + if (warp_idx == 0 && lane_idx == 0) { + printf("logits : %f logitsTransformed: %f\n", float(logits), float(logitsTransformed)); + } +#endif +#ifdef FP16_QK_REDUCTION_SUPPORTED + if constexpr (std::is_same::value) { + s_frag[mma_q][mma_kv][reg_id] = + std::bit_cast(fp16_ieee_from_fp32_value(logitsTransformed)); + } else if constexpr (!std::is_same::value) { + s_frag[mma_q][mma_kv][reg_id] = logitsTransformed; + } +#else + s_frag[mma_q][mma_kv][reg_id] = logitsTransformed; +#endif + } + } + } +} + +template +__device__ __forceinline__ void logits_mask( + const Params& params, typename KTraits::AttentionVariant variant, const uint32_t batch_idx, + const uint32_t qo_packed_idx_base, const uint32_t kv_idx_base, const uint32_t qo_len, + const uint32_t kv_len, const uint32_t chunk_end, const uint_fastdiv group_size, + typename KTraits::DTypeQKAccum (*s_frag)[KTraits::NUM_MMA_KV][KTraits::HALF_ELEMS_PER_THREAD], + const dim3 tid = threadIdx, const uint32_t kv_head_idx = blockIdx.z) { + const uint32_t lane_idx = tid.x; + constexpr uint32_t NUM_MMA_Q = KTraits::NUM_MMA_Q; + constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; + constexpr MaskMode MASK_MODE = KTraits::MASK_MODE; + constexpr uint32_t TPR = KTraits::THREADS_PER_BMATRIX_ROW_SET; + constexpr uint32_t NAPTR = KTraits::NUM_ACCUM_ROWS_PER_THREAD; + constexpr uint32_t LIS = KTraits::LOGITS_INDEX_STRIDE; + + uint32_t q[NUM_MMA_Q][NAPTR], r[NUM_MMA_Q][NAPTR]; +#pragma unroll + for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t j = 0; j < NAPTR; ++j) { + group_size.divmod(qo_packed_idx_base + mma_q * 16 + lane_idx / TPR + LIS * j, q[mma_q][j], + r[mma_q][j]); + } + } + +#pragma unroll + for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t mma_kv = 0; mma_kv < NUM_MMA_KV; ++mma_kv) { +#pragma unroll + for (uint32_t reg_id = 0; reg_id < KTraits::HALF_ELEMS_PER_THREAD; ++reg_id) { +#if defined(PLATFORM_HIP_DEVICE) + const uint32_t q_idx = q[mma_q][(reg_id % NAPTR)]; + const uint32_t kv_idx = kv_idx_base + mma_kv * 16 + (lane_idx % TPR); + const uint32_t qo_head_idx = kv_head_idx * group_size + r[mma_q][(reg_id % NAPTR)]; +#else + const uint32_t q_idx = q[mma_q][(reg_id % 4) / 2], kv_idx = kv_idx_base + mma_kv * 16 + + 2 * (lane_idx % TPR) + + 8 * (reg_id / 4) + reg_id % 2; + const uint32_t qo_head_idx = kv_head_idx * group_size + r[mma_q][(reg_id % 4) / 2]; +#endif + const bool mask = + (!(MASK_MODE == MaskMode::kCausal + ? (kv_idx + qo_len > kv_len + q_idx || (kv_idx >= chunk_end)) + : kv_idx >= chunk_end)) && + variant.LogitsMask(params, batch_idx, q_idx, kv_idx, qo_head_idx, kv_head_idx); + s_frag[mma_q][mma_kv][reg_id] = + (mask) ? s_frag[mma_q][mma_kv][reg_id] : (KTraits::MaskFillValue); + } + } + } +} + +template +__device__ __forceinline__ void update_mdo_states( + typename KTraits::AttentionVariant variant, + typename KTraits::DTypeQKAccum (*s_frag)[KTraits::NUM_MMA_KV][KTraits::HALF_ELEMS_PER_THREAD], + float (*o_frag)[KTraits::NUM_MMA_D_VO][KTraits::HALF_ELEMS_PER_THREAD], + typename KTraits::DTypeQKAccum (*m)[KTraits::NUM_ACCUM_ROWS_PER_THREAD], + float (*d)[KTraits::NUM_ACCUM_ROWS_PER_THREAD], uint32_t warp_idx = 0, uint32_t lane_idx = 0) { + using DTypeQKAccum = typename KTraits::DTypeQKAccum; + using AttentionVariant = typename KTraits::AttentionVariant; + constexpr uint32_t NUM_ACCUM_ROWS_PER_THREAD = KTraits::NUM_ACCUM_ROWS_PER_THREAD; + constexpr bool use_softmax = AttentionVariant::use_softmax; + + if constexpr (use_softmax) { + const float sm_scale = variant.sm_scale_log2; + if constexpr (std::is_same_v) { +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t j = 0; j < NUM_ACCUM_ROWS_PER_THREAD; ++j) { + float m_prev = m[mma_q][j]; +#if defined(PLATFORM_HIP_DEVICE) +#pragma unroll + for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { + m[mma_q][j] = max(m[mma_q][j], s_frag[mma_q][mma_kv][j]); + } + // Butterfly reduction across all threads in the band + m[mma_q][j] = max(m[mma_q][j], gpu_iface::math::shfl_xor_sync(m[mma_q][j], 0x8)); + m[mma_q][j] = max(m[mma_q][j], gpu_iface::math::shfl_xor_sync(m[mma_q][j], 0x4)); + m[mma_q][j] = max(m[mma_q][j], gpu_iface::math::shfl_xor_sync(m[mma_q][j], 0x2)); + m[mma_q][j] = max(m[mma_q][j], gpu_iface::math::shfl_xor_sync(m[mma_q][j], 0x1)); + + float o_scale = gpu_iface::math::ptx_exp2(m_prev * sm_scale - m[mma_q][j] * sm_scale); + d[mma_q][j] *= o_scale; + +#pragma unroll + for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { + o_frag[mma_q][mma_d][j] *= o_scale; + } +#pragma unroll + for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { + s_frag[mma_q][mma_kv][j] = gpu_iface::math::ptx_exp2( + s_frag[mma_q][mma_kv][j] * sm_scale - m[mma_q][j] * sm_scale); + } +#elif (PLATFORM_CUDA_DEVICE) +#pragma unroll + for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { + float m_local = + max(max(s_frag[mma_q][mma_kv][j * 2 + 0], s_frag[mma_q][mma_kv][j * 2 + 1]), + max(s_frag[mma_q][mma_kv][j * 2 + 4], s_frag[mma_q][mma_kv][j * 2 + 5])); + m[mma_q][j] = max(m[mma_q][j], m_local); + } + m[mma_q][j] = max(m[mma_q][j], gpu_iface::math::shfl_xor_sync(m[mma_q][j], 0x2)); + m[mma_q][j] = max(m[mma_q][j], gpu_iface::math::shfl_xor_sync(m[mma_q][j], 0x1)); + + float o_scale = gpu_iface::math::ptx_exp2(m_prev * sm_scale - m[mma_q][j] * sm_scale); + d[mma_q][j] *= o_scale; +#pragma unroll + for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { + o_frag[mma_q][mma_d][j * 2 + 0] *= o_scale; + o_frag[mma_q][mma_d][j * 2 + 1] *= o_scale; + o_frag[mma_q][mma_d][j * 2 + 4] *= o_scale; + o_frag[mma_q][mma_d][j * 2 + 5] *= o_scale; + } +#pragma unroll + for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { + s_frag[mma_q][mma_kv][j * 2 + 0] = gpu_iface::math::ptx_exp2( + s_frag[mma_q][mma_kv][j * 2 + 0] * sm_scale - m[mma_q][j] * sm_scale); + s_frag[mma_q][mma_kv][j * 2 + 1] = gpu_iface::math::ptx_exp2( + s_frag[mma_q][mma_kv][j * 2 + 1] * sm_scale - m[mma_q][j] * sm_scale); + s_frag[mma_q][mma_kv][j * 2 + 4] = gpu_iface::math::ptx_exp2( + s_frag[mma_q][mma_kv][j * 2 + 4] * sm_scale - m[mma_q][j] * sm_scale); + s_frag[mma_q][mma_kv][j * 2 + 5] = gpu_iface::math::ptx_exp2( + s_frag[mma_q][mma_kv][j * 2 + 5] * sm_scale - m[mma_q][j] * sm_scale); + } +#endif + } + } + } else if constexpr (std::is_same_v) { +#if defined(PLATFORM_HIP_DEVICE) + static_assert(false, "Half precision accumulator not yet implemented for AMD"); +#else + const half2 sm_scale = __float2half2_rn(variant.sm_scale_log2); +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { + half m_prev[2]; +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + m_prev[j] = m[mma_q][j]; +#pragma unroll + for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { + half2 m_local = gpu_iface::math::hmax2(*(half2*)&s_frag[mma_q][mma_kv][j * 2], + *(half2*)&s_frag[mma_q][mma_kv][j * 2 + 4]); + m[mma_q][j] = __hmax(m[mma_q][j], __hmax(m_local.x, m_local.y)); + } + } + *(half2*)&m[mma_q] = gpu_iface::math::hmax2( + *(half2*)&m[mma_q], gpu_iface::math::shfl_xor_sync(*(half2*)&m[mma_q], 0x2)); + *(half2*)&m[mma_q] = gpu_iface::math::hmax2( + *(half2*)&m[mma_q], gpu_iface::math::shfl_xor_sync(*(half2*)&m[mma_q], 0x1)); +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + float o_scale = + gpu_iface::math::ptx_exp2(float(m_prev[j] * sm_scale.x - m[mma_q][j] * sm_scale.x)); + d[mma_q][j] *= o_scale; +#pragma unroll + for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { + o_frag[mma_q][mma_d][j * 2 + 0] *= o_scale; + o_frag[mma_q][mma_d][j * 2 + 1] *= o_scale; + o_frag[mma_q][mma_d][j * 2 + 4] *= o_scale; + o_frag[mma_q][mma_d][j * 2 + 5] *= o_scale; + } + half2 m2 = make_half2(m[mma_q][j], m[mma_q][j]); +#pragma unroll + for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { + *(half2*)&s_frag[mma_q][mma_kv][j * 2] = gpu_iface::math::ptx_exp2( + *(half2*)&s_frag[mma_q][mma_kv][j * 2] * sm_scale - m2 * sm_scale); + *(half2*)&s_frag[mma_q][mma_kv][j * 2 + 4] = gpu_iface::math::ptx_exp2( + *(half2*)&s_frag[mma_q][mma_kv][j * 2 + 4] * sm_scale - m2 * sm_scale); + } + } + } +#endif + } + } +} + +template +__device__ __forceinline__ void compute_sfm_v( + smem_t* v_smem, + uint32_t* v_smem_offset_r, + typename KTraits::DTypeQKAccum (*s_frag)[KTraits::NUM_MMA_KV][KTraits::HALF_ELEMS_PER_THREAD], + float (*o_frag)[KTraits::NUM_MMA_D_VO][KTraits::HALF_ELEMS_PER_THREAD], + float (*d)[KTraits::NUM_ACCUM_ROWS_PER_THREAD], const dim3 tid = threadIdx, + typename KTraits::DTypeQKAccum* qk_scratch = nullptr) { + constexpr uint32_t UPCAST_STRIDE_V = KTraits::UPCAST_STRIDE_V; + constexpr uint32_t HALF_ELEMS_PER_THREAD = KTraits::HALF_ELEMS_PER_THREAD; + constexpr uint32_t INT32_ELEMS_PER_THREAD = KTraits::INT32_ELEMS_PER_THREAD; + constexpr uint32_t V_SMEM_COLUMN_ADVANCE = 16 / KTraits::HALF_ELEMS_PER_THREAD; + typename KTraits::DTypeQ s_frag_f16[KTraits::NUM_MMA_Q][KTraits::NUM_MMA_KV] + [HALF_ELEMS_PER_THREAD]; + +#if defined(PLATFORM_HIP_DEVICE) +#if Debug + // Print S fragment BEFORE transpose (in B/C/D layout: 128x64) + flashinfer::gpu_iface::debug_utils::hip::write_s_frag_to_lds< + typename KTraits::DTypeQKAccum, KTraits::NUM_MMA_Q, KTraits::NUM_MMA_KV, + KTraits::NUM_ACCUM_ROWS_PER_THREAD>(s_frag, qk_scratch, tid); + flashinfer::gpu_iface::debug_utils::hip::print_lds_array( + qk_scratch, KTraits::CTA_TILE_Q, KTraits::CTA_TILE_KV, + "S frag BEFORE transpose (B/C/D layout, 128x64)"); +#endif + // In-place transposition of the s_frag MMA tile to get the data into CDNA3 A-matrix layout. + mma::transpose_mma_tile(reinterpret_cast(s_frag)); +#if Debug + // Print S fragment AFTER transpose (in A-matrix layout: 64x128) + flashinfer::gpu_iface::debug_utils::hip::write_s_frag_to_lds< + typename KTraits::DTypeQKAccum, KTraits::NUM_MMA_Q, KTraits::NUM_MMA_KV, + KTraits::NUM_ACCUM_ROWS_PER_THREAD>(s_frag, qk_scratch, tid); + flashinfer::gpu_iface::debug_utils::hip::print_lds_array( + qk_scratch, KTraits::CTA_TILE_KV, KTraits::CTA_TILE_Q, + "S frag AFTER transpose (A-matrix layout, 64x128)"); +#endif +#endif + + if constexpr (std::is_same_v) { +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { + vec_cast::template cast( + s_frag_f16[mma_q][mma_kv], s_frag[mma_q][mma_kv]); + } + } + } + + if constexpr (KTraits::AttentionVariant::use_softmax) { +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { + if constexpr (std::is_same_v) { + mma::m16k16_rowsum_f16f16f32(d[mma_q], s_frag_f16[mma_q][mma_kv]); + } else { +#if defined(PLATFORM_HIP_DEVICE) + static_assert(!std::is_same_v, + "FP16 reduction path not implemented for CDNA3"); +#else + mma::m16k16_rowsum_f16f16f32(d[mma_q], s_frag[mma_q][mma_kv]); +#endif + } + } + } + } + +#if Debug1 + // Print d values after update_mdo_states + flashinfer::gpu_iface::debug_utils::hip::write_d_to_lds( + d, qk_scratch, tid); + flashinfer::gpu_iface::debug_utils::hip::print_lds_array_1d( + qk_scratch, KTraits::CTA_TILE_Q, "--- d values after rowsum inside compute_sfm_v ---"); +#endif + +#pragma unroll + for (uint32_t mma_kv = 0; mma_kv < KTraits::NUM_MMA_KV; ++mma_kv) { +#pragma unroll + for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { + uint32_t b_frag[INT32_ELEMS_PER_THREAD]; + if constexpr (sizeof(typename KTraits::DTypeKV) == 1) { +#if defined(PLATFORM_HIP_DEVICE) + static_assert(false, "FP8 V path not implemented for CDNA3 yet"); +#else + uint32_t b_frag_f8[2]; + if (mma_d % 2 == 0) { + v_smem->ldmatrix_m8n8x4_trans_left_half(*v_smem_offset_r, b_frag_f8); + } else { + v_smem->ldmatrix_m8n8x4_trans_right_half(*v_smem_offset_r, b_frag_f8); + } + b_frag_f8[0] = frag_layout_swizzle_16b_to_8b_trans(b_frag_f8[0]); + b_frag_f8[1] = frag_layout_swizzle_16b_to_8b_trans(b_frag_f8[1]); + vec_cast::template cast<8>( + (typename KTraits::DTypeQ*)b_frag, (typename KTraits::DTypeKV*)b_frag_f8); + swap(b_frag[1], b_frag[2]); +#endif + } else { +#if defined(PLATFORM_HIP_DEVICE) + v_smem->load_fragment_and_quad_transpose(*v_smem_offset_r, b_frag); +#else + v_smem->ldmatrix_m8n8x4_trans(*v_smem_offset_r, b_frag); +#endif + } +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { + if constexpr (std::is_same_v) { + mma::mma_sync_m16n16k16_row_col_f16f16f32( + o_frag[mma_q][mma_d], (uint32_t*)s_frag_f16[mma_q][mma_kv], b_frag); + } else { + mma::mma_sync_m16n16k16_row_col_f16f16f32( + o_frag[mma_q][mma_d], (uint32_t*)s_frag[mma_q][mma_kv], b_frag); + } + } + if constexpr (sizeof(typename KTraits::DTypeKV) == 1) { + if (mma_d % 2 == 1) { + *v_smem_offset_r = v_smem->template advance_offset_by_column( + *v_smem_offset_r, mma_d / 2); + } + } else { + *v_smem_offset_r = v_smem->template advance_offset_by_column( + *v_smem_offset_r, mma_d); + } + } + *v_smem_offset_r = + v_smem->template advance_offset_by_row<16, UPCAST_STRIDE_V>(*v_smem_offset_r) - + sizeof(typename KTraits::DTypeKV) * KTraits::NUM_MMA_D_VO; + } + *v_smem_offset_r -= 16 * KTraits::NUM_MMA_KV * UPCAST_STRIDE_V; +} + +template +__device__ __forceinline__ void normalize_d( + float (*o_frag)[KTraits::NUM_MMA_D_VO][KTraits::HALF_ELEMS_PER_THREAD], + typename KTraits::DTypeQKAccum (*m)[KTraits::NUM_ACCUM_ROWS_PER_THREAD], + float (*d)[KTraits::NUM_ACCUM_ROWS_PER_THREAD]) { + using AttentionVariant = typename KTraits::AttentionVariant; + constexpr uint32_t NAPTR = KTraits::NUM_ACCUM_ROWS_PER_THREAD; + + if constexpr (AttentionVariant::use_softmax) { + float d_rcp[KTraits::NUM_MMA_Q][KTraits::NUM_ACCUM_ROWS_PER_THREAD]; + // compute reciprocal of d +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t j = 0; j < KTraits::NUM_ACCUM_ROWS_PER_THREAD; ++j) { + d_rcp[mma_q][j] = (m[mma_q][j] != typename KTraits::DTypeQKAccum(-gpu_iface::math::inf)) + ? gpu_iface::math::ptx_rcp(d[mma_q][j]) + : 0.f; + } + } + +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { +#pragma unroll + for (uint32_t reg_id = 0; reg_id < KTraits::HALF_ELEMS_PER_THREAD; ++reg_id) { +#if defined(PLATFORM_HIP_DEVICE) + o_frag[mma_q][mma_d][reg_id] = + o_frag[mma_q][mma_d][reg_id] * d_rcp[mma_q][reg_id % NAPTR]; +#else + o_frag[mma_q][mma_d][reg_id] = + o_frag[mma_q][mma_d][reg_id] * d_rcp[mma_q][(reg_id % 4) / 2]; +#endif + } + } + } + } +} + +template +__device__ __forceinline__ void finalize_m( + typename KTraits::AttentionVariant variant, + typename KTraits::DTypeQKAccum (*m)[KTraits::NUM_ACCUM_ROWS_PER_THREAD]) { + if constexpr (variant.use_softmax) { +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t j = 0; j < KTraits::NUM_ACCUM_ROWS_PER_THREAD; ++j) { + if (m[mma_q][j] != typename KTraits::DTypeQKAccum(-gpu_iface::math::inf)) { + m[mma_q][j] *= variant.sm_scale_log2; + } + } + } + } +} + +/*! + * \brief Synchronize the states of the MDO kernel across the threadblock along threadIdx.z. + */ +template +__device__ __forceinline__ void threadblock_sync_mdo_states( + float (*o_frag)[KTraits::NUM_MMA_D_VO][KTraits::HALF_ELEMS_PER_THREAD], + typename KTraits::SharedStorage* smem_storage, + typename KTraits::DTypeQKAccum (*m)[KTraits::NUM_ACCUM_ROWS_PER_THREAD], + float (*d)[KTraits::NUM_ACCUM_ROWS_PER_THREAD], const uint32_t warp_idx, + const uint32_t lane_idx, const dim3 tid = threadIdx) { + constexpr uint32_t TPR = KTraits::THREADS_PER_BMATRIX_ROW_SET; + constexpr uint32_t NARPT = KTraits::NUM_ACCUM_ROWS_PER_THREAD; + + static_assert(WARP_SIZE % TPR == 0, "THREADS_PER_BMATRIX_ROW_SET must divide WARP_SIZE"); + constexpr uint32_t GROUPS_PER_WARP = WARP_SIZE / TPR; + const uint32_t lane_group_idx = lane_idx / TPR; + + // only necessary when blockDim.z > 1 + if constexpr (KTraits::NUM_WARPS_KV > 1) { + float* smem_o = smem_storage->cta_sync_o_smem; + float2* smem_md = smem_storage->cta_sync_md_smem; + // o: [num_warps, + // NUM_MMA_Q, + // NUM_MMA_D_VO, + // WARP_SIZE, + // HALF_ELEMS_PER_THREAD] + // md: [num_warps, NUM_MMA_Q, 16, 2 (m/d)] +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { + vec_t::memcpy( + smem_o + (((warp_idx * KTraits::NUM_MMA_Q + mma_q) * KTraits::NUM_MMA_D_VO + mma_d) * + WARP_SIZE + + lane_idx) * + KTraits::HALF_ELEMS_PER_THREAD, + o_frag[mma_q][mma_d]); + } + } + + if constexpr (KTraits::AttentionVariant::use_softmax) { +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t j = 0; j < NARPT; ++j) { + smem_md[((warp_idx * KTraits::NUM_MMA_Q + mma_q) * NARPT + j) * GROUPS_PER_WARP + + lane_group_idx] = make_float2(float(m[mma_q][j]), d[mma_q][j]); + } + } + + // synchronize m,d first + __syncthreads(); +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { + float o_scale[NARPT][KTraits::NUM_WARPS_KV]; +#pragma unroll + for (uint32_t j = 0; j < NARPT; ++j) { + float m_new = -gpu_iface::math::inf, d_new = 1.f; +#pragma unroll + for (uint32_t i = 0; i < KTraits::NUM_WARPS_KV; ++i) { + float2 md = smem_md[(((i * KTraits::NUM_WARPS_Q + get_warp_idx_q(tid.y)) * + KTraits::NUM_MMA_Q + + mma_q) * + NARPT + + j) * + GROUPS_PER_WARP + + lane_group_idx]; + float m_prev = m_new, d_prev = d_new; + m_new = max(m_new, md.x); + d_new = d_prev * gpu_iface::math::ptx_exp2(m_prev - m_new) + + md.y * gpu_iface::math::ptx_exp2(md.x - m_new); + } + +#pragma unroll + for (uint32_t i = 0; i < KTraits::NUM_WARPS_KV; ++i) { + float2 md = smem_md[(((i * KTraits::NUM_WARPS_Q + get_warp_idx_q(tid.y)) * + KTraits::NUM_MMA_Q + + mma_q) * + NARPT + + j) * + GROUPS_PER_WARP + + lane_group_idx]; + float mi = md.x; + o_scale[j][i] = gpu_iface::math::ptx_exp2(float(mi - m_new)); + } + m[mma_q][j] = typename KTraits::DTypeQKAccum(m_new); + d[mma_q][j] = d_new; + } + +#pragma unroll + for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { + vec_t o_new; + o_new.fill(0.f); +#pragma unroll + for (uint32_t i = 0; i < KTraits::NUM_WARPS_KV; ++i) { + vec_t oi; + oi.load(smem_o + ((((i * KTraits::NUM_WARPS_Q + get_warp_idx_q(tid.y)) * + KTraits::NUM_MMA_Q + + mma_q) * + KTraits::NUM_MMA_D_VO + + mma_d) * + WARP_SIZE + + lane_idx) * + KTraits::HALF_ELEMS_PER_THREAD); + +#pragma unroll + for (uint32_t reg_id = 0; reg_id < KTraits::HALF_ELEMS_PER_THREAD; ++reg_id) { +#if defined(PLATFORM_HIP_DEVICE) + // CDNA3: Direct mapping - each reg_id corresponds to one accumulator row + o_new[reg_id] += oi[reg_id] * o_scale[reg_id][i]; +#else + // CUDA: Grouped mapping - 2 elements per accumulator row + o_new[reg_id] += oi[reg_id] * o_scale[(reg_id % 4) / 2][i]; +#endif + } + } + o_new.store(o_frag[mma_q][mma_d]); + } + } + } else { + // synchronize m,d first + __syncthreads(); +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { + vec_t o_new; + o_new.fill(0.f); +#pragma unroll + for (uint32_t i = 0; i < KTraits::NUM_WARPS_KV; ++i) { + vec_t oi; + oi.load(smem_o + ((((i * KTraits::NUM_WARPS_Q + get_warp_idx_q(tid.y)) * + KTraits::NUM_MMA_Q + + mma_q) * + KTraits::NUM_MMA_D_VO + + mma_d) * + WARP_SIZE + + lane_idx) * + KTraits::HALF_ELEMS_PER_THREAD); +#pragma unroll + for (uint32_t reg_id = 0; reg_id < KTraits::HALF_ELEMS_PER_THREAD; ++reg_id) { + o_new[reg_id] += oi[reg_id]; + } + } + o_new.store(o_frag[mma_q][mma_d]); + } + } + } + } +} + +template +__device__ __forceinline__ void write_o_reg_gmem( + float (*o_frag)[KTraits::NUM_MMA_D_VO][KTraits::HALF_ELEMS_PER_THREAD], + smem_t* o_smem, + typename KTraits::DTypeO* o_ptr_base, const uint32_t o_packed_idx_base, + const uint32_t qo_upper_bound, const uint32_t o_stride_n, const uint32_t o_stride_h, + const uint_fastdiv group_size, const dim3 tid = threadIdx) { + using DTypeO = typename KTraits::DTypeO; + constexpr uint32_t UPCAST_STRIDE_O = KTraits::UPCAST_STRIDE_O; + constexpr uint32_t TPR = KTraits::THREADS_PER_BMATRIX_ROW_SET; + constexpr uint32_t NAPTR = KTraits::NUM_ACCUM_ROWS_PER_THREAD; + constexpr uint32_t HALF_ELEMS_PER_THREAD = KTraits::HALF_ELEMS_PER_THREAD; + constexpr uint32_t WARP_THREAD_COLS = KTraits::WARP_THREAD_COLS; + constexpr uint32_t VECTOR_BIT_WIDTH = KTraits::VECTOR_BIT_WIDTH; + + const uint32_t warp_idx_x = get_warp_idx_q(tid.y); + const uint32_t lane_idx = tid.x; + + if constexpr (sizeof(DTypeO) == 4) { +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t j = 0; j < NAPTR; ++j) { + uint32_t q, r; + group_size.divmod(o_packed_idx_base + lane_idx / TPR + mma_q * 16 + j * 8, q, r); + const uint32_t o_idx = q; +#pragma unroll + for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { + if (o_idx < qo_upper_bound) { + auto base_addr = o_ptr_base + q * o_stride_n + r * o_stride_h + mma_d * 16; + auto col_offset = lane_idx % 16; +#if defined(PLATFORM_HIP_DEVICE) + *(base_addr + col_offset) = o_frag[mma_q][mma_d][j]; +#else + *reinterpret_cast(base_addr + col_offset * 2) = + *reinterpret_cast(&o_frag[mma_q][mma_d][j * 2]); + + *reinterpret_cast(base_addr + 8 + col_offset * 2) = + *reinterpret_cast(&o_frag[mma_q][mma_d][$ + j * 2]); +#endif + } + } + } + } + } else { + if (get_warp_idx_kv(tid.z) == 0) { +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t mma_d = 0; mma_d < KTraits::NUM_MMA_D_VO; ++mma_d) { + uint32_t o_frag_f16[HALF_ELEMS_PER_THREAD / 2]; + vec_cast::template cast((DTypeO*)o_frag_f16, + o_frag[mma_q][mma_d]); + +#ifdef FLASHINFER_STMATRIX_M8N8X4_ENABLED + uint32_t o_smem_offset_w = o_smem->template get_permuted_offset( + (warp_idx_x * KTraits::NUM_MMA_Q + mma_q) * 16 + lane_idx % 16, + mma_d * 2 + lane_idx / 16); + o_smem->stmatrix_m8n8x4(o_smem_offset_w, o_frag_f16); +#else + uint32_t o_smem_offset_w = o_smem->template get_permuted_offset( + (warp_idx_x * KTraits::NUM_MMA_Q + mma_q) * 16 + lane_idx / TPR, mma_d * 2); +#if defined(PLATFORM_HIP_DEVICE) + ((uint32_t*)(o_smem->base + o_smem_offset_w))[lane_idx % TPR] = o_frag_f16[0]; + // Move 2 elements forward in the same row + uint32_t offset_2 = o_smem_offset_w + 2; + ((uint32_t*)(o_smem->base + offset_2))[lane_idx % 16] = o_frag_f16[1]; +#else + ((uint32_t*)(o_smem->base + o_smem_offset_w))[lane_idx % TPR] = o_frag_f16[0]; + ((uint32_t*)(o_smem->base + o_smem_offset_w + 8 * UPCAST_STRIDE_O))[lane_idx % 4] = + o_frag_f16[1]; + ((uint32_t*)(o_smem->base + (o_smem_offset_w ^ 0x1)))[lane_idx % TPR] = o_frag_f16[2]; + ((uint32_t*)(o_smem->base + (o_smem_offset_w ^ 0x1) + + 8 * UPCAST_STRIDE_O))[lane_idx % 4] = o_frag_f16[3]; +#endif +#endif + } + } + + uint32_t o_smem_offset_w = o_smem->template get_permuted_offset( + warp_idx_x * KTraits::NUM_MMA_Q * 16 + lane_idx / WARP_THREAD_COLS, + lane_idx % WARP_THREAD_COLS); + +#pragma unroll + for (uint32_t mma_q = 0; mma_q < KTraits::NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t j = 0; j < 2 * 2; ++j) { + uint32_t q, r; + group_size.divmod(o_packed_idx_base + lane_idx / WARP_THREAD_COLS + mma_q * 16 + j * 4, q, + r); + const uint32_t o_idx = q; + DTypeO* o_ptr = o_ptr_base + q * o_stride_n + r * o_stride_h + + (lane_idx % WARP_THREAD_COLS) * upcast_size(); +#pragma unroll + for (uint32_t mma_do = 0; mma_do < KTraits::NUM_MMA_D_VO / 4; ++mma_do) { + if (o_idx < qo_upper_bound) { + o_smem->store_vector(o_smem_offset_w, o_ptr); + } + o_ptr += WARP_THREAD_COLS * upcast_size(); + o_smem_offset_w = o_smem->template advance_offset_by_column( + o_smem_offset_w, mma_do); + } + o_smem_offset_w = + o_smem->template advance_offset_by_row<4, UPCAST_STRIDE_O>(o_smem_offset_w) - + 2 * KTraits::NUM_MMA_D_VO; + } + } + } + } +} + +} // namespace + +/*! + * \brief FlashAttention prefill CUDA kernel for a single request. + * \tparam partition_kv Whether to split kv_len into chunks. + * \tparam mask_mode The mask mode used in the attention operation. + * \tparam POS_ENCODING_MODE The positional encoding mode. + * \tparam NUM_MMA_Q The number of fragments in x dimension. + * \tparam NUM_MMA_D_VO The number of fragments in y dimension. + * \tparam NUM_MMA_KV The number of fragments in z dimension. + * \tparam num_warps The number of warps in the threadblock. + * \tparam DTypeQ The data type of the query tensor. + * \tparam DTypeKV The data type of the key/value tensor. + * \tparam DTypeO The data type of the output tensor. + * \param q The query tensor. + * \param k The key tensor. + * \param v The value tensor. + * \param o The output tensor. + * \param tmp The temporary buffer (used when partition_kv is true). + * \param lse The logsumexp value. + * \param rope_rcp_scale 1/(rope_scale), where rope_scale is the scaling + * factor used in RoPE interpolation. + * \param rope_rcp_theta 1/(rope_theta), where rope_theta is the theta + * used in RoPE. + */ +template +__device__ __forceinline__ void SinglePrefillWithKVCacheDevice( + const Params params, typename KTraits::SharedStorage& smem_storage, const dim3 tid = threadIdx, + const uint32_t bx = blockIdx.x, const uint32_t chunk_idx = blockIdx.y, + const uint32_t kv_head_idx = blockIdx.z, const uint32_t num_chunks = gridDim.y, + const uint32_t num_kv_heads = gridDim.z) { + using DTypeQ = typename Params::DTypeQ; +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) + if constexpr (std::is_same_v) { + FLASHINFER_RUNTIME_ASSERT("Prefill kernels do not support bf16 on sm75."); + } else { +#endif + using DTypeKV = typename Params::DTypeKV; + using DTypeO = typename Params::DTypeO; + using DTypeQKAccum = typename KTraits::DTypeQKAccum; + using AttentionVariant = typename KTraits::AttentionVariant; + [[maybe_unused]] constexpr uint32_t NUM_MMA_Q = KTraits::NUM_MMA_Q; + [[maybe_unused]] constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; + [[maybe_unused]] constexpr uint32_t NUM_MMA_D_QK = KTraits::NUM_MMA_D_QK; + [[maybe_unused]] constexpr uint32_t NUM_MMA_D_VO = KTraits::NUM_MMA_D_VO; + [[maybe_unused]] constexpr uint32_t HEAD_DIM_QK = KTraits::HEAD_DIM_QK; + [[maybe_unused]] constexpr uint32_t HEAD_DIM_VO = KTraits::HEAD_DIM_VO; + [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_Q = KTraits::UPCAST_STRIDE_Q; + [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_K = KTraits::UPCAST_STRIDE_K; + [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_V = KTraits::UPCAST_STRIDE_V; + [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_O = KTraits::UPCAST_STRIDE_O; + [[maybe_unused]] constexpr uint32_t CTA_TILE_Q = KTraits::CTA_TILE_Q; + [[maybe_unused]] constexpr uint32_t CTA_TILE_KV = KTraits::CTA_TILE_KV; + [[maybe_unused]] constexpr uint32_t NUM_WARPS_Q = KTraits::NUM_WARPS_Q; + [[maybe_unused]] constexpr uint32_t NUM_WARPS_KV = KTraits::NUM_WARPS_KV; + [[maybe_unused]] constexpr SwizzleMode SWIZZLE_MODE_Q = KTraits::SWIZZLE_MODE_Q; + [[maybe_unused]] constexpr SwizzleMode SWIZZLE_MODE_KV = KTraits::SWIZZLE_MODE_KV; + [[maybe_unused]] constexpr uint32_t KV_THR_LAYOUT_ROW = KTraits::KV_THR_LAYOUT_ROW; + [[maybe_unused]] constexpr uint32_t KV_THR_LAYOUT_COL = KTraits::KV_THR_LAYOUT_COL; + [[maybe_unused]] constexpr MaskMode MASK_MODE = KTraits::MASK_MODE; + [[maybe_unused]] constexpr uint32_t HALF_ELEMS_PER_THREAD = KTraits::HALF_ELEMS_PER_THREAD; + [[maybe_unused]] constexpr uint32_t NUM_ACCUM_ROWS_PER_THREAD = + KTraits::NUM_ACCUM_ROWS_PER_THREAD; + [[maybe_unused]] constexpr uint32_t LOGITS_INDEX_STRIDE = KTraits::LOGITS_INDEX_STRIDE; + [[maybe_unused]] constexpr uint32_t THREADS_PER_BMATRIX_ROW_SET = + KTraits::THREADS_PER_BMATRIX_ROW_SET; + [[maybe_unused]] constexpr uint32_t VECTOR_BIT_WIDTH = KTraits::VECTOR_BIT_WIDTH; + + DTypeQ* q = params.q; + DTypeKV* k = params.k; + DTypeKV* v = params.v; + DTypeO* o = params.o; + float* lse = params.lse; + const uint32_t qo_len = params.qo_len; + const uint32_t kv_len = params.kv_len; + const bool partition_kv = params.partition_kv; + const uint32_t q_stride_n = params.q_stride_n; + const uint32_t q_stride_h = params.q_stride_h; + const uint32_t k_stride_n = params.k_stride_n; + const uint32_t k_stride_h = params.k_stride_h; + const uint32_t v_stride_n = params.v_stride_n; + const uint32_t v_stride_h = params.v_stride_h; + const uint_fastdiv& group_size = params.group_size; + + static_assert(sizeof(DTypeQ) == 2); + const uint32_t lane_idx = tid.x, warp_idx = get_warp_idx(tid.y, tid.z); + const uint32_t num_qo_heads = num_kv_heads * group_size; + + const uint32_t max_chunk_size = partition_kv ? ceil_div(kv_len, num_chunks) : kv_len; + const uint32_t chunk_start = partition_kv ? chunk_idx * max_chunk_size : 0; + const uint32_t chunk_end = + partition_kv ? min((chunk_idx + 1) * max_chunk_size, kv_len) : kv_len; + const uint32_t chunk_size = chunk_end - chunk_start; + + auto block = cg::this_thread_block(); + auto smem = reinterpret_cast(&smem_storage); + AttentionVariant variant(params, /*batch_idx=*/0, smem); + const uint32_t window_left = variant.window_left; + + DTypeQKAccum s_frag[NUM_MMA_Q][NUM_MMA_KV][HALF_ELEMS_PER_THREAD]; + alignas(16) float o_frag[NUM_MMA_Q][NUM_MMA_D_VO][HALF_ELEMS_PER_THREAD]; + DTypeQKAccum m[NUM_MMA_Q][NUM_ACCUM_ROWS_PER_THREAD]; + float d[NUM_MMA_Q][NUM_ACCUM_ROWS_PER_THREAD]; + float rope_freq[NUM_MMA_D_QK / 2][4]; + if constexpr (KTraits::POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) { + const float rope_rcp_scale = params.rope_rcp_scale; + const float rope_rcp_theta = params.rope_rcp_theta; + init_rope_freq(rope_freq, rope_rcp_scale, rope_rcp_theta, tid.x); + } + init_states(variant, o_frag, m, d); + +#if Debug + // Statically allocate a shared memory array specifically for debugging s_frag. + // This avoids modifying the main SharedStorage union. + __shared__ DTypeQKAccum qk_scratch[CTA_TILE_Q * CTA_TILE_KV]; +#endif + + // cooperative fetch q fragment from gmem to reg + const uint32_t qo_packed_idx_base = + (bx * NUM_WARPS_Q + get_warp_idx_q(tid.y)) * NUM_MMA_Q * 16; + smem_t qo_smem(smem_storage.q_smem); + const uint32_t o_stride_n = num_qo_heads * HEAD_DIM_VO, o_stride_h = HEAD_DIM_VO; + DTypeQ* q_ptr_base = q + (kv_head_idx * group_size) * q_stride_h; + DTypeO* o_ptr_base = partition_kv + ? o + chunk_idx * o_stride_n + (kv_head_idx * group_size) * o_stride_h + : o + (kv_head_idx * group_size) * o_stride_h; + + load_q_global_smem(qo_packed_idx_base, qo_len, q_ptr_base, q_stride_n, q_stride_h, + group_size, &qo_smem, tid); + + uint32_t q_smem_offset_r = qo_smem.template get_permuted_offset( + get_warp_idx_q(tid.y) * NUM_MMA_Q * 16 + lane_idx % 16, lane_idx / 16); + + memory::commit_group(); + if constexpr (KTraits::POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) { + memory::wait_group<0>(); + block.sync(); + q_smem_inplace_apply_rotary(qo_packed_idx_base, qo_len, kv_len, group_size, &qo_smem, + &q_smem_offset_r, rope_freq, tid); + block.sync(); + } + + smem_t k_smem(smem_storage.k_smem); + smem_t v_smem(smem_storage.v_smem); + + const uint32_t num_iterations = + ceil_div(MASK_MODE == MaskMode::kCausal + ? min(chunk_size, + sub_if_greater_or_zero( + kv_len - qo_len + ((bx + 1) * CTA_TILE_Q) / group_size, chunk_start)) + : chunk_size, + CTA_TILE_KV); + + const uint32_t window_iteration = + ceil_div(sub_if_greater_or_zero(kv_len + (bx + 1) * CTA_TILE_Q / group_size, + qo_len + window_left + chunk_start), + CTA_TILE_KV); + + const uint32_t mask_iteration = + (MASK_MODE == MaskMode::kCausal + ? min(chunk_size, sub_if_greater_or_zero( + kv_len + (bx * CTA_TILE_Q) / group_size - qo_len, chunk_start)) + : chunk_size) / + CTA_TILE_KV; + + DTypeKV* k_ptr = + k + + (chunk_start + warp_idx * KV_THR_LAYOUT_ROW + lane_idx / KV_THR_LAYOUT_COL) * k_stride_n + + kv_head_idx * k_stride_h + + (lane_idx % KV_THR_LAYOUT_COL) * upcast_size(); + DTypeKV* v_ptr = + v + + (chunk_start + warp_idx * KV_THR_LAYOUT_ROW + lane_idx / KV_THR_LAYOUT_COL) * v_stride_n + + kv_head_idx * v_stride_h + + (lane_idx % KV_THR_LAYOUT_COL) * upcast_size(); + +#if defined(PLATFORM_HIP_DEVICE) + uint32_t k_smem_offset_r = k_smem.template get_permuted_offset( + get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + lane_idx % 16, (lane_idx / 16)); + uint32_t v_smem_offset_r = v_smem.template get_permuted_offset( + get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + (lane_idx % 4) + 4 * (lane_idx / 16), + lane_idx / 4); +#elif defined(PLATFORM_CUDA_DEVICE) + uint32_t k_smem_offset_r = k_smem.template get_permuted_offset( + get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + 8 * (lane_idx / 16) + lane_idx % 8, + (lane_idx % 16) / 8); + uint32_t v_smem_offset_r = v_smem.template get_permuted_offset( + get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + (lane_idx % 4) + 4 * (lane_idx / 16), + lane_idx / 4); +#endif + uint32_t k_smem_offset_w = k_smem.template get_permuted_offset( + warp_idx * KV_THR_LAYOUT_ROW + lane_idx / KV_THR_LAYOUT_COL, + lane_idx % KV_THR_LAYOUT_COL), + v_smem_offset_w = v_smem.template get_permuted_offset( + warp_idx * KV_THR_LAYOUT_ROW + lane_idx / KV_THR_LAYOUT_COL, + lane_idx % KV_THR_LAYOUT_COL); + produce_kv(k_smem, &k_smem_offset_w, &k_ptr, + k_stride_n, 0, chunk_size, tid); + memory::commit_group(); + produce_kv(v_smem, &v_smem_offset_w, &v_ptr, + v_stride_n, 0, chunk_size, tid); + memory::commit_group(); +#if Debug + // if (warp_idx == 0 && lane_idx == 0) { + // printf("partition_kv : %d\n", partition_kv); + // printf("kv_len : %d\n", kv_len); + // printf("max_chunk_size : %d\n", max_chunk_size); + // printf("chunk_end : %d\n", chunk_end); + // printf("chunk_start : %d\n", chunk_start); + // } +#if 0 + // Test Q + if (warp_idx == 0 && lane_idx == 0) { + printf("\n DEBUG Q ORIGINAL (HIP):\n"); + uint32_t q_smem_offset_r_debug; + for (auto i = 0; i < NUM_MMA_Q * 16 * 4; ++i) { + for (auto j = 0; j < NUM_MMA_D_QK * 4; ++j) { + q_smem_offset_r_debug = + qo_smem.template get_permuted_offset( + i, j); + uint32_t a_frag[KTraits::INT32_ELEMS_PER_THREAD]; + qo_smem.load_fragment(q_smem_offset_r_debug, a_frag); + auto frag_T = reinterpret_cast<__half *>(a_frag); + for (auto i = 0ul; i < 4; ++i) { + printf("%f ", (float)(*(frag_T + i))); + } + } + printf("\n"); + qo_smem.template advance_offset_by_row< + 16, KTraits::UPCAST_STRIDE_Q>(q_smem_offset_r_debug); + } + } +#endif + // Test K Global values: + // Prints the (NUM_MMA_KV*16) x (NUM_MMA_D*16) matrix from global mem. + + if (warp_idx == 0 && lane_idx == 0) { + printf("\n DEBUG K Global (HIP):\n"); + printf("k_stride_n : %d\n", k_stride_n); + printf("k_stride_h : %d\n", k_stride_h); + printf("kv_head_idx : %d\n", kv_head_idx); + printf("num_qo_heads : %d\n", num_qo_heads); + printf("num_kv_heads : %d\n", num_kv_heads); + printf("k_stride_n : %d\n", k_stride_n); + printf("KTraits::NUM_MMA_D_QK : %d\n", KTraits::NUM_MMA_D_QK); + printf("NUM_MMA_KV : %d\n", NUM_MMA_KV); + printf("NUM_MMA_Q : %d\n", NUM_MMA_Q); + printf("sm_scale : %f\n", variant.sm_scale_log2); +#if 0 + DTypeKV *k_ptr_tmp = k + + (chunk_start + warp_idx * KV_THR_LAYOUT_ROW + + lane_idx / KV_THR_LAYOUT_COL) * + k_stride_n + + kv_head_idx * k_stride_h + + (lane_idx % KV_THR_LAYOUT_COL) * + upcast_size(); + for (auto i = 0; i < NUM_MMA_KV * 16; ++i) { + for (auto j = 0; j < NUM_MMA_D_QK * 16; ++j) { + auto fKval = (float)*(k_ptr_tmp); + k_ptr_tmp += 1; + printf("%f ", fKval); + } + printf("\n"); + } +#endif + } + + // Test K LDS values: + // Prints the (NUM_MMA_KV*16) x (NUM_MMA_D*16) matrix from shared mem. + // Note that LDS is loaded collaboratively by all warps and not each + // warp accesses the whole K matrix loaded into LDS. Each warp will + // only access 1/4 of the K values loaded into LDS. +#endif + +#pragma unroll 1 + for (uint32_t iter = 0; iter < num_iterations; ++iter) { + memory::wait_group<1>(); + block.sync(); + + if constexpr (KTraits::POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) { + k_smem_inplace_apply_rotary(chunk_start + iter * CTA_TILE_KV, &k_smem, + &k_smem_offset_r, rope_freq, tid); + block.sync(); + } +#if Debug1 + +#if 0 + if (warp_idx == 0 && lane_idx == 0) { + printf("\n DEBUG K LDS ORIGINAL (HIP) Iter %d:\n", iter); + uint32_t k_smem_offset_r_debug; + for (auto i = 0; i < NUM_MMA_KV * 16; ++i) { + for (auto j = 0; j < NUM_MMA_D_QK * 4; ++j) { + k_smem_offset_r_debug = + k_smem.template get_permuted_offset(i, + j); + uint32_t a_frag[KTraits::INT32_ELEMS_PER_THREAD]; + k_smem.load_fragment(k_smem_offset_r_debug, a_frag); + auto frag_T = reinterpret_cast<__half *>(a_frag); + for (auto i = 0ul; i < 4; ++i) { + printf("%f ", (float)(*(frag_T + i))); + } + } + printf("\n"); + k_smem.template advance_offset_by_row<16, KTraits::UPCAST_STRIDE_K>( + k_smem_offset_r_debug); + } + } +#endif + +#if 1 + if (warp_idx == 0 && lane_idx == 0) { + uint32_t b_frag[KTraits::INT32_ELEMS_PER_THREAD]; + k_smem.load_fragment(k_smem_offset_r, b_frag); + auto frag_T = reinterpret_cast<__half*>(b_frag); + for (auto reg_id = 0ul; reg_id < 4; ++reg_id) { + for (auto i = 0ul; i < 4; ++i) { + printf("%f ", (float)(*(frag_T + i))); + } + } + printf("\n------------\n"); + k_smem.load_fragment(k_smem_offset_r, b_frag); + frag_T = reinterpret_cast<__half*>(b_frag); + for (auto reg_id = 0ul; reg_id < 4; ++reg_id) { + for (auto i = 0ul; i < 4; ++i) { + printf("%f ", (float)(*(frag_T + i))); + } + } + printf("\n-----===============-------\n"); + } +#endif +#endif + + // compute attention score + compute_qk(&qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag); +#if Debug1 + flashinfer::gpu_iface::debug_utils::hip::write_s_frag_to_lds< + DTypeQKAccum, NUM_MMA_Q, NUM_MMA_KV, HALF_ELEMS_PER_THREAD>(s_frag, qk_scratch, + CTA_TILE_KV, tid); + + // a) Print thread 0's registers to see the source data. + flashinfer::gpu_iface::debug_utils::hip::print_s_frag_register< + DTypeQKAccum, NUM_MMA_Q, NUM_MMA_KV, HALF_ELEMS_PER_THREAD>(s_frag, tid); + + // b) Print the materialized LDS array to see the final result for this iteration. + flashinfer::gpu_iface::debug_utils::hip::print_lds_array(qk_scratch, CTA_TILE_Q, CTA_TILE_KV); +#endif + + logits_transform( + params, variant, /*batch_idx=*/0, qo_packed_idx_base, + chunk_start + (iter * NUM_WARPS_KV + get_warp_idx_kv(tid.z)) * NUM_MMA_KV * 16, + qo_len, kv_len, group_size, s_frag, tid, kv_head_idx); + + // apply mask + if (MASK_MODE == MaskMode::kCustom || (iter >= mask_iteration || iter < window_iteration)) { + logits_mask( + params, variant, /*batch_idx=*/0, qo_packed_idx_base, + chunk_start + (iter * NUM_WARPS_KV + get_warp_idx_kv(tid.z)) * NUM_MMA_KV * 16, + qo_len, kv_len, chunk_end, group_size, s_frag, tid, kv_head_idx); + } +#if Debug1 + flashinfer::gpu_iface::debug_utils::hip::write_s_frag_to_lds< + DTypeQKAccum, NUM_MMA_Q, NUM_MMA_KV, HALF_ELEMS_PER_THREAD>(s_frag, qk_scratch, + CTA_TILE_KV, tid); + + // // a) Print thread 0's registers to see the source data. + // flashinfer::gpu_iface::debug_utils::hip::print_s_frag_register< + // DTypeQKAccum, NUM_MMA_Q, NUM_MMA_KV, HALF_ELEMS_PER_THREAD>(s_frag, tid); + + // b) Print the materialized LDS array to see the final result for this iteration. + flashinfer::gpu_iface::debug_utils::hip::print_lds_array( + qk_scratch, CTA_TILE_Q, CTA_TILE_KV, ("S frag before update_mdo for iteration\n")); + +#endif + // compute m,d states in online softmax + update_mdo_states(variant, s_frag, o_frag, m, d, warp_idx, lane_idx); + +#if Debug1 + flashinfer::gpu_iface::debug_utils::hip::write_s_frag_to_lds< + DTypeQKAccum, NUM_MMA_Q, NUM_MMA_KV, HALF_ELEMS_PER_THREAD>(s_frag, qk_scratch, + CTA_TILE_KV, tid); + + // // a) Print thread 0's registers to see the source data. + // flashinfer::gpu_iface::debug_utils::hip::print_s_frag_register< + // DTypeQKAccum, NUM_MMA_Q, NUM_MMA_KV, HALF_ELEMS_PER_THREAD>(s_frag, tid); + + // b) Print the materialized LDS array to see the final result for this iteration. + flashinfer::gpu_iface::debug_utils::hip::print_lds_array( + qk_scratch, CTA_TILE_Q, CTA_TILE_KV, ("S frag after update_mdo for iteration\n")); + + // c) Print d values after update_mdo_states + flashinfer::gpu_iface::debug_utils::hip::write_d_to_lds( + d, qk_scratch, tid); + flashinfer::gpu_iface::debug_utils::hip::print_lds_array_1d( + qk_scratch, CTA_TILE_Q, "--- d values after update_mdo_states ---"); +#endif + block.sync(); + produce_kv( + k_smem, &k_smem_offset_w, &k_ptr, k_stride_n, (iter + 1) * CTA_TILE_KV, chunk_size, tid); + memory::commit_group(); + memory::wait_group<1>(); + block.sync(); + + // compute sfm*v + compute_sfm_v(&v_smem, &v_smem_offset_r, s_frag, o_frag, d, tid, qk_scratch); + block.sync(); + produce_kv( + v_smem, &v_smem_offset_w, &v_ptr, v_stride_n, (iter + 1) * CTA_TILE_KV, chunk_size, tid); + memory::commit_group(); + } + memory::wait_group<0>(); + block.sync(); + + finalize_m(variant, m); + + // threadblock synchronization + threadblock_sync_mdo_states(o_frag, &smem_storage, m, d, warp_idx, lane_idx, tid); + + // normalize d + normalize_d(o_frag, m, d); + + // write back + write_o_reg_gmem(o_frag, &qo_smem, o_ptr_base, qo_packed_idx_base, qo_len, + /*o_stride_n=*/ + partition_kv ? num_chunks * o_stride_n : o_stride_n, + /*o_stride_h=*/o_stride_h, group_size, tid); + + // write lse + if constexpr (variant.use_softmax) { + if (lse != nullptr || partition_kv) { + if (get_warp_idx_kv(tid.z) == 0) { +#pragma unroll + for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t j = 0; j < NUM_ACCUM_ROWS_PER_THREAD; ++j) { + uint32_t q, r; + group_size.divmod(qo_packed_idx_base + lane_idx / THREADS_PER_BMATRIX_ROW_SET + + j * LOGITS_INDEX_STRIDE + mma_q * 16, + q, r); + const uint32_t qo_head_idx = kv_head_idx * group_size + r; + const uint32_t qo_idx = q; + if (qo_idx < qo_len) { + if (partition_kv) { + lse[(qo_idx * num_chunks + chunk_idx) * num_qo_heads + qo_head_idx] = + gpu_iface::math::ptx_log2(d[mma_q][j]) + float(m[mma_q][j]); + } else { + lse[qo_idx * num_qo_heads + qo_head_idx] = + gpu_iface::math::ptx_log2(d[mma_q][j]) + float(m[mma_q][j]); + } + } + } + } + } + } + } +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) + } +#endif +} + +template +__global__ __launch_bounds__(KTraits::NUM_THREADS) void SinglePrefillWithKVCacheKernel( + const __grid_constant__ Params params) { + extern __shared__ uint8_t smem[]; + auto& smem_storage = reinterpret_cast(smem); + SinglePrefillWithKVCacheDevice(params, smem_storage); +} + +template +gpuError_t SinglePrefillWithKVCacheDispatched(Params params, typename Params::DTypeO* tmp, + gpuStream_t stream) { + using DTypeQ = typename Params::DTypeQ; + using DTypeKV = typename Params::DTypeKV; + using DTypeO = typename Params::DTypeO; + const uint32_t num_qo_heads = params.num_qo_heads; + const uint32_t num_kv_heads = params.num_kv_heads; + const uint32_t qo_len = params.qo_len; + const uint32_t kv_len = params.kv_len; + if (kv_len < qo_len && MASK_MODE == MaskMode::kCausal) { + std::ostringstream err_msg; + err_msg << "When mask_mode is set to MaskMode::kCausal, kv_len must be " + "greater than or equal to qo_len, got kv_len" + << kv_len << " and qo_len " << qo_len; + FLASHINFER_ERROR(err_msg.str()); + } + + const uint32_t group_size = num_qo_heads / num_kv_heads; + constexpr uint32_t NUM_MMA_D_QK = HEAD_DIM_QK / 16; + constexpr uint32_t NUM_MMA_D_VO = HEAD_DIM_VO / 16; + int64_t packed_qo_len = qo_len * group_size; + uint32_t cta_tile_q = FA2DetermineCtaTileQ(packed_qo_len, HEAD_DIM_VO); + + DISPATCH_CTA_TILE_Q(cta_tile_q, CTA_TILE_Q, { + constexpr uint32_t NUM_WARPS_Q = get_num_warps_q(CTA_TILE_Q); + constexpr uint32_t NUM_WARPS_KV = get_num_warps_kv(CTA_TILE_Q); + constexpr uint32_t NUM_MMA_Q = get_num_mma_q(CTA_TILE_Q); + + using DTypeQKAccum = + typename std::conditional, half, + float>::type; + + int dev_id = 0; + FI_GPU_CALL(gpuGetDevice(&dev_id)); + int max_smem_per_sm = getMaxSharedMemPerMultiprocessor(dev_id); + // we expect each sm execute two threadblocks + const int num_ctas_per_sm = + max_smem_per_sm >= 2 * (CTA_TILE_Q * HEAD_DIM_QK * sizeof(DTypeQ) + + (HEAD_DIM_QK + HEAD_DIM_VO) * 16 * NUM_WARPS_KV * sizeof(DTypeKV)) + ? 2 + : 1; + const int max_smem_per_threadblock = max_smem_per_sm / num_ctas_per_sm; + + const uint32_t max_num_mma_kv_reg = + (HEAD_DIM_VO >= 128 && NUM_MMA_Q == 2 && POS_ENCODING_MODE == PosEncodingMode::kRoPELlama && + !USE_FP16_QK_REDUCTION) + ? 2 + : (8 / NUM_MMA_Q); + const uint32_t max_num_mma_kv_smem = + (max_smem_per_threadblock - CTA_TILE_Q * HEAD_DIM_QK * sizeof(DTypeQ)) / + ((HEAD_DIM_QK + HEAD_DIM_VO) * 16 * NUM_WARPS_KV * sizeof(DTypeKV)); + + // control NUM_MMA_KV for maximum warp occupancy + DISPATCH_NUM_MMA_KV(min(max_num_mma_kv_smem, max_num_mma_kv_reg), NUM_MMA_KV, { + using KTraits = + KernelTraits; + if constexpr (KTraits::IsInvalid()) { + // Invalid configuration, skip + std::ostringstream err_msg; + err_msg << "FlashInfer Internal Error: Invalid " + "configuration : NUM_MMA_Q=" + << NUM_MMA_Q << " NUM_MMA_D_QK=" << NUM_MMA_D_QK << " NUM_MMA_D_VO=" << NUM_MMA_D_VO + << " NUM_MMA_KV=" << NUM_MMA_KV << " NUM_WARPS_Q=" << NUM_WARPS_Q + << " NUM_WARPS_KV=" << NUM_WARPS_KV + << " please create an issue " + "(https://github.com/flashinfer-ai/flashinfer/" + "issues)" + " and report the issue to the developers."; + FLASHINFER_ERROR(err_msg.str()); + } else { + constexpr uint32_t num_threads = (NUM_WARPS_Q * NUM_WARPS_KV) * WARP_SIZE; + auto kernel = SinglePrefillWithKVCacheKernel; + size_t smem_size = sizeof(typename KTraits::SharedStorage); + FI_GPU_CALL( + gpuFuncSetAttribute(kernel, gpuFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + int num_blocks_per_sm = 0; + int num_sm = 0; + FI_GPU_CALL(gpuDeviceGetAttribute(&num_sm, gpuDevAttrMultiProcessorCount, dev_id)); + FI_GPU_CALL(gpuOccupancyMaxActiveBlocksPerMultiprocessor(&num_blocks_per_sm, kernel, + num_threads, smem_size)); + uint32_t max_num_kv_chunks = (num_blocks_per_sm * num_sm) / + (num_kv_heads * ceil_div(qo_len * group_size, CTA_TILE_Q)); + uint32_t num_chunks; + if (max_num_kv_chunks > 0) { + uint32_t chunk_size = max(ceil_div(kv_len, max_num_kv_chunks), 256); + num_chunks = ceil_div(kv_len, chunk_size); + } else { + num_chunks = 0; + } + + if (num_chunks <= 1 || tmp == nullptr) { + // Enough parallelism, do not split-kv + params.partition_kv = false; + void* args[] = {(void*)¶ms}; + dim3 nblks(ceil_div(qo_len * group_size, CTA_TILE_Q), 1, num_kv_heads); + dim3 nthrs(WARP_SIZE, NUM_WARPS_Q, NUM_WARPS_KV); + FI_GPU_CALL(gpuLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + } else { + // Use cooperative groups to increase occupancy + params.partition_kv = true; + float* tmp_lse = (float*)(tmp + num_chunks * qo_len * num_qo_heads * HEAD_DIM_VO); + auto o = params.o; + auto lse = params.lse; + params.o = tmp; + params.lse = tmp_lse; + void* args[] = {(void*)¶ms}; + dim3 nblks(ceil_div(qo_len * group_size, CTA_TILE_Q), num_chunks, num_kv_heads); + dim3 nthrs(WARP_SIZE, NUM_WARPS_Q, NUM_WARPS_KV); + FI_GPU_CALL(gpuLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + if constexpr (AttentionVariant::use_softmax) { + FI_GPU_CALL(MergeStates(tmp, tmp_lse, o, lse, num_chunks, qo_len, num_qo_heads, + HEAD_DIM_VO, stream)); + } else { + FI_GPU_CALL( + AttentionSum(tmp, o, num_chunks, qo_len, num_qo_heads, HEAD_DIM_VO, stream)); + } + } + } + }) + }); + return gpuSuccess; +} + +template +__global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchPrefillWithRaggedKVCacheKernel( + const __grid_constant__ Params params) { + using DTypeQ = typename Params::DTypeQ; +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) + if constexpr (std::is_same_v) { + FLASHINFER_RUNTIME_ASSERT("Prefill kernels do not support bf16 on sm75."); + } else { +#endif + using DTypeKV = typename Params::DTypeKV; + using DTypeO = typename Params::DTypeO; + using IdType = typename Params::IdType; + using DTypeQKAccum = typename KTraits::DTypeQKAccum; + using AttentionVariant = typename KTraits::AttentionVariant; + [[maybe_unused]] constexpr uint32_t NUM_MMA_Q = KTraits::NUM_MMA_Q; + [[maybe_unused]] constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; + [[maybe_unused]] constexpr uint32_t NUM_MMA_D_QK = KTraits::NUM_MMA_D_QK; + [[maybe_unused]] constexpr uint32_t NUM_MMA_D_VO = KTraits::NUM_MMA_D_VO; + [[maybe_unused]] constexpr uint32_t HEAD_DIM_QK = KTraits::HEAD_DIM_QK; + [[maybe_unused]] constexpr uint32_t HEAD_DIM_VO = KTraits::HEAD_DIM_VO; + [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_Q = KTraits::UPCAST_STRIDE_Q; + [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_K = KTraits::UPCAST_STRIDE_K; + [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_V = KTraits::UPCAST_STRIDE_V; + [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_O = KTraits::UPCAST_STRIDE_O; + [[maybe_unused]] constexpr uint32_t CTA_TILE_Q = KTraits::CTA_TILE_Q; + [[maybe_unused]] constexpr uint32_t CTA_TILE_KV = KTraits::CTA_TILE_KV; + [[maybe_unused]] constexpr uint32_t NUM_WARPS_Q = KTraits::NUM_WARPS_Q; + [[maybe_unused]] constexpr uint32_t NUM_WARPS_KV = KTraits::NUM_WARPS_KV; + [[maybe_unused]] constexpr SwizzleMode SWIZZLE_MODE_Q = KTraits::SWIZZLE_MODE_Q; + [[maybe_unused]] constexpr SwizzleMode SWIZZLE_MODE_KV = KTraits::SWIZZLE_MODE_KV; + [[maybe_unused]] constexpr uint32_t KV_THR_LAYOUT_ROW = KTraits::KV_THR_LAYOUT_ROW; + [[maybe_unused]] constexpr uint32_t KV_THR_LAYOUT_COL = KTraits::KV_THR_LAYOUT_COL; + [[maybe_unused]] constexpr MaskMode MASK_MODE = KTraits::MASK_MODE; + [[maybe_unused]] constexpr uint32_t VECTOR_BIT_WIDTH = KTraits::VECTOR_BIT_WIDTH; + + DTypeQ* q = params.q; + IdType* request_indices = params.request_indices; + IdType* qo_tile_indices = params.qo_tile_indices; + IdType* kv_tile_indices = params.kv_tile_indices; + IdType* q_indptr = params.q_indptr; + IdType* kv_indptr = params.kv_indptr; + DTypeKV* k = params.k; + DTypeKV* v = params.v; + IdType* o_indptr = params.o_indptr; + DTypeO* o = params.o; + float* lse = params.lse; + bool* block_valid_mask = params.block_valid_mask; + const bool partition_kv = params.partition_kv; + const uint32_t q_stride_n = params.q_stride_n; + const uint32_t q_stride_h = params.q_stride_h; + const uint32_t k_stride_n = params.k_stride_n; + const uint32_t k_stride_h = params.k_stride_h; + const uint32_t v_stride_n = params.v_stride_n; + const uint32_t v_stride_h = params.v_stride_h; + const uint_fastdiv& group_size = params.group_size; + + static_assert(sizeof(DTypeQ) == 2); + const uint32_t kv_chunk_size = *(params.kv_chunk_size_ptr); + const dim3& tid = threadIdx; + + auto block = cg::this_thread_block(); + const uint32_t bx = blockIdx.x, lane_idx = tid.x, + warp_idx = get_warp_idx(tid.y, tid.z), kv_head_idx = blockIdx.z; + if (block_valid_mask && !block_valid_mask[bx]) { + return; + } + const uint32_t num_kv_heads = gridDim.z, num_qo_heads = group_size * num_kv_heads; + const uint32_t request_idx = request_indices[bx], qo_tile_idx = qo_tile_indices[bx], + kv_tile_idx = kv_tile_indices[bx]; + extern __shared__ uint8_t smem[]; + auto& smem_storage = reinterpret_cast(smem); + AttentionVariant variant(params, /*batch_idx=*/request_idx, smem); + const uint32_t qo_len = variant.qo_len, kv_len = variant.kv_len, + window_left = variant.window_left; + const uint32_t kv_len_safe = kv_len > 0 ? kv_len : 1; + const uint32_t max_chunk_size = partition_kv ? kv_chunk_size : kv_len; + const uint32_t chunk_start = partition_kv ? kv_tile_idx * max_chunk_size : 0; + const uint32_t chunk_end = + partition_kv ? min((kv_tile_idx + 1) * max_chunk_size, kv_len) : kv_len; + const uint32_t chunk_size = chunk_end - chunk_start; + const uint32_t qo_upper_bound = + min(qo_len, ceil_div((qo_tile_idx + 1) * CTA_TILE_Q, group_size)); + + DTypeQKAccum s_frag[NUM_MMA_Q][NUM_MMA_KV][8]; + alignas(16) float o_frag[NUM_MMA_Q][NUM_MMA_D_VO][8]; + DTypeQKAccum m[NUM_MMA_Q][2]; + float d[NUM_MMA_Q][2]; + float rope_freq[NUM_MMA_D_QK / 2][4]; + + if constexpr (KTraits::POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) { + const float rope_rcp_scale = params.rope_rcp_scale; + const float rope_rcp_theta = params.rope_rcp_theta; + init_rope_freq(rope_freq, rope_rcp_scale, rope_rcp_theta, tid.x); + } + init_states(variant, o_frag, m, d); + + const uint32_t qo_packed_idx_base = + (qo_tile_idx * NUM_WARPS_Q + get_warp_idx_q(tid.y)) * NUM_MMA_Q * 16; + smem_t qo_smem(smem_storage.q_smem); + const uint32_t o_stride_n = num_qo_heads * HEAD_DIM_VO, o_stride_h = HEAD_DIM_VO; + + DTypeQ* q_ptr_base = + q + q_indptr[request_idx] * q_stride_n + kv_head_idx * group_size * q_stride_h; + + DTypeO* o_ptr_base = partition_kv ? o + (o_indptr[request_idx] + kv_tile_idx) * o_stride_n + + (kv_head_idx * group_size) * o_stride_h + : o + o_indptr[request_idx] * o_stride_n + + (kv_head_idx * group_size) * o_stride_h; + + uint32_t q_smem_offset_r = qo_smem.template get_permuted_offset( + get_warp_idx_q(tid.y) * NUM_MMA_Q * 16 + lane_idx % 16, lane_idx / 16); + + load_q_global_smem(qo_packed_idx_base, qo_upper_bound, q_ptr_base, q_stride_n, + q_stride_h, group_size, &qo_smem, tid); + + memory::commit_group(); + + if constexpr (KTraits::POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) { + memory::wait_group<0>(); + block.sync(); + IdType* q_rope_offset = nullptr; + + if constexpr (has_maybe_q_rope_offset_v) { + q_rope_offset = params.maybe_q_rope_offset; + } + if (!q_rope_offset) { + q_smem_inplace_apply_rotary(qo_packed_idx_base, qo_len, kv_len, group_size, + &qo_smem, &q_smem_offset_r, rope_freq, tid); + } else { + q_smem_inplace_apply_rotary_with_pos( + qo_packed_idx_base, q_rope_offset + q_indptr[request_idx], &qo_smem, group_size, + &q_smem_offset_r, rope_freq, tid); + } + block.sync(); + } + + const uint32_t num_iterations = ceil_div( + (MASK_MODE == MaskMode::kCausal + ? min(chunk_size, sub_if_greater_or_zero( + kv_len - qo_len + ((qo_tile_idx + 1) * CTA_TILE_Q) / group_size, + chunk_start)) + : chunk_size), + CTA_TILE_KV); + + const uint32_t window_iteration = + ceil_div(sub_if_greater_or_zero(kv_len + (qo_tile_idx + 1) * CTA_TILE_Q / group_size, + qo_len + window_left + chunk_start), + CTA_TILE_KV); + + const uint32_t mask_iteration = + (MASK_MODE == MaskMode::kCausal + ? min(chunk_size, + sub_if_greater_or_zero(kv_len + (qo_tile_idx * CTA_TILE_Q) / group_size - qo_len, + chunk_start)) + : chunk_size) / + CTA_TILE_KV; + + smem_t k_smem(smem_storage.k_smem), v_smem(smem_storage.v_smem); + + uint32_t k_smem_offset_r = k_smem.template get_permuted_offset( + get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + 8 * (lane_idx / 16) + + lane_idx % 8, + (lane_idx % 16) / 8), + v_smem_offset_r = v_smem.template get_permuted_offset( + get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + lane_idx % 16, lane_idx / 16), + k_smem_offset_w = k_smem.template get_permuted_offset( + warp_idx * KV_THR_LAYOUT_ROW + lane_idx / KV_THR_LAYOUT_COL, + lane_idx % KV_THR_LAYOUT_COL), + v_smem_offset_w = v_smem.template get_permuted_offset( + warp_idx * KV_THR_LAYOUT_ROW + lane_idx / KV_THR_LAYOUT_COL, + lane_idx % KV_THR_LAYOUT_COL); + + DTypeKV* k_ptr = k + + (kv_indptr[request_idx] + chunk_start + warp_idx * KV_THR_LAYOUT_ROW + + lane_idx / KV_THR_LAYOUT_COL) * + k_stride_n + + kv_head_idx * k_stride_h + + (lane_idx % KV_THR_LAYOUT_COL) * upcast_size(); + DTypeKV* v_ptr = v + + (kv_indptr[request_idx] + chunk_start + warp_idx * KV_THR_LAYOUT_ROW + + lane_idx / KV_THR_LAYOUT_COL) * + v_stride_n + + kv_head_idx * v_stride_h + + (lane_idx % KV_THR_LAYOUT_COL) * upcast_size(); + + produce_kv(k_smem, &k_smem_offset_w, &k_ptr, + k_stride_n, 0, chunk_size, tid); + memory::commit_group(); + produce_kv(v_smem, &v_smem_offset_w, &v_ptr, + v_stride_n, 0, chunk_size, tid); + memory::commit_group(); + +#pragma unroll 1 + for (uint32_t iter = 0; iter < num_iterations; ++iter) { + memory::wait_group<1>(); + block.sync(); + + if constexpr (KTraits::POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) { + IdType* k_rope_offset = nullptr; + if constexpr (has_maybe_k_rope_offset_v) { + k_rope_offset = params.maybe_k_rope_offset; + } + k_smem_inplace_apply_rotary( + (k_rope_offset == nullptr ? 0 : k_rope_offset[request_idx]) + chunk_start + + iter * CTA_TILE_KV, + &k_smem, &k_smem_offset_r, rope_freq, tid); + block.sync(); + } + + // compute attention score + compute_qk(&qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag); + + logits_transform( + params, variant, /*batch_idx=*/request_idx, qo_packed_idx_base, + chunk_start + (iter * NUM_WARPS_KV + get_warp_idx_kv(tid.z)) * NUM_MMA_KV * 16, + qo_len, kv_len, group_size, s_frag, tid, kv_head_idx); + + // apply mask + if (MASK_MODE == MaskMode::kCustom || (iter >= mask_iteration || iter < window_iteration)) { + logits_mask( + params, variant, /*batch_idx=*/request_idx, qo_packed_idx_base, + chunk_start + (iter * NUM_WARPS_KV + get_warp_idx_kv(tid.z)) * NUM_MMA_KV * 16, + qo_len, kv_len, chunk_end, group_size, s_frag, tid, kv_head_idx); + } + + // compute m,d states in online softmax + update_mdo_states(variant, s_frag, o_frag, m, d); + + block.sync(); + produce_kv( + k_smem, &k_smem_offset_w, &k_ptr, k_stride_n, (iter + 1) * CTA_TILE_KV, chunk_size, tid); + memory::commit_group(); + memory::wait_group<1>(); + block.sync(); + + // compute sfm*v + compute_sfm_v(&v_smem, &v_smem_offset_r, s_frag, o_frag, d); + + block.sync(); + produce_kv( + v_smem, &v_smem_offset_w, &v_ptr, v_stride_n, (iter + 1) * CTA_TILE_KV, chunk_size, tid); + memory::commit_group(); + } + memory::wait_group<0>(); + block.sync(); + + finalize_m(variant, m); + + // threadblock synchronization + threadblock_sync_mdo_states(o_frag, &smem_storage, m, d, warp_idx, lane_idx, tid); + + // normalize d + normalize_d(o_frag, m, d); + + const uint32_t num_kv_chunks = (kv_len_safe + kv_chunk_size - 1) / kv_chunk_size; + + // write back + write_o_reg_gmem(o_frag, &qo_smem, o_ptr_base, qo_packed_idx_base, qo_len, + /*o_stride_n=*/ + partition_kv ? num_kv_chunks * o_stride_n : o_stride_n, + /*o_stride_h=*/o_stride_h, group_size, tid); + + // write lse + if constexpr (AttentionVariant::use_softmax) { + if (lse != nullptr) { + if (get_warp_idx_kv(tid.z) == 0) { +#pragma unroll + for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + uint32_t q, r; + group_size.divmod(qo_packed_idx_base + lane_idx / 4 + j * 8 + mma_q * 16, q, r); + const uint32_t qo_head_idx = kv_head_idx * group_size + r; + const uint32_t qo_idx = q; + if (qo_idx < qo_len) { + if (partition_kv) { + lse[(o_indptr[request_idx] + qo_idx * num_kv_chunks + kv_tile_idx) * + num_qo_heads + + qo_head_idx] = gpu_iface::math::ptx_log2(d[mma_q][j]) + float(m[mma_q][j]); + } else { + lse[(o_indptr[request_idx] + qo_idx) * num_qo_heads + qo_head_idx] = + gpu_iface::math::ptx_log2(d[mma_q][j]) + float(m[mma_q][j]); + } + } + } + } + } + } + } +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) + } +#endif +} + +template +__device__ __forceinline__ void BatchPrefillWithPagedKVCacheDevice( + const Params params, typename KTraits::SharedStorage& smem_storage, const dim3 tid = threadIdx, + const uint32_t bx = blockIdx.x, const uint32_t kv_head_idx = blockIdx.z, + const uint32_t num_kv_heads = gridDim.z) { + using DTypeQ = typename Params::DTypeQ; +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) + if constexpr (std::is_same_v) { + FLASHINFER_RUNTIME_ASSERT("Prefill kernels do not support bf16 on sm75."); + } else { +#endif + using DTypeKV = typename Params::DTypeKV; + using DTypeO = typename Params::DTypeO; + using IdType = typename Params::IdType; + using DTypeQKAccum = typename KTraits::DTypeQKAccum; + using AttentionVariant = typename KTraits::AttentionVariant; + [[maybe_unused]] constexpr uint32_t NUM_MMA_Q = KTraits::NUM_MMA_Q; + [[maybe_unused]] constexpr uint32_t NUM_MMA_KV = KTraits::NUM_MMA_KV; + [[maybe_unused]] constexpr uint32_t NUM_MMA_D_QK = KTraits::NUM_MMA_D_QK; + [[maybe_unused]] constexpr uint32_t NUM_MMA_D_VO = KTraits::NUM_MMA_D_VO; + [[maybe_unused]] constexpr uint32_t HEAD_DIM_QK = KTraits::HEAD_DIM_QK; + [[maybe_unused]] constexpr uint32_t HEAD_DIM_VO = KTraits::HEAD_DIM_VO; + [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_Q = KTraits::UPCAST_STRIDE_Q; + [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_K = KTraits::UPCAST_STRIDE_K; + [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_V = KTraits::UPCAST_STRIDE_V; + [[maybe_unused]] constexpr uint32_t UPCAST_STRIDE_O = KTraits::UPCAST_STRIDE_O; + [[maybe_unused]] constexpr uint32_t NUM_WARPS_Q = KTraits::NUM_WARPS_Q; + [[maybe_unused]] constexpr uint32_t NUM_WARPS_KV = KTraits::NUM_WARPS_KV; + [[maybe_unused]] constexpr SwizzleMode SWIZZLE_MODE_Q = KTraits::SWIZZLE_MODE_Q; + [[maybe_unused]] constexpr SwizzleMode SWIZZLE_MODE_KV = KTraits::SWIZZLE_MODE_KV; + [[maybe_unused]] constexpr uint32_t CTA_TILE_Q = KTraits::CTA_TILE_Q; + [[maybe_unused]] constexpr uint32_t CTA_TILE_KV = KTraits::CTA_TILE_KV; + [[maybe_unused]] constexpr uint32_t KV_THR_LAYOUT_ROW = KTraits::KV_THR_LAYOUT_ROW; + [[maybe_unused]] constexpr uint32_t KV_THR_LAYOUT_COL = KTraits::KV_THR_LAYOUT_COL; + [[maybe_unused]] constexpr MaskMode MASK_MODE = KTraits::MASK_MODE; + [[maybe_unused]] constexpr uint32_t VECTOR_BIT_WIDTH = KTraits::VECTOR_BIT_WIDTH; + + IdType* request_indices = params.request_indices; + IdType* qo_tile_indices = params.qo_tile_indices; + IdType* kv_tile_indices = params.kv_tile_indices; + DTypeQ* q = params.q; + IdType* q_indptr = params.q_indptr; + IdType* o_indptr = params.o_indptr; + DTypeO* o = params.o; + float* lse = params.lse; + bool* block_valid_mask = params.block_valid_mask; + const paged_kv_t& paged_kv = params.paged_kv; + const bool partition_kv = params.partition_kv; + const uint_fastdiv& group_size = params.group_size; + + static_assert(sizeof(DTypeQ) == 2); + auto block = cg::this_thread_block(); + const uint32_t kv_chunk_size = *(params.kv_chunk_size_ptr); + + const uint32_t lane_idx = tid.x, warp_idx = get_warp_idx(tid.y, tid.z); + if (block_valid_mask && !block_valid_mask[bx]) { + return; + } + const uint32_t num_qo_heads = num_kv_heads * group_size; + + const uint32_t request_idx = request_indices[bx], qo_tile_idx = qo_tile_indices[bx], + kv_tile_idx = kv_tile_indices[bx]; + auto smem = reinterpret_cast(&smem_storage); + AttentionVariant variant(params, /*batch_idx=*/request_idx, smem); + const uint32_t qo_len = variant.qo_len, kv_len = variant.kv_len, + window_left = variant.window_left; + const uint32_t kv_len_safe = kv_len > 0 ? kv_len : 1; + const uint32_t max_chunk_size = partition_kv ? kv_chunk_size : kv_len; + const uint32_t chunk_start = partition_kv ? kv_tile_idx * max_chunk_size : 0; + const uint32_t chunk_end = + partition_kv ? min((kv_tile_idx + 1) * max_chunk_size, kv_len) : kv_len; + const uint32_t chunk_size = chunk_end - chunk_start; + const uint32_t qo_upper_bound = + min(qo_len, ceil_div((qo_tile_idx + 1) * CTA_TILE_Q, group_size)); + + DTypeQKAccum s_frag[NUM_MMA_Q][NUM_MMA_KV][8]; + alignas(16) float o_frag[NUM_MMA_Q][NUM_MMA_D_VO][8]; + DTypeQKAccum m[NUM_MMA_Q][2]; + float d[NUM_MMA_Q][2]; + float rope_freq[NUM_MMA_D_QK / 2][4]; + + if constexpr (KTraits::POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) { + const float rope_rcp_scale = params.rope_rcp_scale; + const float rope_rcp_theta = params.rope_rcp_theta; + init_rope_freq(rope_freq, rope_rcp_scale, rope_rcp_theta, tid.x); + } + init_states(variant, o_frag, m, d); + + const uint32_t qo_packed_idx_base = + (qo_tile_idx * NUM_WARPS_Q + get_warp_idx_q(tid.y)) * NUM_MMA_Q * 16; + const uint32_t q_stride_n = params.q_stride_n, q_stride_h = params.q_stride_h; + smem_t qo_smem(smem_storage.q_smem); + const uint32_t o_stride_n = num_qo_heads * HEAD_DIM_VO, o_stride_h = HEAD_DIM_VO; + DTypeQ* q_ptr_base = + q + q_indptr[request_idx] * q_stride_n + (kv_head_idx * group_size) * q_stride_h; + DTypeO* o_ptr_base = partition_kv ? o + (o_indptr[request_idx] + kv_tile_idx) * o_stride_n + + (kv_head_idx * group_size) * o_stride_h + : o + o_indptr[request_idx] * o_stride_n + + (kv_head_idx * group_size) * o_stride_h; + uint32_t q_smem_offset_r = qo_smem.template get_permuted_offset( + get_warp_idx_q(tid.y) * NUM_MMA_Q * 16 + lane_idx % 16, lane_idx / 16); + + load_q_global_smem(qo_packed_idx_base, qo_upper_bound, q_ptr_base, q_stride_n, + q_stride_h, group_size, &qo_smem, tid); + + memory::commit_group(); + + if constexpr (KTraits::POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) { + memory::wait_group<0>(); + block.sync(); + IdType* q_rope_offset = nullptr; + if constexpr (has_maybe_q_rope_offset_v) { + q_rope_offset = params.maybe_q_rope_offset; + } + if (q_rope_offset == nullptr) { + q_smem_inplace_apply_rotary(qo_packed_idx_base, qo_len, kv_len, group_size, + &qo_smem, &q_smem_offset_r, rope_freq, tid); + } else { + q_smem_inplace_apply_rotary_with_pos( + qo_packed_idx_base, q_rope_offset + q_indptr[request_idx], &qo_smem, group_size, + &q_smem_offset_r, rope_freq, tid); + } + block.sync(); + } + + smem_t k_smem(smem_storage.k_smem), v_smem(smem_storage.v_smem); + size_t thr_local_kv_offset[NUM_MMA_KV * KV_THR_LAYOUT_COL / 2 / NUM_WARPS_Q]; + + uint32_t k_smem_offset_r = k_smem.template get_permuted_offset( + get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + 8 * (lane_idx / 16) + + lane_idx % 8, + (lane_idx % 16) / 8), + v_smem_offset_r = v_smem.template get_permuted_offset( + get_warp_idx_kv(tid.z) * NUM_MMA_KV * 16 + lane_idx % 16, lane_idx / 16), + k_smem_offset_w = k_smem.template get_permuted_offset( + warp_idx * KV_THR_LAYOUT_ROW + lane_idx / KV_THR_LAYOUT_COL, + lane_idx % KV_THR_LAYOUT_COL), + v_smem_offset_w = v_smem.template get_permuted_offset( + warp_idx * KV_THR_LAYOUT_ROW + lane_idx / KV_THR_LAYOUT_COL, + lane_idx % KV_THR_LAYOUT_COL); + const IdType last_indptr = paged_kv.indptr[paged_kv.batch_size]; + + uint32_t packed_page_iter_base = + paged_kv.indptr[request_idx] * paged_kv.page_size + chunk_start; +#pragma unroll + for (uint32_t i = 0; + i < NUM_MMA_KV * (SWIZZLE_MODE_KV == SwizzleMode::k128B ? 4 : 2) / NUM_WARPS_Q; ++i) { + uint32_t page_iter, entry_idx; + paged_kv.page_size.divmod(packed_page_iter_base + warp_idx * KV_THR_LAYOUT_ROW + + lane_idx / KV_THR_LAYOUT_COL + + KV_THR_LAYOUT_ROW * NUM_WARPS_Q * NUM_WARPS_KV * i, + page_iter, entry_idx); + thr_local_kv_offset[i] = paged_kv.protective_get_kv_offset( + page_iter, kv_head_idx, entry_idx, + (lane_idx % KV_THR_LAYOUT_COL) * upcast_size(), last_indptr); + } + page_produce_kv(k_smem, &k_smem_offset_w, paged_kv, 0, thr_local_kv_offset, + chunk_size, tid); + memory::commit_group(); + page_produce_kv(v_smem, &v_smem_offset_w, paged_kv, 0, thr_local_kv_offset, + chunk_size, tid); + memory::commit_group(); + + const uint32_t num_iterations = ceil_div( + (MASK_MODE == MaskMode::kCausal + ? min(chunk_size, sub_if_greater_or_zero( + kv_len - qo_len + ((qo_tile_idx + 1) * CTA_TILE_Q) / group_size, + chunk_start)) + : chunk_size), + CTA_TILE_KV); + + const uint32_t window_iteration = + ceil_div(sub_if_greater_or_zero(kv_len + (qo_tile_idx + 1) * CTA_TILE_Q / group_size, + qo_len + window_left + chunk_start), + CTA_TILE_KV); + + const uint32_t mask_iteration = + (MASK_MODE == MaskMode::kCausal + ? min(chunk_size, + sub_if_greater_or_zero(kv_len + (qo_tile_idx * CTA_TILE_Q) / group_size - qo_len, + chunk_start)) + : chunk_size) / + CTA_TILE_KV; + +#pragma unroll 1 + for (uint32_t iter = 0; iter < num_iterations; ++iter) { + packed_page_iter_base += CTA_TILE_KV; +#pragma unroll + for (uint32_t i = 0; + i < NUM_MMA_KV * (SWIZZLE_MODE_KV == SwizzleMode::k128B ? 4 : 2) / NUM_WARPS_Q; ++i) { + uint32_t page_iter, entry_idx; + paged_kv.page_size.divmod(packed_page_iter_base + warp_idx * KV_THR_LAYOUT_ROW + + lane_idx / KV_THR_LAYOUT_COL + + KV_THR_LAYOUT_ROW * NUM_WARPS_Q * NUM_WARPS_KV * i, + page_iter, entry_idx); + thr_local_kv_offset[i] = paged_kv.protective_get_kv_offset( + page_iter, kv_head_idx, entry_idx, + (lane_idx % KV_THR_LAYOUT_COL) * upcast_size(), last_indptr); + } + memory::wait_group<1>(); + block.sync(); + + if constexpr (KTraits::POS_ENCODING_MODE == PosEncodingMode::kRoPELlama) { + k_smem_inplace_apply_rotary( + (paged_kv.rope_pos_offset == nullptr ? 0 : paged_kv.rope_pos_offset[request_idx]) + + chunk_start + iter * CTA_TILE_KV, + &k_smem, &k_smem_offset_r, rope_freq, tid); + block.sync(); + } + + // compute attention score + compute_qk(&qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag); + + logits_transform( + params, variant, /*batch_idx=*/request_idx, qo_packed_idx_base, + chunk_start + (iter * NUM_WARPS_KV + get_warp_idx_kv(tid.z)) * NUM_MMA_KV * 16, + qo_len, kv_len, group_size, s_frag, tid, kv_head_idx); + + // apply mask + if (MASK_MODE == MaskMode::kCustom || (iter >= mask_iteration || iter < window_iteration)) { + logits_mask( + params, variant, /*batch_idx=*/request_idx, qo_packed_idx_base, + chunk_start + (iter * NUM_WARPS_KV + get_warp_idx_kv(tid.z)) * NUM_MMA_KV * 16, + qo_len, kv_len, chunk_end, group_size, s_frag, tid, kv_head_idx); + } + + // compute m,d states in online softmax + update_mdo_states(variant, s_frag, o_frag, m, d); + + block.sync(); + page_produce_kv(k_smem, &k_smem_offset_w, paged_kv, (iter + 1) * CTA_TILE_KV, + thr_local_kv_offset, chunk_size, tid); + memory::commit_group(); + memory::wait_group<1>(); + block.sync(); + + // compute sfm*v + compute_sfm_v(&v_smem, &v_smem_offset_r, s_frag, o_frag, d); + + block.sync(); + page_produce_kv(v_smem, &v_smem_offset_w, paged_kv, (iter + 1) * CTA_TILE_KV, + thr_local_kv_offset, chunk_size, tid); + memory::commit_group(); + } + memory::wait_group<0>(); + block.sync(); + + finalize_m(variant, m); + + // threadblock synchronization + threadblock_sync_mdo_states(o_frag, &smem_storage, m, d, warp_idx, lane_idx, tid); + + // normalize d + normalize_d(o_frag, m, d); + + const uint32_t num_kv_chunks = (kv_len_safe + kv_chunk_size - 1) / kv_chunk_size; + + // write_back + write_o_reg_gmem(o_frag, &qo_smem, o_ptr_base, qo_packed_idx_base, qo_len, + /*o_stride_n=*/ + partition_kv ? num_kv_chunks * o_stride_n : o_stride_n, + /*o_stride_h=*/o_stride_h, group_size, tid); + + // write lse + if constexpr (variant.use_softmax) { + if (lse != nullptr) { + if (get_warp_idx_kv(tid.z) == 0) { +#pragma unroll + for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + uint32_t q, r; + group_size.divmod(qo_packed_idx_base + lane_idx / 4 + j * 8 + mma_q * 16, q, r); + const uint32_t qo_head_idx = kv_head_idx * group_size + r; + const uint32_t qo_idx = q; + if (qo_idx < qo_upper_bound) { + if (partition_kv) { + lse[(o_indptr[request_idx] + qo_idx * num_kv_chunks + kv_tile_idx) * + num_qo_heads + + qo_head_idx] = gpu_iface::math::ptx_log2(d[mma_q][j]) + float(m[mma_q][j]); + } else { + lse[(o_indptr[request_idx] + qo_idx) * num_qo_heads + qo_head_idx] = + gpu_iface::math::ptx_log2(d[mma_q][j]) + float(m[mma_q][j]); + } + } + } + } + } + } + } +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) + } +#endif +} + +template +__global__ __launch_bounds__(KTraits::NUM_THREADS) void BatchPrefillWithPagedKVCacheKernel( + const __grid_constant__ Params params) { + extern __shared__ uint8_t smem[]; + auto& smem_storage = reinterpret_cast(smem); + BatchPrefillWithPagedKVCacheDevice(params, smem_storage); +} + +template +gpuError_t BatchPrefillWithRaggedKVCacheDispatched(Params params, typename Params::DTypeO* tmp_v, + float* tmp_s, gpuStream_t stream) { + using DTypeQ = typename Params::DTypeQ; + using DTypeKV = typename Params::DTypeKV; + using DTypeO = typename Params::DTypeO; + const uint32_t padded_batch_size = params.padded_batch_size; + const uint32_t num_qo_heads = params.num_qo_heads; + const uint32_t num_kv_heads = params.num_kv_heads; + constexpr uint32_t NUM_MMA_Q = get_num_mma_q(CTA_TILE_Q); + constexpr uint32_t NUM_WARPS_Q = get_num_warps_q(CTA_TILE_Q); + constexpr uint32_t NUM_WARPS_KV = get_num_warps_kv(CTA_TILE_Q); + + if (padded_batch_size == 0) { + // No request, skip + // this won't happen in CUDAGraph mode because we fixed the + // padded_batch_size + return gpuSuccess; + } + + dim3 nblks(padded_batch_size, 1, num_kv_heads); + dim3 nthrs(WARP_SIZE, NUM_WARPS_Q, NUM_WARPS_KV); + constexpr uint32_t NUM_MMA_D_QK = HEAD_DIM_QK / 16; + constexpr uint32_t NUM_MMA_D_VO = HEAD_DIM_VO / 16; + using DTypeQKAccum = + typename std::conditional, half, + float>::type; + + int dev_id = 0; + FI_GPU_CALL(gpuGetDevice(&dev_id)); + int max_smem_per_sm = getMaxSharedMemPerMultiprocessor(dev_id); + // we expect each sm execute two threadblocks + const int num_ctas_per_sm = + max_smem_per_sm >= 2 * (CTA_TILE_Q * HEAD_DIM_QK * sizeof(DTypeQ) + + (HEAD_DIM_QK + HEAD_DIM_VO) * 16 * NUM_WARPS_KV * sizeof(DTypeKV)) + ? 2 + : 1; + const int max_smem_per_threadblock = max_smem_per_sm / num_ctas_per_sm; + + const uint32_t max_num_mma_kv_reg = + (HEAD_DIM_VO >= 128 && NUM_MMA_Q == 2 && POS_ENCODING_MODE == PosEncodingMode::kRoPELlama && + !USE_FP16_QK_REDUCTION) + ? 2 + : (8 / NUM_MMA_Q); + const uint32_t max_num_mma_kv_smem = + (max_smem_per_threadblock - CTA_TILE_Q * HEAD_DIM_QK * sizeof(DTypeQ)) / + ((HEAD_DIM_QK + HEAD_DIM_VO) * 16 * NUM_WARPS_KV * sizeof(DTypeKV)); + + DISPATCH_NUM_MMA_KV(min(max_num_mma_kv_smem, max_num_mma_kv_reg), NUM_MMA_KV, { + using KTraits = + KernelTraits; + if constexpr (KTraits::IsInvalid()) { + // Invalid configuration, skip + std::ostringstream err_msg; + err_msg << "FlashInfer Internal Error: Invalid configuration : " + "NUM_MMA_Q=" + << NUM_MMA_Q << " NUM_MMA_D_QK=" << NUM_MMA_D_QK << " NUM_MMA_D_VO=" << NUM_MMA_D_VO + << " NUM_MMA_KV=" << NUM_MMA_KV << " NUM_WARPS_Q=" << NUM_WARPS_Q + << " NUM_WARPS_KV=" << NUM_WARPS_KV + << " please create an issue " + "(https://github.com/flashinfer-ai/flashinfer/issues)" + " and report the issue to the developers."; + FLASHINFER_ERROR(err_msg.str()); + } else { + size_t smem_size = sizeof(typename KTraits::SharedStorage); + auto kernel = BatchPrefillWithRaggedKVCacheKernel; + FI_GPU_CALL( + gpuFuncSetAttribute(kernel, gpuFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + if (tmp_v == nullptr) { + // do not partition kv + params.partition_kv = false; + void* args[] = {(void*)¶ms}; + FI_GPU_CALL(gpuLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + } else { + // partition kv + params.partition_kv = true; + auto o = params.o; + auto lse = params.lse; + params.o = tmp_v; + params.lse = tmp_s; + void* args[] = {(void*)¶ms}; + FI_GPU_CALL(gpuLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + if constexpr (AttentionVariant::use_softmax) { + FI_GPU_CALL(VariableLengthMergeStates(tmp_v, tmp_s, params.merge_indptr, o, lse, + params.max_total_num_rows, params.total_num_rows, + num_qo_heads, HEAD_DIM_VO, stream)); + } else { + FI_GPU_CALL(VariableLengthAttentionSum(tmp_v, params.merge_indptr, o, + params.max_total_num_rows, params.total_num_rows, + num_qo_heads, HEAD_DIM_VO, stream)); + } + } + } + }); + return gpuSuccess; +} + +template +gpuError_t BatchPrefillWithPagedKVCacheDispatched(Params params, typename Params::DTypeO* tmp_v, + float* tmp_s, gpuStream_t stream) { + using DTypeQ = typename Params::DTypeQ; + using DTypeKV = typename Params::DTypeKV; + using DTypeO = typename Params::DTypeO; + const uint32_t padded_batch_size = params.padded_batch_size; + const uint32_t num_qo_heads = params.num_qo_heads; + const uint32_t num_kv_heads = params.paged_kv.num_heads; + constexpr uint32_t NUM_MMA_Q = get_num_mma_q(CTA_TILE_Q); + constexpr uint32_t NUM_WARPS_Q = get_num_warps_q(CTA_TILE_Q); + constexpr uint32_t NUM_WARPS_KV = get_num_warps_kv(CTA_TILE_Q); + + if (padded_batch_size == 0) { + // No request, skip + // this won't happen in CUDAGraph mode because we fixed the + // padded_batch_size + return gpuSuccess; + } + + dim3 nblks(padded_batch_size, 1, num_kv_heads); + dim3 nthrs(WARP_SIZE, NUM_WARPS_Q, NUM_WARPS_KV); + + constexpr uint32_t NUM_MMA_D_QK = HEAD_DIM_QK / 16; + constexpr uint32_t NUM_MMA_D_VO = HEAD_DIM_VO / 16; + using DTypeQKAccum = + typename std::conditional, half, + float>::type; + + int dev_id = 0; + FI_GPU_CALL(gpuGetDevice(&dev_id)); + int max_smem_per_sm = getMaxSharedMemPerMultiprocessor(dev_id); + // we expect each sm execute two threadblocks + const int num_ctas_per_sm = + max_smem_per_sm >= 2 * (CTA_TILE_Q * HEAD_DIM_QK * sizeof(DTypeQ) + + (HEAD_DIM_QK + HEAD_DIM_VO) * 16 * NUM_WARPS_KV * sizeof(DTypeKV)) + ? 2 + : 1; + const int max_smem_per_threadblock = max_smem_per_sm / num_ctas_per_sm; + + const uint32_t max_num_mma_kv_reg = + (HEAD_DIM_VO >= 128 && NUM_MMA_Q == 2 && POS_ENCODING_MODE == PosEncodingMode::kRoPELlama && + !USE_FP16_QK_REDUCTION) + ? 2 + : (8 / NUM_MMA_Q); + const uint32_t max_num_mma_kv_smem = + (max_smem_per_threadblock - CTA_TILE_Q * HEAD_DIM_QK * sizeof(DTypeQ)) / + ((HEAD_DIM_QK + HEAD_DIM_VO) * 16 * NUM_WARPS_KV * sizeof(DTypeKV)); + + DISPATCH_NUM_MMA_KV(min(max_num_mma_kv_smem, max_num_mma_kv_reg), NUM_MMA_KV, { + using KTraits = + KernelTraits; + if constexpr (KTraits::IsInvalid()) { + // Invalid configuration, skip + std::ostringstream err_msg; + err_msg << "FlashInfer Internal Error: Invalid configuration : " + "NUM_MMA_Q=" + << NUM_MMA_Q << " NUM_MMA_D_QK=" << NUM_MMA_D_QK << " NUM_MMA_D_VO=" << NUM_MMA_D_VO + << " NUM_MMA_KV=" << NUM_MMA_KV << " NUM_WARPS_Q=" << NUM_WARPS_Q + << " NUM_WARPS_KV=" << NUM_WARPS_KV + << " please create an issue " + "(https://github.com/flashinfer-ai/flashinfer/issues)" + " and report the issue to the developers."; + FLASHINFER_ERROR(err_msg.str()); + } else { + size_t smem_size = sizeof(typename KTraits::SharedStorage); + auto kernel = BatchPrefillWithPagedKVCacheKernel; + FI_GPU_CALL( + gpuFuncSetAttribute(kernel, gpuFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + if (tmp_v == nullptr) { + // do not partition kv + params.partition_kv = false; + void* args[] = {(void*)¶ms}; + FI_GPU_CALL(gpuLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + } else { + params.partition_kv = true; + auto o = params.o; + auto lse = params.lse; + params.o = tmp_v; + params.lse = tmp_s; + void* args[] = {(void*)¶ms}; + FI_GPU_CALL(gpuLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + if constexpr (AttentionVariant::use_softmax) { + FI_GPU_CALL(VariableLengthMergeStates(tmp_v, tmp_s, params.merge_indptr, o, lse, + params.max_total_num_rows, params.total_num_rows, + num_qo_heads, HEAD_DIM_VO, stream)); + } else { + FI_GPU_CALL(VariableLengthAttentionSum(tmp_v, params.merge_indptr, o, + params.max_total_num_rows, params.total_num_rows, + num_qo_heads, HEAD_DIM_VO, stream)); + } + } + } + }); + return gpuSuccess; +} + +} // namespace flashinfer + +#endif // FLASHINFER_PREFILL_CUH_ diff --git a/libflashinfer/include/gpu_iface/backend/hip/mma_debug_utils_hip.h b/libflashinfer/include/gpu_iface/backend/hip/mma_debug_utils_hip.h index f8bc7dd1d2..16a9d610ec 100644 --- a/libflashinfer/include/gpu_iface/backend/hip/mma_debug_utils_hip.h +++ b/libflashinfer/include/gpu_iface/backend/hip/mma_debug_utils_hip.h @@ -125,7 +125,7 @@ __device__ void print_lds_array(float* lds_array, uint32_t dimY, uint32_t dimX, printf("%s (%dx%d):\n", title, dimX, dimY); for (int y = 0; y < dimY; ++y) { for (int x = 0; x < dimX; ++x) { - printf("%8.3f ", lds_array[y * dimX + x]); + printf("%10.6f ", lds_array[y * dimX + x]); } printf("\n"); } @@ -134,36 +134,52 @@ __device__ void print_lds_array(float* lds_array, uint32_t dimY, uint32_t dimX, __syncthreads(); } -/// @brief Materializes a 2D array of accumulator fragments from each thread's registers into a -/// 2D shared memory array. -/// @details This function is the inverse of the hardware's distribution of accumulator results. -/// It reconstructs a logical tile of the S = Q * K^T matrix in shared memory, -/// accounting for the partitioning of work across multiple warps. +/// @brief Prints a 1D LDS array of floats to the console from a single thread. +/// @details Useful for printing row-wise statistics like m or d values. +__device__ void print_lds_array_1d(float* lds_array, uint32_t dim, + const char* title = "LDS Array 1D (float)") { + if (threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0) { + printf("%s (%d elements):\n", title, dim); + for (int i = 0; i < dim; ++i) { + printf("%10.6f ", lds_array[i]); + if ((i + 1) % 16 == 0) printf("\n"); // Line break every 16 elements + } + if (dim % 16 != 0) printf("\n"); + printf("\n"); + } + __syncthreads(); +} + +/// @brief Generic function to materialize 2D fragment arrays into shared memory. +/// @details Works for both s_frag (attention scores) and o_frag (output accumulator). +/// Reconstructs a logical tile from distributed register fragments. /// @tparam T The data type of the fragments and LDS array (e.g., float or half). -/// @tparam NUM_MMA_Q The number of fragments along the Q dimension (rows) per thread. -/// @tparam NUM_MMA_KV The number of fragments along the KV dimension (columns) per thread. -/// @tparam ELEMS_PER_FRAGMENT The number of elements per fragment (typically 4 for float/half). -/// @param s_frag The 3D fragment array from the thread's registers. +/// @tparam NUM_MMA_ROW The number of fragments along the rows dimension per thread. +/// @tparam NUM_MMA_COL The number of fragments along the column dimension per thread. +/// For s_frag: NUM_MMA_KV (KV sequence length) +/// For o_frag: NUM_MMA_D_VO (head dimension) +/// @tparam ELEMS_PER_FRAGMENT The number of elements per fragment (typically 4). +/// @param frag The 3D fragment array from the thread's registers. /// @param lds_scratchpad Pointer to the shared memory array. -/// @param lds_stride The width/stride of the lds_scratchpad (e.g., CTA_TILE_KV). +/// @param lds_stride The width/stride of the lds_scratchpad. /// @param tid The thread's index within the block (threadIdx). -template -__device__ void write_s_frag_to_lds(const T (*s_frag)[NUM_MMA_KV][ELEMS_PER_FRAGMENT], - T* lds_scratchpad, const uint32_t lds_stride, - const dim3 tid = threadIdx) { +template +__device__ void write_frag_to_lds(const T (*frag)[NUM_MMA_COL][ELEMS_PER_FRAGMENT], + T* lds_scratchpad, const uint32_t lds_stride, + const dim3 tid = threadIdx) { const int lane_id = tid.x % 64; const int warp_idx_q = tid.y; // Calculate the starting row in the LDS tile for this entire warp. - const uint32_t warp_base_row = warp_idx_q * NUM_MMA_Q * MMA_COLS; + const uint32_t warp_base_row = warp_idx_q * NUM_MMA_ROW * MMA_COLS; #pragma unroll - for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { + for (uint32_t mma_q = 0; mma_q < NUM_MMA_ROW; ++mma_q) { #pragma unroll - for (uint32_t mma_kv = 0; mma_kv < NUM_MMA_KV; ++mma_kv) { + for (uint32_t mma_col = 0; mma_col < NUM_MMA_COL; ++mma_col) { // -- Calculate the top-left corner of the 16x16 fragment this thread contributes to -- const uint32_t frag_row_offset = mma_q * MMA_COLS; - const uint32_t frag_col_offset = mma_kv * MMA_COLS; + const uint32_t frag_col_offset = mma_col * MMA_COLS; // -- Calculate the specific 4x1 element strip this thread writes within that fragment -- // This logic correctly materializes a B-layout fragment (column strip). @@ -174,7 +190,7 @@ __device__ void write_s_frag_to_lds(const T (*s_frag)[NUM_MMA_KV][ELEMS_PER_FRAG const uint32_t thread_col_in_frag = (lane_id % MMA_COLS); // -- Combine all offsets and write the 4x1 column strip to LDS -- - const T* values = s_frag[mma_q][mma_kv]; + const T* values = frag[mma_q][mma_col]; for (int i = 0; i < MMA_ROWS_PER_THREAD; ++i) { // The row for this element is the thread's starting row + the element's index in the strip. const uint32_t final_row = warp_base_row + frag_row_offset + thread_start_row_in_frag + i; @@ -189,13 +205,40 @@ __device__ void write_s_frag_to_lds(const T (*s_frag)[NUM_MMA_KV][ELEMS_PER_FRAG } } +/// @brief Convenience wrapper for s_frag (attention scores). +template +__device__ void write_s_frag_to_lds(const T (*s_frag)[NUM_MMA_KV][ELEMS_PER_FRAGMENT], + T* lds_scratchpad, const uint32_t lds_stride, + const dim3 tid = threadIdx) { + write_frag_to_lds(s_frag, lds_scratchpad, + lds_stride, tid); +} + +/// @brief Convenience wrapper for o_frag (output accumulator). +template +__device__ void write_o_frag_to_lds(const T (*o_frag)[NUM_MMA_D_VO][ELEMS_PER_FRAGMENT], + T* lds_scratchpad, const uint32_t lds_stride, + const dim3 tid = threadIdx) { + write_frag_to_lds(o_frag, lds_scratchpad, + lds_stride, tid); +} + +/// @brief Generic function to materialize 1D row-wise values (m or d) into shared memory. +/// @details Writes row-wise statistics (like max or denominator) from register arrays +/// to a 1D shared memory array, with one value per row. +/// @tparam T The data type (typically float). +/// @tparam NUM_MMA_Q The number of fragments along the Q dimension per thread. +/// @tparam NUM_ACCUM_ROWS_PER_THREAD The number of accumulator rows per thread (typically 4). +/// @param values The 2D array from registers [NUM_MMA_Q][NUM_ACCUM_ROWS_PER_THREAD]. +/// @param lds_scratchpad Pointer to the 1D shared memory array. +/// @param tid The thread's index within the block (threadIdx). template -__device__ void write_m_new_to_lds(const T (*m)[NUM_ACCUM_ROWS_PER_THREAD], T* lds_scratchpad, - const dim3 tid = threadIdx) { +__device__ void write_row_values_to_lds(const T (*values)[NUM_ACCUM_ROWS_PER_THREAD], + T* lds_scratchpad, const dim3 tid = threadIdx) { const int lane_idx = tid.x; const int warp_idx_q = tid.y; - // Each group of 16 threads (a "row group") computes the max for 4 rows. + // Each group of 16 threads (a "row group") handles 4 rows. // We only need one thread from each group to write the results. if (lane_idx % MMA_COLS == 0) { // Base row index for this warp's Q tile @@ -212,13 +255,34 @@ __device__ void write_m_new_to_lds(const T (*m)[NUM_ACCUM_ROWS_PER_THREAD], T* l // e.g., lane 0 is in group 0, lane 16 is in group 1, etc. const uint32_t row_group_offset = (lane_idx / MMA_COLS) * NUM_ACCUM_ROWS_PER_THREAD; - // The final row index in the logical S matrix + // The final row index in the logical matrix const uint32_t final_row_idx = warp_base_row + mma_base_row + row_group_offset + j; - lds_scratchpad[final_row_idx] = m[mma_q][j]; + lds_scratchpad[final_row_idx] = values[mma_q][j]; } } } } +/// @brief Convenience wrapper for m (row-wise max) values. +template +__device__ void write_m_to_lds(const T (*m)[NUM_ACCUM_ROWS_PER_THREAD], T* lds_scratchpad, + const dim3 tid = threadIdx) { + write_row_values_to_lds(m, lds_scratchpad, tid); +} + +/// @brief Convenience wrapper for d (denominator) values. +template +__device__ void write_d_to_lds(const T (*d)[NUM_ACCUM_ROWS_PER_THREAD], T* lds_scratchpad, + const dim3 tid = threadIdx) { + write_row_values_to_lds(d, lds_scratchpad, tid); +} + +// Legacy alias for backward compatibility +template +__device__ void write_m_new_to_lds(const T (*m)[NUM_ACCUM_ROWS_PER_THREAD], T* lds_scratchpad, + const dim3 tid = threadIdx) { + write_m_to_lds(m, lds_scratchpad, tid); +} + } // namespace flashinfer::gpu_iface::debug_utils::hip diff --git a/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h b/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h index 1316e454ce..1cc9ea8863 100644 --- a/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h +++ b/libflashinfer/include/gpu_iface/backend/hip/mma_hip.h @@ -111,6 +111,24 @@ __device__ __forceinline__ void transpose_inter_quad_fragments(uint32_t* R) { R[1] = __shfl_xor(R[1], xor_mask, 64); } +/// @brief Performs a full 16x16 in-register matrix transpose by combining intra-quad and +/// inter-quad fragment transpositions. +/// @details This function converts between A-matrix layout (row-major) and B/C/D-matrix layout +/// (column-major) for CDNA3 MFMA operations. It applies both +/// transpose_intra_quad_fragments and transpose_inter_quad_fragments to fully transpose a +/// 16x16 tile distributed across 64 threads. +/// +/// Use cases: +/// - B→A layout: Convert column slices to row slices (e.g., for rowsum where S must be +/// A-matrix) +/// - A→B layout: Convert row slices to column slices (if needed for other operations) +/// +/// @param R Pointer to 2 uint32_t registers containing the fragment data +__device__ __forceinline__ void transpose_mma_tile(uint32_t* R) { + transpose_intra_quad_fragments(R); + transpose_inter_quad_fragments(R); +} + // Single unified load function for all fragment types /// @param R [in] pointer to the register file to load the fragment into /// @param smem_ptr [in] pointer to the shared memory to load the fragment from @@ -191,7 +209,6 @@ __device__ __forceinline__ void load_quad_transposed_fragment(uint32_t* R, const template __device__ __forceinline__ void m16k16_rowsum_f16f16f32(float* d, DType* s_frag) { static_assert(sizeof(DType) == 2, "DType must be 16-bit type"); - transpose_intra_quad_fragments(reinterpret_cast(s_frag)); f16x4 a = reinterpret_cast(s_frag)[0]; f16x4 b = {f16(1.0f), f16(1.0f), f16(1.0f), f16(1.0f)}; f32x4 c = {d[0], d[1], d[2], d[3]}; diff --git a/libflashinfer/include/gpu_iface/mma_ops.hpp b/libflashinfer/include/gpu_iface/mma_ops.hpp index b015b116a7..78264ac6e2 100644 --- a/libflashinfer/include/gpu_iface/mma_ops.hpp +++ b/libflashinfer/include/gpu_iface/mma_ops.hpp @@ -40,6 +40,16 @@ __device__ __forceinline__ void load_quad_transposed_fragment(uint32_t* R, const "Only __half is supported for load_quad_transposed_fragment"); mma_detail::load_quad_transposed_fragment(R, smem_ptr); } + +/*! + * \brief Performs a full 16x16 in-register matrix transpose for CDNA3 MFMA tiles + * \details Converts between A-matrix layout (row-major) and B/C/D-matrix layout (column-major) + * by combining intra-quad and inter-quad fragment transpositions. + * \param R Pointer to 2 uint32_t registers containing the fragment data + */ +__device__ __forceinline__ void transpose_mma_tile(uint32_t* R) { + mma_detail::transpose_mma_tile(R); +} #endif /*! diff --git a/libflashinfer/tests/hip/test_k_smem_read_pattern.cpp b/libflashinfer/tests/hip/test_k_smem_read_pattern.cpp new file mode 100644 index 0000000000..748f042081 --- /dev/null +++ b/libflashinfer/tests/hip/test_k_smem_read_pattern.cpp @@ -0,0 +1,181 @@ +#include + +#include +#include +#include +#include + +// Constants for MI300 +constexpr uint32_t WARP_SIZE = 64; // 64 threads per wavefront +constexpr uint32_t HALF_ELEMS_PER_THREAD = 4; // Each thread processes 4 half elements +constexpr uint32_t INT32_ELEMS_PER_THREAD = 2; // 2 int32 registers per thread + +// Simplified linear shared memory operations (CPU implementation) +template +uint32_t get_permuted_offset_linear(uint32_t row, uint32_t col) { + return row * stride + col; +} + +template +uint32_t advance_offset_by_column_linear(uint32_t offset, uint32_t step_idx) { + return offset + step_size; +} + +template +uint32_t advance_offset_by_row_linear(uint32_t offset) { + return offset + step_size * row_stride; +} + +// CPU-based simulation of k-matrix access pattern in compute_qk +template +void SimulateKReadPattern(std::vector& thread_ids_reading_offsets) { + // Constants derived from HEAD_DIM + constexpr uint32_t UPCAST_STRIDE_K = HEAD_DIM / HALF_ELEMS_PER_THREAD; + constexpr uint32_t NUM_MMA_D_QK = HEAD_DIM / 16; + constexpr uint32_t grid_width = HEAD_DIM / HALF_ELEMS_PER_THREAD; + constexpr uint32_t grid_height = 16 * NUM_MMA_KV; + + constexpr uint32_t K_SMEM_COLUMN_ADVANCE = 16 / HALF_ELEMS_PER_THREAD; // = 4 for MI300 + + // Initialize with -1 (unread) + thread_ids_reading_offsets.assign(grid_height * grid_width, -1); + + // Simulate each thread's read pattern + for (uint32_t tid = 0; tid < WARP_SIZE; tid++) { + // Map tid to kernel's lane_idx + uint32_t lane_idx = tid; + uint32_t warp_idx_kv = 0; // For simplicity, assuming one warp group + + // Exactly match the kernel's initial offset calculation - MI300 version + uint32_t k_smem_offset_r = get_permuted_offset_linear( + warp_idx_kv * NUM_MMA_KV * 16 + 4 * (lane_idx / 16) + lane_idx % 4, (lane_idx % 16) / 4); + + // uint32_t k_smem_offset_r = + // get_permuted_offset_linear( + // warp_idx_kv * NUM_MMA_KV * 16 + + // 4 * (lane_idx / 16), + // (lane_idx % 16)); + + // Follow the same loop structure as in compute_qk + for (uint32_t mma_d = 0; mma_d < NUM_MMA_D_QK; ++mma_d) { + for (uint32_t mma_kv = 0; mma_kv < NUM_MMA_KV; ++mma_kv) { + // Mark grid positions accessed by ldmatrix_m8n8x4 / + // load_fragment + uint32_t read_row = k_smem_offset_r / UPCAST_STRIDE_K; + uint32_t read_col = k_smem_offset_r % UPCAST_STRIDE_K; + + if (tid == 0) { + std::cout << "Thread " << tid << " k_smem_offset_r " << k_smem_offset_r << '\n'; + } + + // Simulate loading a matrix fragment + for (uint32_t reg_id = 0; reg_id < INT32_ELEMS_PER_THREAD; reg_id++) { + if (read_row < grid_height && read_col < grid_width) { + thread_ids_reading_offsets[read_row * grid_width + read_col] = tid; + } + + // Each INT32_ELEMS_PER_THREAD register holds 2 half + // elements For simplicity, we're just recording the base + // offset + } + + // Advance to next row, exactly as in compute_qk + k_smem_offset_r = advance_offset_by_row_linear<16, UPCAST_STRIDE_K>(k_smem_offset_r); + } + + // Reset row position and advance to next column section, exactly as + // in compute_qk For MI300, advance by 4 columns (vs 2 for NVIDIA) + k_smem_offset_r = + advance_offset_by_column_linear(k_smem_offset_r, mma_d) - + NUM_MMA_KV * 16 * UPCAST_STRIDE_K; + } + } +} + +// Helper function to run the test with configurable parameters +template +void RunKReadPatternTest() { + constexpr uint32_t grid_width = HEAD_DIM / HALF_ELEMS_PER_THREAD; + constexpr uint32_t grid_height = 16 * NUM_MMA_KV; + + printf( + "\n=== Testing key read pattern with HEAD_DIM = %u, NUM_MMA_KV = %u " + "===\n", + HEAD_DIM, NUM_MMA_KV); + + // Host array to store thread IDs at each offset + std::vector thread_ids(grid_height * grid_width, -1); + + // Run CPU simulation of read pattern + SimulateKReadPattern(thread_ids); + + // Print the grid of thread IDs + printf("Thread IDs reading from each offset (%dx%d grid):\n", grid_height, grid_width); + + // Column headers + printf(" "); + for (int c = 0; c < grid_width; c++) { + printf("%3d ", c); + if (c == 15 && grid_width > 16) printf("| "); // Divider for HEAD_DIM=128 + } + printf("\n +"); + for (int c = 0; c < grid_width; c++) { + printf("----"); + if (c == 15 && grid_width > 16) printf("+"); + } + printf("\n"); + + // Print the grid + for (int r = 0; r < grid_height; r++) { + printf("%2d | ", r); + for (int c = 0; c < grid_width; c++) { + int thread_id = thread_ids[r * grid_width + c]; + if (thread_id >= 0) { + printf("%3d ", thread_id); + } else { + printf(" . "); // Dot for unread positions + } + if (c == 15 && grid_width > 16) printf("| "); // Divider for HEAD_DIM=128 + } + printf("\n"); + } + + // Check for unread positions + int unread = 0; + for (int i = 0; i < grid_height * grid_width; i++) { + if (thread_ids[i] == -1) { + unread++; + } + } + + // Print statistics + printf("\nStatistics:\n"); + printf("- Positions read: %d/%d (%.1f%%)\n", grid_height * grid_width - unread, + grid_height * grid_width, + 100.0f * (grid_height * grid_width - unread) / (grid_height * grid_width)); + printf("- Unread positions: %d/%d (%.1f%%)\n", unread, grid_height * grid_width, + 100.0f * unread / (grid_height * grid_width)); + + // Validate full coverage + EXPECT_EQ(unread, 0) << "Not all positions were read"; +} + +// Tests for different configurations +TEST(MI300KReadPatternTest, HeadDim64_NumMmaKV1) { RunKReadPatternTest<64, 1>(); } + +// TEST(MI300KReadPatternTest, HeadDim128_NumMmaKV1) { +// RunKReadPatternTest<128, 1>(); +// } + +// TEST(MI300KReadPatternTest, HeadDim64_NumMmaKV2) { +// RunKReadPatternTest<64, 2>(); +// } + +// TEST(MI300KReadPatternTest, HeadDim128_NumMmaKV2) { +// RunKReadPatternTest<128, 2>(); +// } + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/libflashinfer/tests/hip/test_produce_kv_kernel.cpp b/libflashinfer/tests/hip/test_produce_kv_kernel.cpp new file mode 100644 index 0000000000..e69de29bb2 diff --git a/libflashinfer/tests/hip/test_q_smem_read_pattern.cpp b/libflashinfer/tests/hip/test_q_smem_read_pattern.cpp new file mode 100644 index 0000000000..61fb5b1016 --- /dev/null +++ b/libflashinfer/tests/hip/test_q_smem_read_pattern.cpp @@ -0,0 +1,156 @@ +#include + +#include +#include +#include +#include + +// Constants for MI300 +constexpr uint32_t WARP_STEP_SIZE = 16; // 16 threads per warp row +constexpr uint32_t QUERY_ELEMS_PER_THREAD = 4; +constexpr uint32_t WARP_THREAD_ROWS = 4; // 4 rows of threads in a warp + +// Simplified linear shared memory operations (CPU implementation) +template +uint32_t get_permuted_offset_linear(uint32_t row, uint32_t col) { + return row * stride + col; +} + +template +uint32_t advance_offset_by_column_linear(uint32_t offset, uint32_t step_idx) { + return offset + step_size; +} + +template +uint32_t advance_offset_by_row_linear(uint32_t offset) { + return offset + step_size * row_stride; +} + +// CPU-based simulation of the read pattern in compute_qk +template +void SimulateReadPattern(std::vector& thread_ids_reading_offsets) { + // Constants derived from HEAD_DIM + constexpr uint32_t UPCAST_STRIDE_Q = HEAD_DIM / QUERY_ELEMS_PER_THREAD; + constexpr uint32_t NUM_MMA_D_QK = HEAD_DIM / 16; + constexpr uint32_t grid_width = (HEAD_DIM / QUERY_ELEMS_PER_THREAD); // 16 for 64, 32 for 128 + constexpr uint32_t grid_height = 16 * NUM_MMA_Q; // 16 for NUM_MMA_Q=1, 32 for NUM_MMA_Q=2 + + // Initialize with -1 (unread) + thread_ids_reading_offsets.assign(grid_height * grid_width, -1); + + // Simulate each thread's read pattern + for (uint32_t tid = 0; tid < 64; tid++) { + // Map tid to kernel's lane_idx (same for a single warp) + uint32_t lane_idx = tid; + + // Get warp_idx_q (this is 0 for our single warp simulation) + uint32_t warp_idx_q = 0; + + // Exactly match the kernel's initial offset calculation + uint32_t q_smem_offset_r = get_permuted_offset_linear( + warp_idx_q * NUM_MMA_Q * 16 + lane_idx % 16, lane_idx / 16); + + // Follow exactly the same loop structure as in compute_qk + for (uint32_t mma_d = 0; mma_d < NUM_MMA_D_QK; ++mma_d) { + for (uint32_t mma_q = 0; mma_q < NUM_MMA_Q; ++mma_q) { + // This would be a ldmatrix_m8n8x4 call in the actual code + uint32_t read_row = q_smem_offset_r / UPCAST_STRIDE_Q; + uint32_t read_col = q_smem_offset_r % UPCAST_STRIDE_Q; + + if (read_row < grid_height && read_col < grid_width) { + thread_ids_reading_offsets[read_row * grid_width + read_col] = tid; + } + + // Advance to next row, exactly as in compute_qk + q_smem_offset_r = advance_offset_by_row_linear<16, UPCAST_STRIDE_Q>(q_smem_offset_r); + } + + // Reset row position and advance to next column, exactly as in + // compute_qk + q_smem_offset_r = advance_offset_by_column_linear<4>(q_smem_offset_r, mma_d) - + NUM_MMA_Q * 16 * UPCAST_STRIDE_Q; + } + } +} + +// Helper function to run the test with configurable NUM_MMA_Q +template +void RunReadPatternTest() { + constexpr uint32_t grid_width = (HEAD_DIM / QUERY_ELEMS_PER_THREAD); // 16 for 64, 32 for 128 + constexpr uint32_t grid_height = 16 * NUM_MMA_Q; // 16 for NUM_MMA_Q=1, 32 for NUM_MMA_Q=2 + + printf( + "\n=== Testing query read pattern with HEAD_DIM = %u, NUM_MMA_Q = " + "%u ===\n", + HEAD_DIM, NUM_MMA_Q); + + // Host array to store thread IDs at each offset + std::vector thread_ids(grid_height * grid_width, -1); + + // Run CPU simulation of read pattern + SimulateReadPattern(thread_ids); + + // Print the grid of thread IDs + printf("Thread IDs reading from each offset (%dx%d grid):\n", grid_height, grid_width); + + // Column headers + printf(" "); + for (int c = 0; c < grid_width; c++) { + printf("%3d ", c); + if (c == 15 && grid_width > 16) printf("| "); // Divider for HEAD_DIM=128 + } + printf("\n +"); + for (int c = 0; c < grid_width; c++) { + printf("----"); + if (c == 15 && grid_width > 16) printf("+"); + } + printf("\n"); + + // Print the grid + for (int r = 0; r < grid_height; r++) { + printf("%2d | ", r); + for (int c = 0; c < grid_width; c++) { + int thread_id = thread_ids[r * grid_width + c]; + if (thread_id >= 0) { + printf("%3d ", thread_id); + } else { + printf(" . "); // Dot for unread positions + } + if (c == 15 && grid_width > 16) printf("| "); // Divider for HEAD_DIM=128 + } + printf("\n"); + } + + // Check for unread positions + int unread = 0; + for (int i = 0; i < grid_height * grid_width; i++) { + if (thread_ids[i] == -1) { + unread++; + } + } + + // Print statistics + printf("\nStatistics:\n"); + printf("- Positions read: %d/%d (%.1f%%)\n", grid_height * grid_width - unread, + grid_height * grid_width, + 100.0f * (grid_height * grid_width - unread) / (grid_height * grid_width)); + printf("- Unread positions: %d/%d (%.1f%%)\n", unread, grid_height * grid_width, + 100.0f * unread / (grid_height * grid_width)); + + // Validate full coverage + EXPECT_EQ(unread, 0) << "Not all positions were read"; +} + +// Tests for different configurations +TEST(MI300ReadPatternTest, HeadDim64_NumMmaQ1) { RunReadPatternTest<64, 1>(); } + +TEST(MI300ReadPatternTest, HeadDim128_NumMmaQ1) { RunReadPatternTest<128, 1>(); } + +TEST(MI300ReadPatternTest, HeadDim64_NumMmaQ2) { RunReadPatternTest<64, 2>(); } + +TEST(MI300ReadPatternTest, HeadDim128_NumMmaQ2) { RunReadPatternTest<128, 2>(); } + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/libflashinfer/tests/hip/test_single_prefill.cpp b/libflashinfer/tests/hip/test_single_prefill.cpp new file mode 100644 index 0000000000..d91d4c7ca4 --- /dev/null +++ b/libflashinfer/tests/hip/test_single_prefill.cpp @@ -0,0 +1,618 @@ +// SPDX - FileCopyrightText : 2023 - 2025 Flashinfer team +// SPDX - FileCopyrightText : 2025 Advanced Micro Devices, Inc. +// +// SPDX - License - Identifier : Apache 2.0 + +#include + +#include + +#include "../../utils/cpu_reference_hip.h" +#include "../../utils/flashinfer_prefill_ops.hip.h" +#include "../../utils/utils_hip.h" +#include "flashinfer/attention/generic/prefill.cuh" +#include "gpu_iface/gpu_runtime_compat.hpp" + +#define HIP_ENABLE_WARP_SYNC_BUILTINS 1 + +using namespace flashinfer; + +#if 0 +template +void _TestComputeQKCorrectness(size_t qo_len, + size_t kv_len, + size_t num_qo_heads, + size_t num_kv_heads, + size_t head_dim, + bool causal, + QKVLayout kv_layout, + PosEncodingMode pos_encoding_mode, + bool use_fp16_qk_reduction, + float rtol = 1e-3, + float atol = 1e-3) +{ + // std::cout << "Testing compute_qk: qo_len=" << qo_len + // << ", kv_len=" << kv_len << ", num_qo_heads=" << num_qo_heads + // << ", num_kv_heads=" << num_kv_heads << ", head_dim=" << head_dim + // << std::endl; + + // Generate test data (same as original test) + std::vector q(qo_len * num_qo_heads * head_dim); + std::vector k(kv_len * num_kv_heads * head_dim); + std::vector v(kv_len * num_kv_heads * + head_dim); // Still needed for params + std::vector o(qo_len * num_qo_heads * + head_dim); // Still needed for params + + utils::vec_normal_(q); + utils::vec_normal_(k); + utils::vec_normal_(v); // Initialize even though we won't use it + utils::vec_zero_(o); + + // GPU memory allocation (same pattern as original) + DTypeQ *q_d; + FI_GPU_CALL(hipMalloc(&q_d, q.size() * sizeof(DTypeQ))); + FI_GPU_CALL(hipMemcpy(q_d, q.data(), q.size() * sizeof(DTypeQ), + hipMemcpyHostToDevice)); + + DTypeKV *k_d; + FI_GPU_CALL(hipMalloc(&k_d, k.size() * sizeof(DTypeKV))); + FI_GPU_CALL(hipMemcpy(k_d, k.data(), k.size() * sizeof(DTypeKV), + hipMemcpyHostToDevice)); + + DTypeKV *v_d; + FI_GPU_CALL(hipMalloc(&v_d, v.size() * sizeof(DTypeKV))); + FI_GPU_CALL(hipMemcpy(v_d, v.data(), v.size() * sizeof(DTypeKV), + hipMemcpyHostToDevice)); + + DTypeO *o_d; + FI_GPU_CALL(hipMalloc(&o_d, o.size() * sizeof(DTypeO))); + FI_GPU_CALL(hipMemcpy(o_d, o.data(), o.size() * sizeof(DTypeO), + hipMemcpyHostToDevice)); + + DTypeO *tmp_d; + FI_GPU_CALL(hipMalloc(&tmp_d, 16 * 1024 * 1024 * sizeof(DTypeO))); + + // Allocate output buffer for QK scores + const size_t qk_output_size = qo_len * kv_len * num_qo_heads; + float *qk_scores_d; + FI_GPU_CALL(hipMalloc(&qk_scores_d, qk_output_size * sizeof(float))); + + // std::cout << "Debug: Kernel launch parameters:" << std::endl; + // std::cout << " qo_len=" << qo_len << ", kv_len=" << kv_len << std::endl; + // std::cout << " num_qo_heads=" << num_qo_heads + // << ", num_kv_heads=" << num_kv_heads << std::endl; + // std::cout << " head_dim=" << head_dim << std::endl; + // std::cout << " qk_output_size=" << qk_output_size << std::endl; + // std::cout << " Launching ComputeQKStubCaller..." << std::endl; + + // Call ComputeQKStubCaller instead of SinglePrefillWithKVCache + hipError_t status = + flashinfer::ComputeQKStubCaller( + q_d, k_d, v_d, o_d, tmp_d, + /*lse=*/nullptr, qk_scores_d, // Add qk_scores_d parameter + num_qo_heads, num_kv_heads, qo_len, kv_len, head_dim, causal, + kv_layout, pos_encoding_mode, use_fp16_qk_reduction); + + // std::cout << " Kernel launch status: " << hipGetErrorString(status) + // << std::endl; + EXPECT_EQ(status, hipSuccess) + << "ComputeQKStubCaller kernel launch failed, error message: " + << hipGetErrorString(status); + + // Get GPU QK scores + std::vector gpu_qk_scores(qk_output_size); + FI_GPU_CALL(hipMemcpy(gpu_qk_scores.data(), qk_scores_d, + qk_output_size * sizeof(float), + hipMemcpyDeviceToHost)); + + // Check if GPU output is not empty + bool isEmpty = gpu_qk_scores.empty(); + EXPECT_EQ(isEmpty, false) << "GPU QK scores vector is empty"; + + // Compute CPU reference using our cpu_reference::compute_qk + std::vector cpu_qk_scores = cpu_reference::compute_qk( + q, k, qo_len, kv_len, num_qo_heads, num_kv_heads, head_dim, kv_layout); + + // Validate results (same pattern as original test) + size_t num_results_error_atol = 0; + bool nan_detected = false; + + // Compare element-by-element + size_t comparison_size = + std::min(gpu_qk_scores.size(), cpu_qk_scores.size()); + for (size_t i = 0; i < comparison_size; ++i) { + float gpu_val = gpu_qk_scores[i]; + float cpu_val = cpu_qk_scores[i]; + + if (isnan(gpu_val)) { + nan_detected = true; + } + + if (!utils::isclose(cpu_val, gpu_val, rtol, atol)) { + num_results_error_atol++; + if (num_results_error_atol <= 10) + { // Only print first 10 mismatches + std::cout << "QK mismatch at i=" << i << ", cpu_val=" << cpu_val + << ", gpu_val=" << gpu_val << std::endl; + } + } + } +#if 0 + // Calculate and report accuracy + float result_accuracy = + 1.0f - float(num_results_error_atol) / float(comparison_size); + + std::cout << "compute_qk results: num_qo_heads=" << num_qo_heads + << ", num_kv_heads=" << num_kv_heads << ", qo_len=" << qo_len + << ", kv_len=" << kv_len << ", head_dim=" << head_dim + << ", causal=" << causal + << ", kv_layout=" << QKVLayoutToString(kv_layout) + << ", pos_encoding_mode=" + << PosEncodingModeToString(pos_encoding_mode) + << ", qk_accuracy=" << result_accuracy << " (" + << num_results_error_atol << "/" << comparison_size + << " mismatches)" << std::endl; + + // Print some sample values for debugging + std::cout << "Sample QK scores (first 10): GPU vs CPU" << std::endl; + for (size_t i = 0; i < std::min(size_t(10), comparison_size); ++i) { + std::cout << " [" << i << "] GPU=" << gpu_qk_scores[i] + << ", CPU=" << cpu_qk_scores[i] << std::endl; + } + + // Assertions (slightly relaxed for initial testing) + EXPECT_GT(result_accuracy, 0.80) + << "compute_qk accuracy too low"; // Start with 80% + EXPECT_FALSE(nan_detected) << "NaN detected in compute_qk results"; +#endif + // Cleanup + FI_GPU_CALL(hipFree(q_d)); + FI_GPU_CALL(hipFree(k_d)); + FI_GPU_CALL(hipFree(v_d)); + FI_GPU_CALL(hipFree(o_d)); + FI_GPU_CALL(hipFree(tmp_d)); + FI_GPU_CALL(hipFree(qk_scores_d)); +} + +#endif + +template +void _TestSinglePrefillKernelCorrectness(size_t qo_len, size_t kv_len, size_t num_qo_heads, + size_t num_kv_heads, size_t head_dim, bool causal, + QKVLayout kv_layout, PosEncodingMode pos_encoding_mode, + bool use_fp16_qk_reduction, uint32_t debug_thread_id, + uint32_t debug_warp_id, float rtol = 1e-3, + float atol = 1e-3) { + std::vector q(qo_len * num_qo_heads * head_dim); + std::vector k(kv_len * num_kv_heads * head_dim); + std::vector v(kv_len * num_kv_heads * head_dim); + std::vector o(qo_len * num_qo_heads * head_dim); + + utils::vec_normal_(q); + utils::vec_normal_(k); + utils::vec_normal_(v); + // utils::vec_lexicographic_(q); + // utils::vec_lexicographic_(k); + // utils::vec_fill_(v, __float2half(1.0f)); + utils::vec_zero_(o); + + DTypeQ* q_d; + FI_GPU_CALL(hipMalloc(&q_d, q.size() * sizeof(DTypeQ))); + FI_GPU_CALL(hipMemcpy(q_d, q.data(), q.size() * sizeof(DTypeQ), hipMemcpyHostToDevice)); + + DTypeKV* k_d; + FI_GPU_CALL(hipMalloc(&k_d, k.size() * sizeof(DTypeKV))); + FI_GPU_CALL(hipMemcpy(k_d, k.data(), k.size() * sizeof(DTypeKV), hipMemcpyHostToDevice)); + + DTypeKV* v_d; + FI_GPU_CALL(hipMalloc(&v_d, v.size() * sizeof(DTypeKV))); + FI_GPU_CALL(hipMemcpy(v_d, v.data(), v.size() * sizeof(DTypeKV), hipMemcpyHostToDevice)); + + DTypeO* o_d; + FI_GPU_CALL(hipMalloc(&o_d, o.size() * sizeof(DTypeO))); + FI_GPU_CALL(hipMemcpy(o_d, o.data(), o.size() * sizeof(DTypeO), hipMemcpyHostToDevice)); + + DTypeO* tmp_d; + FI_GPU_CALL(hipMalloc(&tmp_d, 16 * 1024 * 1024 * sizeof(DTypeO))); + + hipError_t status = flashinfer::SinglePrefillWithKVCache( + q_d, k_d, v_d, o_d, tmp_d, + /*lse=*/nullptr, num_qo_heads, num_kv_heads, qo_len, kv_len, head_dim, causal, kv_layout, + pos_encoding_mode, use_fp16_qk_reduction, debug_thread_id, debug_warp_id); + + EXPECT_EQ(status, hipSuccess) << "SinglePrefillWithKVCache kernel launch failed, error message: " + << hipGetErrorString(status); + + std::vector o_h(o.size()); + FI_GPU_CALL(hipMemcpy(o_h.data(), o_d, o_h.size() * sizeof(DTypeO), hipMemcpyDeviceToHost)); + + // Print the first 10 elements of the output vector for debugging + // std::cout << "Output vector (first 10 elements):"; + // std::cout << "[" << std::endl; + // for (int i = 0; i < 10; ++i) { + // std::cout << fi::con::explicit_casting(o_h[i]) << " + // "; + // } + // std::cout << "]" << std::endl; + + bool isEmpty = o_h.empty(); + EXPECT_EQ(isEmpty, false) << "Output vector is empty"; + + std::vector att_out; + std::vector o_ref = cpu_reference::single_mha( + q, k, v, qo_len, kv_len, num_qo_heads, num_kv_heads, head_dim, causal, kv_layout, + pos_encoding_mode); + size_t num_results_error_atol = 0; + bool nan_detected = false; + + for (size_t i = 0; i < o_ref.size(); ++i) { + float o_h_val = fi::con::explicit_casting(o_h[i]); + float o_ref_val = fi::con::explicit_casting(o_ref[i]); + + if (isnan(o_h_val)) { + nan_detected = true; + } + + num_results_error_atol += (!utils::isclose(o_ref_val, o_h_val, rtol, atol)); + // if (!utils::isclose(o_ref_val, o_h_val, rtol, atol)) { + // std::cout << "i=" << i << ", o_ref[i]=" << o_ref_val + // << ", o_h[i]=" << o_h_val << std::endl; + // } + } + // std::cout<<"Printing att_out vector:\n"; + // for(auto i: att_out) { + // std::cout << i << "\n"; + // } +#if 1 + float result_accuracy = 1. - float(num_results_error_atol) / float(o_ref.size()); + std::cout << "num_qo_heads=" << num_qo_heads << ", num_kv_heads=" << num_kv_heads + << ", qo_len=" << qo_len << ", kv_len=" << kv_len << ", head_dim=" << head_dim + << ", causal=" << causal << ", kv_layout=" << QKVLayoutToString(kv_layout) + << ", pos_encoding_mode=" << PosEncodingModeToString(pos_encoding_mode) + << ", result_accuracy=" << result_accuracy << std::endl; + + EXPECT_GT(result_accuracy, 0.90) << "Result correctness test failed."; + EXPECT_FALSE(nan_detected) << "Nan detected in the result."; +#endif + FI_GPU_CALL(hipFree(q_d)); + FI_GPU_CALL(hipFree(k_d)); + FI_GPU_CALL(hipFree(v_d)); + FI_GPU_CALL(hipFree(o_d)); + FI_GPU_CALL(hipFree(tmp_d)); +} + +// template +// void TestSinglePrefillKernelLongContextCorrectness(bool +// use_fp16_qk_reduction) +// { +// for (size_t qo_len : {1, 31, 63, 127}) { +// for (size_t kv_len : {31717}) { +// for (size_t num_heads : {1}) { +// for (size_t head_dim : {64, 128, 256}) { +// for (bool causal : {false, true}) { +// for (size_t pos_encoding_mode : {0, 1}) { +// for (size_t kv_layout : {0, 1}) { +// _TestSinglePrefillKernelCorrectness< +// DTypeIn, DTypeIn, DTypeO>( +// qo_len, kv_len, num_heads, num_heads, +// head_dim, causal, QKVLayout(kv_layout), +// PosEncodingMode(pos_encoding_mode), +// use_fp16_qk_reduction); +// } +// } +// } +// } +// } +// } +// } +// } +//*********************************************************************** +// The following tests are disabled because we dont support fp8 <-> float +// conversions + +// template +// void TestSinglePrefillFP8KernelLongContextCorrectness(bool +// use_fp16_qk_reduction) { +// for (size_t qo_len : {1, 31, 63, 127}) { +// for (size_t kv_len : {31717}) { +// for (size_t num_heads : {1}) { +// for (size_t head_dim : {64, 128, 256}) { +// for (bool causal : {false, true}) { +// for (size_t pos_encoding_mode : {0}) { +// for (size_t kv_layout : {0, 1}) { +// _TestSinglePrefillKernelCorrectness<__half, DTypeKV, __half>( +// qo_len, kv_len, num_heads, num_heads, head_dim, causal, +// QKVLayout(kv_layout), PosEncodingMode(pos_encoding_mode), +// use_fp16_qk_reduction); +// } +// } +// } +// } +// } +// } +// } +// } + +// template +// void TestSinglePrefillKernelShortContextCorrectness(bool +// use_fp16_qk_reduction) +// { +// float rtol = std::is_same::value ? 1e-2 : 1e-3; +// float atol = std::is_same::value ? 1e-2 : 1e-3; +// for (size_t qkv_len : {2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37}) { +// for (size_t num_qo_heads : {32}) { +// for (size_t num_kv_heads : {4, 8, 32}) { +// for (size_t head_dim : {64, 128, 256}) { +// for (bool causal : {false, true}) { +// for (size_t pos_encoding_mode : {0, 1}) { +// for (size_t kv_layout : {0, 1}) { +// _TestSinglePrefillKernelCorrectness< +// DTypeIn, DTypeIn, DTypeO>( +// qkv_len, qkv_len, num_qo_heads, +// num_kv_heads, head_dim, causal, +// QKVLayout(kv_layout), +// PosEncodingMode(pos_encoding_mode), +// use_fp16_qk_reduction, rtol, atol); +// } +// } +// } +// } +// } +// } +// } +// } + +//*********************************************************************** +// The following tests are disabled because we dont support fp8 <-> float +// conversions + +// template +// void TestSinglePrefillFP8KernelShortContextCorrectness(bool +// use_fp16_qk_reduction) { +// float rtol = 1e-3; +// float atol = 1e-3; +// for (size_t qkv_len : {2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37}) { +// for (size_t num_qo_heads : {32}) { +// for (size_t num_kv_heads : {4, 8, 32}) { +// for (size_t head_dim : {64, 128, 256}) { +// for (bool causal : {false, true}) { +// for (size_t pos_encoding_mode : {0}) { +// for (size_t kv_layout : {0, 1}) { +// _TestSinglePrefillKernelCorrectness<__half, DTypeKV, __half>( +// qkv_len, qkv_len, num_qo_heads, num_kv_heads, head_dim, +// causal, QKVLayout(kv_layout), +// PosEncodingMode(pos_encoding_mode), +// use_fp16_qk_reduction, rtol, atol); +// } +// } +// } +// } +// } +// } +// } +// } + +// template +// void TestSinglePrefillKernelCorrectness(bool use_fp16_qk_reduction) +// { +// for (size_t qo_len : {399, 400, 401}) { +// for (size_t kv_len : {533, 534, 535}) { +// for (size_t num_heads : {12}) { +// for (size_t head_dim : {64, 128, 256}) { +// for (bool causal : {false, true}) { +// for (size_t pos_encoding_mode : {0, 1}) { +// for (size_t kv_layout : {0, 1}) { +// _TestSinglePrefillKernelCorrectness< +// DTypeIn, DTypeIn, DTypeO>( +// qo_len, kv_len, num_heads, num_heads, +// head_dim, causal, QKVLayout(kv_layout), +// PosEncodingMode(pos_encoding_mode), +// use_fp16_qk_reduction); +// } +// } +// } +// } +// } +// } +// } +// } + +// template +// void TestSinglePrefillFP8KernelCorrectness(bool use_fp16_qk_reduction) +// { +// for (size_t qo_len : {399, 400, 401}) { +// for (size_t kv_len : {533, 534, 535}) { +// for (size_t num_heads : {12}) { +// for (size_t head_dim : {64, 128, 256}) { +// for (bool causal : {false, true}) { +// for (size_t pos_encoding_mode : {0}) { +// for (size_t kv_layout : {0, 1}) { +// _TestSinglePrefillKernelCorrectness< +// __half, DTypeKV, __half>( +// qo_len, kv_len, num_heads, num_heads, +// head_dim, causal, QKVLayout(kv_layout), +// PosEncodingMode(pos_encoding_mode), +// use_fp16_qk_reduction); +// } +// } +// } +// } +// } +// } +// } +// } + +// TEST(FlashInferCorrectnessTest, +// TestSinglePrefillKernelLongContextCorrectnessFP16) +// { +// TestSinglePrefillKernelLongContextCorrectness<__half, __half>(false); +// } + +// TEST(FlashInferCorrectnessTest, +// TestSinglePrefillKernelLongContextCorrectnessFP16QKHalfAccum) +// { +// TestSinglePrefillKernelLongContextCorrectness<__half, __half>(true); +// } + +// TEST(FlashInferCorrectnessTest, +// TestSinglePrefillKernelShortContextCorrectnessFP16) +// { +// TestSinglePrefillKernelShortContextCorrectness<__half, __half>(false); +// } + +// TEST(FlashInferCorrectnessTest, +// TestSinglePrefillKernelShortContextCorrectnessFP16QKHalfAccum) +// { +// TestSinglePrefillKernelShortContextCorrectness<__half, __half>(true); +// } + +// TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelCorrectnessTestFP16) +// { +// TestSinglePrefillKernelCorrectness<__half, __half>(false); +// } + +// TEST(FlashInferCorrectnessTest, +// TestSinglePrefillKernelCorrectnessTestFP16QKHalfAccum) +// { +// TestSinglePrefillKernelCorrectness<__half, __half>(true); +// } + +// #ifdef FLASHINFER_ENABLE_BF16 +// TEST(FlashInferCorrectnessTest, +// TestSinglePrefillKernelLongContextCorrectnessBF16) +// { +// TestSinglePrefillKernelLongContextCorrectness<__hip_bfloat16, +// __hip_bfloat16>( +// false); +// } +// TEST(FlashInferCorrectnessTest, +// TestSinglePrefillKernelShortContextCorrectnessBF16) +// { +// TestSinglePrefillKernelShortContextCorrectness<__hip_bfloat16, +// __hip_bfloat16>( +// false); +// } +// TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelCorrectnessTestBF16) +// { +// TestSinglePrefillKernelCorrectness<__hip_bfloat16, +// __hip_bfloat16>(false); +// } +// #endif + +//*********************************************************************** +// The following tests are disabled because we dont support fp8 <-> float +// conversions + +// #ifdef FLASHINFER_ENABLE_FP8_E4M3 +// TEST(FlashInferCorrectnessTest, +// TestSinglePrefillKernelShortContextCorrectnessE4M3) { +// TestSinglePrefillFP8KernelShortContextCorrectness<__nv_fp8_e4m3>(false); +// } +// TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelCorrectnessTestE4M3) { +// TestSinglePrefillFP8KernelCorrectness<__nv_fp8_e4m3>(false); +// } +// TEST(FlashInferCorrectnessTest, +// TestSinglePrefillKernelLongContextCorrectnessE4M3) { +// TestSinglePrefillFP8KernelLongContextCorrectness<__nv_fp8_e4m3>(false); +// } +// #endif + +// #ifdef FLASHINFER_ENABLE_FP8_E5M2 +// TEST(FlashInferCorrectnessTest, +// TestSinglePrefillKernelShortContextCorrectnessE5M2) { +// TestSinglePrefillFP8KernelShortContextCorrectness<__nv_fp8_e5m2>(false); +// } +// TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelCorrectnessTestE5M2) { +// TestSinglePrefillFP8KernelCorrectness<__nv_fp8_e5m2>(false); +// } +// TEST(FlashInferCorrectnessTest, +// TestSinglePrefillKernelLongContextCorrectnessE5M2) { +// TestSinglePrefillFP8KernelLongContextCorrectness<__nv_fp8_e5m2>(false); +// } +// #endif + +int main(int argc, char** argv) { + // ::testing::InitGoogleTest(&argc, argv); + // return RUN_ALL_TESTS(); + using DTypeIn = __half; + using DTypeO = __half; + uint32_t debug_thread_id = 0; + uint32_t debug_warp_id = 0; + bool use_fp16_qk_reduction = false; + size_t qo_len = 128; + size_t kv_len = 128; + size_t num_heads = 1; + size_t head_dim = 64; + bool causal = false; + size_t pos_encoding_mode = 0; // 1 == kRopeLLama + size_t kv_layout = 0; + + for (int i = 1; i < argc; i++) { + std::string arg = argv[i]; + + if (arg == "--thread" && i + 1 < argc) { + debug_thread_id = std::stoi(argv[++i]); + std::cout << "Debug thread ID set to: " << debug_thread_id << std::endl; + } else if (arg == "--warp" && i + 1 < argc) { + debug_warp_id = std::stoi(argv[++i]); + std::cout << "Debug warp ID set to: " << debug_warp_id << std::endl; + } else if (arg == "--qo_len" && i + 1 < argc) { + qo_len = std::stoi(argv[++i]); + } else if (arg == "--kv_len" && i + 1 < argc) { + kv_len = std::stoi(argv[++i]); + } else if (arg == "--heads" && i + 1 < argc) { + num_heads = std::stoi(argv[++i]); + } else if (arg == "--help") { + std::cout << "Usage: " << argv[0] << " [options]\n" + << "Options:\n" + << " --thread Debug thread ID (0-255 for 4 warps)\n" + << " --warp Debug warp ID (0-3 for 4 warps)\n" + << " --qo_len Query/Output length (default: 128)\n" + << " --kv_len Key/Value length (default: 128)\n" + << " --heads Number of heads (default: 1)\n" + << " --help Show this help message\n"; + return 0; + } + } + + _TestSinglePrefillKernelCorrectness( + qo_len, kv_len, num_heads, num_heads, head_dim, causal, QKVLayout(kv_layout), + PosEncodingMode(pos_encoding_mode), use_fp16_qk_reduction, debug_thread_id, debug_warp_id); +} + +// int main(int argc, char **argv) +// { +// // Test compute_qk first with simple parameters +// std::cout << "=== Testing compute_qk function ===" << std::endl; +// using DTypeIn = __half; +// using DTypeO = __half; +// bool use_fp16_qk_reduction = false; +// bool causal = false; +// size_t pos_encoding_mode = 0; +// size_t kv_layout = 0; + +// // Start with small dimensions for easier debugging +// _TestComputeQKCorrectness( +// 16, // qo_len - small for debugging +// 32, // kv_len +// 1, // num_qo_heads - single head +// 1, // num_kv_heads - single head +// 64, // head_dim +// causal, QKVLayout(kv_layout), PosEncodingMode(pos_encoding_mode), +// use_fp16_qk_reduction); + +// std::cout << "\n=== Testing full single prefill ===" << std::endl; +// // Your existing test... +// size_t qo_len = 399; +// size_t kv_len = 533; +// size_t num_heads = 1; +// size_t head_dim = 64; + +// _TestSinglePrefillKernelCorrectness( +// qo_len, kv_len, num_heads, num_heads, head_dim, causal, +// QKVLayout(kv_layout), PosEncodingMode(pos_encoding_mode), +// use_fp16_qk_reduction); + +// return 0; +// } diff --git a/libflashinfer/utils/cpu_reference_hip.h b/libflashinfer/utils/cpu_reference_hip.h index 9703f30657..f502ded1f3 100644 --- a/libflashinfer/utils/cpu_reference_hip.h +++ b/libflashinfer/utils/cpu_reference_hip.h @@ -8,6 +8,9 @@ #include #include +#include +#include + #include "flashinfer/attention/generic/page.cuh" #include "flashinfer/attention/generic/pos_enc.cuh" #include "flashinfer/exception.h" @@ -64,7 +67,7 @@ inline std::vector apply_llama_rope(const T* input, size_t D, size_t offs if (std::is_same_v) rst[k] = cos * fi::con::explicit_casting(input[k]) + sin * permuted_input[k]; } - return std::move(rst); + return rst; } template @@ -73,16 +76,38 @@ std::vector single_mha(const std::vector& q, const std::vect size_t num_qo_heads, size_t num_kv_heads, size_t head_dim, bool causal = true, QKVLayout kv_layout = QKVLayout::kHND, PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, - float rope_scale = 1.f, float rope_theta = 1e4) { + float logits_soft_cap = 8.0f, float rope_scale = 1.f, + float rope_theta = 1e4, bool use_soft_cap = false) { assert(qo_len <= kv_len); assert(num_qo_heads % num_kv_heads == 0); float sm_scale = 1.f / std::sqrt(float(head_dim)); + // float sm_scale = 1.0; std::vector o(qo_len * num_qo_heads * head_dim); std::vector att(kv_len); std::vector q_rotary_local(head_dim); std::vector k_rotary_local(head_dim); + DISPATCH_head_dim(head_dim, HEAD_DIM, { tensor_info_t info(qo_len, kv_len, num_qo_heads, num_kv_heads, kv_layout, HEAD_DIM); +#if Debug1 + std::cout << "DEBUG: Original Q (CPU): " << '\n'; + for (auto i = 0ul; i < 128; ++i) { + for (int j = 0; j < 64; ++j) { + std::cout << (float)q[info.get_q_elem_offset(i, 0, j)] << " "; + } + std::cout << std::endl; + } + std::cout << std::endl; + + std::cout << "DEBUG: Original K (CPU): " << '\n'; + for (auto i = 0ul; i < 128; ++i) { + for (int j = 0ul; j < 64; ++j) { + std::cout << (float)k[info.get_kv_elem_offset(i, 0, j)] << " "; + } + std::cout << std::endl; + } + std::cout << std::endl; +#endif for (size_t qo_head_idx = 0; qo_head_idx < num_qo_heads; ++qo_head_idx) { const size_t kv_head_idx = qo_head_idx / info.get_group_size(); for (size_t q_idx = 0; q_idx < qo_len; ++q_idx) { @@ -126,6 +151,18 @@ std::vector single_mha(const std::vector& q, const std::vect } max_val = std::max(max_val, att[kv_idx]); } + +#if Debug1 + if (qo_head_idx == 0) { + // for qo_len = 128, each warp on the GPU will store 128/4, + // that is, 32 attention scores. For CDNA3, these 32 scores + // are spread across 4 threads. + for (auto i = 0ul; i < 128; ++i) { + std::cout << att[i] / sm_scale << " "; + } + std::cout << std::endl; + } +#endif // exp minus max float denom = 0; for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { @@ -133,6 +170,27 @@ std::vector single_mha(const std::vector& q, const std::vect denom += att[kv_idx]; } +#if Debug1 + if (qo_head_idx == 0) { + // for qo_len = 128, each warp on the GPU will store 128/4, + // that is, 32 attention scores. For CDNA3, these 32 scores + // are spread across 4 threads. + for (auto i = 0ul; i < 128; ++i) { + std::cout << att[i] << " "; + } + std::cout << std::endl; + } +#endif + +#if Debug1 + if (qo_head_idx == 0) { + for (auto i = 0ul; i < 128; ++i) { + std::cout << denom << " "; + } + std::cout << std::endl; + } +#endif + // divide by denom for (size_t kv_idx = 0; kv_idx < kv_len; ++kv_idx) { att[kv_idx] /= denom; diff --git a/libflashinfer/utils/flashinfer_prefill_ops.hip.h b/libflashinfer/utils/flashinfer_prefill_ops.hip.h new file mode 100644 index 0000000000..971907203d --- /dev/null +++ b/libflashinfer/utils/flashinfer_prefill_ops.hip.h @@ -0,0 +1,188 @@ +// SPDX - FileCopyrightText : 2023 - 2025 Flashinfer team +// SPDX - FileCopyrightText : 2025 Advanced Micro Devices, Inc. +// +// SPDX - License - Identifier : Apache 2.0 + +#pragma once + +#include "utils_hip.h" + +// #include "compute_qk_stub.cuh" +#include "flashinfer/attention/generic/allocator.h" +#include "flashinfer/attention/generic/default_prefill_params.cuh" +#include "flashinfer/attention/generic/exception.h" +#include "flashinfer/attention/generic/prefill.cuh" +// #include "flashinfer/attention/generic/prefill_tester.cuh" +#include + +#include "flashinfer/attention/generic/scheduler.cuh" +#include "flashinfer/attention/generic/variants.cuh" +#include "gpu_iface/enums.hpp" +#include "gpu_iface/layout.cuh" + +namespace flashinfer { + +template +hipError_t SinglePrefillWithKVCacheDispatched(Params params, typename Params::DTypeO* tmp, + hipStream_t stream); + +template +hipError_t SinglePrefillWithKVCacheCustomMask( + DTypeIn* q, DTypeIn* k, DTypeIn* v, uint8_t* custom_mask, DTypeO* o, DTypeO* tmp, float* lse, + uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t qo_len, uint32_t kv_len, + uint32_t head_dim, QKVLayout kv_layout = QKVLayout::kNHD, + PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, bool use_fp16_qk_reduction = false, + std::optional maybe_sm_scale = std::nullopt, float rope_scale = 1.f, + float rope_theta = 1e4, hipStream_t stream = nullptr) { + const float sm_scale = 1.f; + auto [qo_stride_n, qo_stride_h, kv_stride_n, kv_stride_h] = + get_qkv_strides(kv_layout, kv_len, num_qo_heads, num_kv_heads, head_dim); + DISPATCH_use_fp16_qk_reduction( + static_cast(use_fp16_qk_reduction), USE_FP16_QK_REDUCTION, + {DISPATCH_head_dim( + head_dim, HEAD_DIM, {DISPATCH_pos_encoding_mode(pos_encoding_mode, POS_ENCODING_MODE, { + using Params = SinglePrefillParams; + using AttentionVariant = DefaultAttention< + /*use_custom_mask=*/true, /*use_sliding_window=*/false, + /*use_logits_soft_cap=*/false, /*use_alibi=*/false>; + Params params(q, k, v, custom_mask, o, lse, + /*alibi_slopes=*/nullptr, num_qo_heads, num_kv_heads, qo_len, kv_len, + qo_stride_n, qo_stride_h, kv_stride_n, kv_stride_h, head_dim, + /*window_left=*/-1, + /*logits_soft_cap=*/0.f, sm_scale, rope_scale, rope_theta); + return SinglePrefillWithKVCacheDispatched(params, tmp, stream); + })})}); + return hipSuccess; +} + +/*! + * \brief FlashAttention prefill hip function for a single request. + * \tparam DTypeIn The data type of input + * \tparam DTypeO The data type of output + * \param q The query tensor. + * \param k The key tensor. + * \param v The value tensor. + * \param o The output tensor. + * \param tmp The temporary storage (only used for cooperative kernel). + * \param lse The logsumexp values. + * \param num_qo_heads The number of query and output heads. + * \param num_kv_heads The number of key and value heads. + * \param qo_len The length of query and output. + * \param kv_len The length of key and value. + * \param head_dim The dimension of each head. + * \param causal Whether to use causal attention. + * \param kv_layout The layout of input and output. + * \param pos_encoding_mode The positional encoding mode. + * \param use_fp16_qk_reduction Whether to allow accumulating q*k^T with fp16. + * \param rope_scale The scaling factor used in RoPE interpolation. + * \param rope_theta The theta used in RoPE. + * \param stream The hip stream to execute the kernel on. + * \return status Indicates whether hip calls are successful + */ +template +hipError_t SinglePrefillWithKVCache(DTypeQ* q, DTypeKV* k, DTypeKV* v, DTypeO* o, DTypeO* tmp, + float* lse, uint32_t num_qo_heads, uint32_t num_kv_heads, + uint32_t qo_len, uint32_t kv_len, uint32_t head_dim, + bool causal = true, QKVLayout kv_layout = QKVLayout::kNHD, + PosEncodingMode pos_encoding_mode = PosEncodingMode::kNone, + bool use_fp16_qk_reduction = false, + uint32_t debug_thread_id = 0, uint32_t debug_warp_id = 0, + std::optional maybe_sm_scale = std::nullopt, + float rope_scale = 1.f, float rope_theta = 1e4, + hipStream_t stream = nullptr) { + const float sm_scale = 1.f; + const MaskMode mask_mode = causal ? MaskMode::kCausal : MaskMode::kNone; + auto [qo_stride_n, qo_stride_h, kv_stride_n, kv_stride_h] = + get_qkv_strides(kv_layout, kv_len, num_qo_heads, num_kv_heads, head_dim); + DISPATCH_use_fp16_qk_reduction( + static_cast(use_fp16_qk_reduction), USE_FP16_QK_REDUCTION, + {DISPATCH_mask_mode( + mask_mode, MASK_MODE, + {DISPATCH_head_dim(head_dim, HEAD_DIM, + {DISPATCH_pos_encoding_mode(pos_encoding_mode, POS_ENCODING_MODE, { + using Params = SinglePrefillParams; + using AttentionVariant = DefaultAttention< + /*use_custom_mask=*/(MASK_MODE == MaskMode::kCustom), + /*use_sliding_window=*/false, + /*use_logits_soft_cap=*/true, /*use_alibi=*/false>; + Params params(q, k, v, /*custom_mask=*/nullptr, o, lse, + /*alibi_slopes=*/nullptr, num_qo_heads, num_kv_heads, + qo_len, kv_len, qo_stride_n, qo_stride_h, kv_stride_n, + kv_stride_h, head_dim, + /*window_left=*/-1, + /*logits_soft_cap=*/8.f, sm_scale, rope_scale, + rope_theta, debug_thread_id, debug_warp_id); + return SinglePrefillWithKVCacheDispatched< + HEAD_DIM, HEAD_DIM, POS_ENCODING_MODE, USE_FP16_QK_REDUCTION, + MASK_MODE, AttentionVariant, Params>(params, tmp, stream); + })})})}); + return hipSuccess; +} + +// template +// hipError_t +// ComputeQKStubCaller(DTypeQ *q, +// DTypeKV *k, +// DTypeKV *v, +// DTypeO *o, +// DTypeO *tmp, +// float *lse, +// float *qk_scores_output, +// uint32_t num_qo_heads, +// uint32_t num_kv_heads, +// uint32_t qo_len, +// uint32_t kv_len, +// uint32_t head_dim, +// bool causal = true, +// QKVLayout kv_layout = QKVLayout::kNHD, +// PosEncodingMode pos_encoding_mode = +// PosEncodingMode::kNone, bool use_fp16_qk_reduction = +// false, std::optional maybe_sm_scale = +// std::nullopt, float rope_scale = 1.f, float rope_theta = +// 1e4, hipStream_t stream = nullptr) +// { +// const float sm_scale = +// maybe_sm_scale.value_or(1.f / std::sqrt(float(head_dim))); +// const MaskMode mask_mode = causal ? MaskMode::kCausal : MaskMode::kNone; +// auto [qo_stride_n, qo_stride_h, kv_stride_n, kv_stride_h] = +// get_qkv_strides( +// kv_layout, kv_len, num_qo_heads, num_kv_heads, head_dim); +// DISPATCH_use_fp16_qk_reduction( +// static_cast(use_fp16_qk_reduction), USE_FP16_QK_REDUCTION, +// {DISPATCH_mask_mode( +// mask_mode, MASK_MODE, +// {DISPATCH_head_dim( +// head_dim, HEAD_DIM, +// {DISPATCH_pos_encoding_mode( +// pos_encoding_mode, POS_ENCODING_MODE, { +// using Params = +// SinglePrefillParams; +// using AttentionVariant = DefaultAttention< +// /*use_custom_mask=*/(MASK_MODE == +// MaskMode::kCustom), +// /*use_sliding_window=*/false, +// /*use_logits_soft_cap=*/false, +// /*use_alibi=*/false>; +// Params params(q, k, v, /*custom_mask=*/nullptr, o, +// lse, +// /*alibi_slopes=*/nullptr, num_qo_heads, +// num_kv_heads, qo_len, kv_len, +// qo_stride_n, qo_stride_h, kv_stride_n, +// kv_stride_h, head_dim, +// /*window_left=*/-1, +// /*logits_soft_cap=*/0.f, sm_scale, +// rope_scale, rope_theta); +// return ComputeQKStubDispatched< +// HEAD_DIM, HEAD_DIM, POS_ENCODING_MODE, +// USE_FP16_QK_REDUCTION, MASK_MODE, +// AttentionVariant, Params>(params, tmp, +// qk_scores_output, stream); +// })})})}); +// return hipSuccess; +// } + +} // namespace flashinfer diff --git a/libflashinfer/utils/utils_hip.h b/libflashinfer/utils/utils_hip.h index bed4b1d533..de6b6956c8 100644 --- a/libflashinfer/utils/utils_hip.h +++ b/libflashinfer/utils/utils_hip.h @@ -51,10 +51,41 @@ namespace utils { +enum Predicate { + Linear, + Ones, + Zeros, +}; + +template +void generate_data(std::vector& vec) { + if constexpr (Pred == Predicate::Linear) { + assert(vec.size() > 0); + for (int i = 0; i < vec.size(); i++) { + vec[i] = fi::con::explicit_casting(static_cast(i)); + } + } + + else if constexpr (Pred == Predicate::Ones) { + vec_fill_(vec, fi::con::explicit_casting(1.0f)); + } + + else if constexpr (Pred == Predicate::Zeros) { + vec_zero_(vec); + } +} + +template +void vec_lexicographic_(std::vector& vec) { + for (int i = 0; i < vec.size(); i++) { + vec[i] = fi::con::explicit_casting(static_cast(i)); + } +} + template void vec_normal_(std::vector& vec, float mean = 0.f, float std = 1.f) { std::random_device rd{}; - std::mt19937 gen{rd()}; + std::mt19937 gen{1234}; std::normal_distribution d{mean, std}; for (size_t i = 0; i < vec.size(); ++i) { float value = static_cast(d(gen)); @@ -65,7 +96,7 @@ void vec_normal_(std::vector& vec, float mean = 0.f, float std = 1.f) { template void vec_uniform_(std::vector& vec, float a = 0.f, float b = 1.f) { std::random_device rd{}; - std::mt19937 gen{rd()}; + std::mt19937 gen{1234}; std::uniform_real_distribution d{a, b}; for (size_t i = 0; i < vec.size(); ++i) { float value = static_cast(d(gen)); @@ -86,7 +117,7 @@ void vec_fill_(std::vector& vec, T val) { template void vec_randint_(std::vector& vec, int low, int high) { std::random_device rd{}; - std::mt19937 gen{rd()}; + std::mt19937 gen{1234}; std::uniform_int_distribution d{low, high}; for (size_t i = 0; i < vec.size(); ++i) { float value = static_cast(d(gen)); diff --git a/validate_online_softmax_stateful.py b/validate_online_softmax_stateful.py new file mode 100755 index 0000000000..be282d6a7e --- /dev/null +++ b/validate_online_softmax_stateful.py @@ -0,0 +1,289 @@ +#!/usr/bin/env python3 +""" +Stateful validation of online softmax across multiple iterations. +Maintains running m (maximum) state across KV chunks. + +Usage: + ./validate_online_softmax_stateful.py [LOG_FILE] + +Examples: + ./validate_online_softmax_stateful.py prefill.log + ./validate_online_softmax_stateful.py my_debug.log + ./validate_online_softmax_stateful.py # defaults to prefill.log +""" + +import argparse +import re +import sys +from pathlib import Path + +import numpy as np + +# ============================================================================ +# PARSING FUNCTIONS +# ============================================================================ + + +def parse_sm_scale(lines): + """Extract sm_scale from log file.""" + for line in lines: + if "sm_scale" in line: + match = re.search(r"sm_scale\s*:\s*([\d.]+)", line) + if match: + return float(match.group(1)) + raise ValueError("Could not find sm_scale in log file") + + +def parse_matrix(lines, start_line): + """Parse matrix starting from start_line. Returns 128×64 matrix.""" + data = [] + for i in range(start_line, len(lines)): + line = lines[i] + if "frag" in line or "DEBUG" in line or line.strip().startswith("num_"): + break + if line.strip(): + nums = re.findall(r"-?\d+\.\d+", line) + if nums: + data.extend([float(x) for x in nums]) + + if len(data) == 0: + return None + + expected_size = 128 * 64 + if len(data) != expected_size: + print(f"Warning: Expected {expected_size} values, got {len(data)}") + return None + + return np.array(data).reshape(128, 64) + + +def find_iteration_data(lines, iteration_num): + """ + Find before and after matrices for a given iteration. + Returns (before_matrix, after_matrix). + """ + # Find before data + iter_count = 0 + before_line = None + for i, line in enumerate(lines): + if "S frag before update_mdo for iteration" in line: + if iter_count == iteration_num: + before_line = i + 2 + iter_count += 1 + + # Find after data + iter_count = 0 + after_line = None + for i, line in enumerate(lines): + if "S frag after update_mdo for iteration" in line: + if iter_count == iteration_num: + after_line = i + 2 + break + iter_count += 1 + + if before_line is None or after_line is None: + return None, None + + before = parse_matrix(lines, before_line) + after = parse_matrix(lines, after_line) + + return before, after + + +def validate_with_state(before_row, after_row, m_prev, sm_scale): + """ + Validate online softmax transformation with stateful m. + + Args: + before_row: Raw scores for this chunk (64 values) + after_row: Transformed scores (64 values) + m_prev: Maximum from previous chunks (scalar) + sm_scale: Softmax scale factor (scalar) + + Returns: + (is_valid, m_new, max_error) + """ + # Step 1: Find maximum in current chunk + m_chunk = before_row.max() + + # Step 2: Update running maximum + m_new = max(m_prev, m_chunk) + + # Step 3: Apply transformation to each element + expected = np.exp2((before_row - m_new) * sm_scale) + + # Step 4: Compare with actual + diff = np.abs(after_row - expected) + max_error = diff.max() + + # Tolerance for floating point comparison + tolerance = 5e-3 # 0.005 + is_valid = max_error < tolerance + + return is_valid, m_new, max_error + + +# ============================================================================ +# VALIDATION ORCHESTRATION +# ============================================================================ + + +def validate_all_iterations(lines, sm_scale, num_iterations=2): + """ + Validate all iterations with proper state management. + + Returns: + (total_passed, total_rows, max_error_overall) + """ + print(f"{'='*80}") + print(f"STATEFUL ONLINE SOFTMAX VALIDATION") + print(f"{'='*80}") + print(f"sm_scale: {sm_scale}") + print(f"Formula: s_after = exp2((s_before - m_new) * sm_scale)") + print(f" where m_new = max(m_prev, max(s_before_chunk))") + print(f"{'='*80}\n") + + total_passed = 0 + total_rows = 0 + max_error_overall = 0.0 + + # Initialize m_prev to -inf for first iteration + m_prev_per_row = np.full(128, -np.inf) + + for iteration in range(num_iterations): + print(f"\n{'─'*80}") + print(f"ITERATION {iteration}") + print(f"{'─'*80}") + + before, after = find_iteration_data(lines, iteration) + + if before is None or after is None: + print(f"❌ Could not find data for iteration {iteration}") + continue + + print(f"Matrix shape: {before.shape}") + + passed = 0 + failed = 0 + max_error_iter = 0.0 + + # Process each row with its own m_prev + for row_idx in range(128): + before_row = before[row_idx, :] + after_row = after[row_idx, :] + m_prev = m_prev_per_row[row_idx] + + is_valid, m_new, max_error = validate_with_state( + before_row, after_row, m_prev, sm_scale + ) + + # Update running m for this row + m_prev_per_row[row_idx] = m_new + + max_error_iter = max(max_error_iter, max_error) + max_error_overall = max(max_error_overall, max_error) + + if is_valid: + passed += 1 + else: + failed += 1 + if failed <= 3: # Show first 3 failures + print( + f" ❌ Row {row_idx}: m_prev={m_prev:.6f}, " + f"m_chunk={before_row.max():.6f}, " + f"m_new={m_new:.6f}, max_error={max_error:.6e}" + ) + + total_passed += passed + total_rows += 128 + + print(f"\nIteration {iteration} Results:") + print(f" ✓ Passed: {passed}/128 rows") + print(f" ✗ Failed: {failed}/128 rows") + print(f" 📊 Max error: {max_error_iter:.6e}") + + if failed == 0: + print(f" 🎉 ITERATION {iteration} VALIDATED SUCCESSFULLY!") + + # Debug: Show sample row state + sample_row = 0 + print(f"\n Sample (Row {sample_row}):") + print(f" m_prev: {m_prev_per_row[sample_row]:.6f}") + print(f" m_chunk: {before[sample_row, :].max():.6f}") + print(f" m_new: {m_prev_per_row[sample_row]:.6f}") + + print(f"\n{'='*80}") + print(f"OVERALL RESULTS") + print(f"{'='*80}") + print( + f"Total rows validated: {total_passed}/{total_rows} ({100*total_passed/total_rows:.1f}%)" + ) + print(f"Max error across all iterations: {max_error_overall:.6e}") + + if total_passed == total_rows: + print(f"\n🎉 ALL ROWS VALIDATED SUCCESSFULLY!") + print(f"✅ Online softmax is correctly implemented with stateful m propagation") + return True + else: + print(f"\n⚠️ VALIDATION INCOMPLETE: {total_rows - total_passed} rows failed") + return False + + +# ============================================================================ +# MAIN ENTRY POINT +# ============================================================================ + + +def main(): + parser = argparse.ArgumentParser( + description="Validate online softmax with stateful m propagation", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + %(prog)s prefill.log Validate using prefill.log + %(prog)s my_debug.log Validate using my_debug.log + %(prog)s Validate using prefill.log (default) + """, + ) + parser.add_argument( + "logfile", + nargs="?", + default="prefill.log", + help="Path to log file (default: prefill.log)", + ) + parser.add_argument( + "-n", + "--num-iterations", + type=int, + default=2, + help="Number of iterations to validate (default: 2)", + ) + + args = parser.parse_args() + + logfile = Path(args.logfile) + + if not logfile.exists(): + print(f"❌ Error: Log file '{logfile}' not found") + sys.exit(1) + + print(f"Reading log file: {logfile}") + + with open(logfile, "r") as f: + lines = f.readlines() + + print(f"Loaded {len(lines)} lines from {logfile}\n") + + try: + sm_scale = parse_sm_scale(lines) + except ValueError as e: + print(f"❌ Error: {e}") + sys.exit(1) + + success = validate_all_iterations(lines, sm_scale, args.num_iterations) + + sys.exit(0 if success else 1) + + +if __name__ == "__main__": + main()