diff --git a/applications/flash_attention_v2/collective/fmha_fusion.hpp b/applications/flash_attention_v2/collective/fmha_fusion.hpp index a87752588f..d943228538 100644 --- a/applications/flash_attention_v2/collective/fmha_fusion.hpp +++ b/applications/flash_attention_v2/collective/fmha_fusion.hpp @@ -1,5 +1,6 @@ /*************************************************************************************************** -* Copyright (c) 2025 - 2025 Codeplay Software Ltd. All rights reserved. + * Copyright (c) 2025 - 2025 Codeplay Software Ltd. All rights reserved. + * Copyright (C) 2025 Intel Corporation, All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -39,6 +40,7 @@ using namespace cute; struct VariableLength { int max_length; + int total_length = 0; int* cumulative_length = nullptr; CUTE_HOST_DEVICE operator int() const { diff --git a/applications/flash_attention_v2/collective/xe_flash_attn_chunk_prefill_epilogue.hpp b/applications/flash_attention_v2/collective/xe_flash_attn_chunk_prefill_epilogue.hpp new file mode 100644 index 0000000000..68566f9f58 --- /dev/null +++ b/applications/flash_attention_v2/collective/xe_flash_attn_chunk_prefill_epilogue.hpp @@ -0,0 +1,264 @@ +/*************************************************************************************************** + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing elementwise operations used by epilogues. +*/ + +#pragma once + +#include +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_epilogue.hpp" +#include "cutlass/epilogue/collective/detail.hpp" +#include "cutlass/detail/layout.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// +namespace cutlass { +namespace flash_attention { +namespace collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template class FlashChunkPrefillEpilogue { + static_assert(cutlass::detail::dependent_false, "Could not find an epilogue specialization."); +}; + +template +class FlashChunkPrefillEpilogue { +public: + // + // Type Aliases + // + static constexpr bool Sink = Sink_; + using DispatchPolicy = epilogue::IntelXeXMX16; + using ElementO = ElementO_; + using StrideO = StrideO_; + using ElementLSE = ElementLSE_; + using ElementSink = ElementSink_; + using CopyOpO = CopyOpO_; + using SubgroupLayout = SubgroupLayout_; + using TileShapeOutput = TileShapeOutput_; + using TiledMmaOutput = typename TiledMMAHelper, Layout, SubgroupLayout>::TiledMMA; + using GmemTiledCopyO = CopyOpO; + using ElementOutput = ElementO_; + using ElementCompute = ElementCompute_; + using ElementAccumulator = ElementCompute_; + using SubgroupTileShape = decltype(cute::shape_div(TileShapeOutput{}, (SubgroupLayout{}.shape()))); + + static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; + + static_assert(cute::rank(TileShapeOutput{}) == 3, "TileShapeOutput must be rank-3: [CTA_M_QO, CTA_N_VO, CTA_K_PV]"); + static_assert(cute::rank(StrideO{}) == 3, "StrideO must be rank-3: [seq_len_qo, head_size_vo, batch * num_heads]"); + + using CopyThreadShape = Shape<_1, Int>; + + using traits_store_O = Copy_Traits; + using atom_load_O = Copy_Atom; + using val_layout_load_O = decltype(make_layout(shape_div(typename traits_store_O::BlockShape{}, CopyThreadShape{}))); + using XE_Copy_O = decltype(make_tiled_copy(atom_load_O{}, Layout{}, val_layout_load_O{})); + +private: + constexpr static bool is_destination_supported = not cute::is_void_v; + +public: + using EmptyType = cute::tuple<>; + + struct TensorStorageImpl : cute::tuple {}; + + struct SharedStorage { + using TensorStorage = TensorStorageImpl; + + TensorStorage tensors; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + + // Host side epilogue arguments + struct Arguments { + ElementO const *ptr_O; + StrideO dO; + ElementSink const* ptr_sink; + }; + + // Device side epilogue params + struct Params { + XE_Copy_O xe_store_o; + ElementSink const* ptr_sink; + }; + + // + // Methods + // + template + CUTLASS_DEVICE auto convert_type(Tensor const &tensor) { + using From_type = typename Engine::value_type; + constexpr int numel = decltype(size(tensor))::value; + cutlass::NumericArrayConverter convert_op; + auto frag = + convert_op(*reinterpret_cast *>( + tensor.data())); + return make_tensor(make_rmem_ptr(&frag), tensor.layout()); + } + + template + static constexpr Params to_underlying_arguments(ProblemShape const &problem_shape, Arguments const &args, + [[maybe_unused]] void *workspace) { + auto [batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, seq_len_kv_cache, head_size_qk, head_size_vo] = problem_shape; + auto tensorO = make_tensor(make_gmem_ptr(static_cast(args.ptr_O)), + make_layout(make_shape(seq_len_qo, num_heads_q * head_size_vo, batch), + args.dO)); + XE_Copy_O xe_store_o{XE_Copy_O{}.with(tensorO)}; + return { + xe_store_o, args.ptr_sink + }; + } + + template + static size_t get_workspace_size(ProblemShape const &problem_shape, Arguments const &args) { + return 0; + } + + template + static cutlass::Status initialize_workspace(ProblemShape const &problem_shape, Arguments const &args, void *workspace, + cudaStream_t stream, CudaHostAdapter *cuda_adapter = nullptr) { + return Status::kSuccess; + } + + template + CUTLASS_HOST_DEVICE static bool can_implement(ProblemShape const &problem_shape, + [[maybe_unused]] Arguments const &args) { + return true; + } + + CUTLASS_HOST_DEVICE + FlashChunkPrefillEpilogue(Params const ¶ms_, TensorStorage const &) : params(params_) {} + + template + CUTLASS_DEVICE void operator()(ProblemShape problem_shape, SequenceLengthShape sequence_length_shape, TileCoord tile_coord, FragOut &out, + FragMax const& max, FragSum &sum, [[maybe_unused]] FragSink const& sink) { + + using namespace cute; + + static constexpr bool is_var_len = cutlass::fmha::collective::is_variable_length_v>; + + using FragOutLayout = typename FragOut::layout_type; + + constexpr int Vec = shape<0>(FragOutLayout{}); + constexpr int FragsM = shape<1>(FragOutLayout{}); + constexpr int FragsN = size(select<2,3>(shape(FragOutLayout{}))); + + auto sg = compat::get_nd_item<1>().get_sub_group(); + auto out_reg = make_tensor(static_cast(out).data() , Shape, Int, Int>{}); + + CUTLASS_PRAGMA_UNROLL + for (int y = 0; y < FragsM; y++) { + CUTLASS_PRAGMA_UNROLL + for (int x = 0; x < Vec; x++) { + int index = y * Vec + x; + auto cur_sum = reduce_over_group(sg, sum(index), sycl::plus<>()); + if constexpr (Sink) { + constexpr double kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E + auto max_scale_bcast = group_broadcast(sg, max, index); + cur_sum += sycl::native::exp2(static_cast(sink * kLog2e) - max_scale_bcast); + } + auto cur_scale = (cur_sum == 0.f || cur_sum != cur_sum) ? 1.0f : sycl::native::recip(cur_sum); + CUTLASS_PRAGMA_UNROLL + for (int z = 0; z < FragsN; z++) { + out_reg(x, y, z) *= cur_scale; + } + } + } + + // Indexing variables + auto [batch, num_heads_q, num_heads_kv, head_size_vo] = select<0, 1, 2, 7>(problem_shape); + auto [seq_len_qo] = select<0>(sequence_length_shape); + // Represent the full output tensor + Tensor mO_mnl = cute::get_xe_tensor(make_shape(seq_len_qo, head_size_vo, 1)); + + auto [m_coord, n_coord, k_coord, l_coord] = tile_coord; + // Tile the output tensor per WG + Tensor g_wg_O = local_tile(mO_mnl, select<0,1>(TileShapeOutput{}), make_coord(m_coord,n_coord,0)); // (BLK_M,BLK_N,m,n,l) + static constexpr auto ATOM_N = get<2>(typename TiledMmaOutput::ThrLayoutVMNK{}.shape()); + auto m_sg = get_sub_group_id() / ATOM_N; + auto n_sg = get_sub_group_id() % ATOM_N; + // Tile the output tensor per SG + Tensor gO = local_tile(g_wg_O, SubgroupTileShape{}, make_coord(m_sg,n_sg,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + auto thread_xe_store_o = params.xe_store_o.get_thread_slice(ThreadIdxX()); + Tensor tOgO = thread_xe_store_o.partition_D(gO); + + Tensor final_out_reg = make_fragment_like(out_reg); + // iff ElementOutput == ElementAccumulator, then convert_type doesn't do the right conversion + // so we call copy() which internally performs a static_cast op on the data. + // for ElementOutput == bf16 | fp16, convert_type calls relevant NumericConverter specialization. + if constexpr (cute::is_same_v) { + copy(out_reg, final_out_reg); + } else { + Tensor temp = convert_type(out_reg); + copy(temp, final_out_reg); + } + copy(params.xe_store_o, final_out_reg, tOgO); + } + + // SequenceLengthShapeType = Shape + // For Fixed Sequence Length, ProblemShapeType = Shape + // For Variable Sequence Length, ProblemShapeType = Shape + template + CUTLASS_DEVICE static constexpr Params get_updated_copies(Params const& params, ProblemShapeType const& problem_shape, + SequenceLengthShapeType const& sequence_length_shape, int const& l_coord, int const& q_head_coord) { + auto [num_heads_q, num_heads_kv, head_size_vo] = select<1, 2, 7>(problem_shape); + auto [seq_len_qo] = select<0>(sequence_length_shape); + int offset_o = 0; + if constexpr (VarLen) { + auto qo_cumulative_length = get<3>(problem_shape).cumulative_length; + offset_o = num_heads_q * head_size_vo * qo_cumulative_length[l_coord] + q_head_coord * head_size_vo; + } else { + offset_o = num_heads_q * head_size_vo * seq_len_qo * l_coord + q_head_coord * head_size_vo; + } + auto store_traits = static_cast(params.xe_store_o); + ElementO* base_ptr = (ElementO*)store_traits.base_ptr; + auto shape_o = make_shape(static_cast(seq_len_qo), num_heads_q * head_size_vo, 1); + StrideO stride_o = cutlass::make_cute_packed_stride(StrideO{}, shape_o); + auto tensorO = make_tensor(make_gmem_ptr(base_ptr + offset_o), make_layout(shape_o, stride_o)); + XE_Copy_O xe_store_o{XE_Copy_O{}.with(tensorO)}; + return Params{xe_store_o, params.ptr_sink}; + } + +private: + Params const ¶ms; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace collective +} // namespace flash_attention +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/applications/flash_attention_v2/collective/xe_flash_attn_chunk_prefill_mma.hpp b/applications/flash_attention_v2/collective/xe_flash_attn_chunk_prefill_mma.hpp new file mode 100644 index 0000000000..f3d76fdb41 --- /dev/null +++ b/applications/flash_attention_v2/collective/xe_flash_attn_chunk_prefill_mma.hpp @@ -0,0 +1,541 @@ +/*************************************************************************************************** + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "fmha_fusion.hpp" + + +//////////////////////////////////////////////////////////// +namespace { + +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::flash_attention::collective { +using namespace cute; +//////////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct FlashChunkPrefillMma { + static_assert(cutlass::detail::dependent_false, + "Could not find a mainloop specialization."); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct FlashChunkPrefillMma< + gemm::MainloopIntelXeXMX16, ProblemShapeType_, ElementQ_, StrideQ_, + ElementK_, StrideK_, ElementV_, StrideV_, MMAOperation_, TileShapeQK_, + TileShapePV_, SubgroupLayout_, GmemTiledCopyQ_, GmemTiledCopyK_, + GmemTiledCopyV_, CausalMask_, LocalMask_, PagedKV_> { + // + // Type Aliases + // + using DispatchPolicy = gemm::MainloopIntelXeXMX16; + using TileShapeQK = TileShapeQK_; + using TileShapePV = TileShapePV_; + using SubgroupLayout = SubgroupLayout_; + using ProblemShapeType = ProblemShapeType_; + using ElementQ = ElementQ_; + using StrideQ = StrideQ_; + using ElementK = ElementK_; + using StrideK = StrideK_; + using ElementV = ElementV_; + using StrideV = StrideV_; + using GmemTiledCopyQ = GmemTiledCopyQ_; + using GmemTiledCopyK = GmemTiledCopyK_; + using GmemTiledCopyV = GmemTiledCopyV_; + using ArchTag = typename DispatchPolicy::ArchTag; + using MmaAtom = MMA_Atom; + + using TiledMmaQK = typename TiledMMAHelper, + SubgroupLayout>::TiledMMA; + + using TiledMmaPV = typename TiledMMAHelper, + SubgroupLayout>::TiledMMA; + using ElementAccumulator = typename TiledMmaQK::ValTypeC; + static constexpr bool CausalMask = CausalMask_; + static constexpr bool LocalMask = LocalMask_; + static constexpr bool PagedKV = PagedKV_; + + static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; + + using MmaAtomShape = typename MmaAtom::Shape_MNK; + + static constexpr auto PV_ATOM_M = + decltype(get<0>(SubgroupLayout{}.shape()))::value; + static constexpr auto PV_ATOM_N = + decltype(get<1>(SubgroupLayout{}.shape()))::value; + static constexpr auto PV_ATOM_K = + decltype(get<2>(SubgroupLayout{}.shape()))::value; + + using SubgroupTileShapePV = + decltype(cute::shape_div(TileShapePV{}, (SubgroupLayout{}.shape()))); + static constexpr auto QK_BLK_M = get<0>(TileShapeQK{}); + static constexpr auto QK_BLK_N = get<1>(TileShapeQK{}); + static constexpr auto QK_BLK_K = get<2>(TileShapeQK{}); + + // This TiledMma is only required to serve the specific tiling requirements + // for matrix K. This is due to the consumption of matrix K by all subgroups + // within a workgroup. + static constexpr auto QK_ATOM_M = PV_ATOM_M; // 8 + static constexpr auto QK_ATOM_N = PV_ATOM_N; // 1 + static constexpr auto QK_ATOM_K = PV_ATOM_K; // 1 + + using SubgroupTileShapeQK = decltype(cute::shape_div( + TileShapeQK{}, + SubgroupLayout{}.shape())); // 128, 64, 32 / 16, 1, 1 = (8, 64, 32 ) + + static constexpr auto QK_SG_M = get<0>(SubgroupTileShapeQK{}); + static constexpr auto QK_SG_N = get<1>(SubgroupTileShapeQK{}); + static constexpr auto QK_SG_K = get<2>(SubgroupTileShapeQK{}); + + static constexpr bool is_var_len = + cutlass::fmha::collective::is_variable_length_v< + tuple_element_t<3, ProblemShapeType>>; + + using FragsShapeS = decltype(cute::shape_div( + take<0, 2>(SubgroupTileShapeQK{}), + take<0, 2>(MmaAtomShape()))); // 8, 64, 32 / 8, 16, 16 (1, 4) + static constexpr int Vec = + (get<0>(MmaAtomShape()) * get<1>(MmaAtomShape())) / SubgroupSize; // 8 + static constexpr int FragsM = get<0>(FragsShapeS{}); + static constexpr int FragsNS = get<1>(FragsShapeS{}); // 4 + + static constexpr uint32_t MaxThreadsPerBlock = + size(SubgroupLayout{}) * SubgroupSize; + using CopyThreadShape = Shape<_1, Int>; + + using traits_load_Q = Copy_Traits; + using atom_load_Q = Copy_Atom; + using val_layout_load_Q = decltype(make_layout( + shape_div(typename traits_load_Q::BlockShape{}, CopyThreadShape{}))); + using XE_Copy_Q = decltype(make_tiled_copy( + atom_load_Q{}, Layout{}, val_layout_load_Q{})); + + using traits_load_K = Copy_Traits; + using atom_load_K = Copy_Atom; + using val_layout_load_K = decltype(make_layout( + shape_div(typename traits_load_K::BlockShape{}, CopyThreadShape{}))); + using XE_Copy_K = decltype(make_tiled_copy( + atom_load_K{}, Layout{}, val_layout_load_K{})); + + using traits_load_V = Copy_Traits; + using atom_load_V = Copy_Atom; + using val_layout_load_V = decltype(make_layout( + shape_div(typename traits_load_V::BlockShape{}, CopyThreadShape{}))); + using XE_Copy_V = decltype(make_tiled_copy( + atom_load_V{}, Layout{}, val_layout_load_V{})); + + // Host side kernel arguments + struct Arguments { + ElementQ const *ptr_Q; + StrideQ dQ; + ElementK const *ptr_K; + StrideK dK; + ElementV const *ptr_V; + StrideV dV; + ElementK const *ptr_K_cache; + StrideK dK_cache; + ElementV const *ptr_V_cache; + StrideV dV_cache; + // Paged KV Cache + int const *ptr_page_table; + int page_size; + int const *num_pages_per_seq; + int window_left; + int window_right; + }; + + struct Params { + XE_Copy_Q gmem_tiled_copy_q; + XE_Copy_K gmem_tiled_copy_k; + XE_Copy_V gmem_tiled_copy_v; + XE_Copy_K gmem_tiled_copy_k_cache; + XE_Copy_V gmem_tiled_copy_v_cache; + // Paged KV Cache + int const *ptr_page_table; + int page_size; + int const *num_pages_per_seq; + int window_left; + int window_right; + }; + + // + // Methods + // + + FlashChunkPrefillMma() = default; + + static constexpr Params + to_underlying_arguments(ProblemShapeType const &problem_shape, + Arguments const &args, void *workspace) { + (void)workspace; + + auto [batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, + seq_len_kv_cache, head_size_qk, head_size_vo] = problem_shape; + + auto tensorQ = make_tensor( + make_gmem_ptr(args.ptr_Q), + make_layout(make_shape(seq_len_qo, num_heads_q * head_size_qk, batch), + args.dQ)); + auto tensorK = make_tensor( + make_gmem_ptr(args.ptr_K), + make_layout(make_shape(seq_len_kv, num_heads_kv * head_size_qk, batch), + args.dK)); + auto tensorV = make_tensor( + make_gmem_ptr(args.ptr_V), + make_layout(make_shape(num_heads_kv * head_size_vo, seq_len_kv, batch), + args.dV)); + auto tensorK_cache = + make_tensor(make_gmem_ptr(args.ptr_K_cache), + make_layout(make_shape(seq_len_kv_cache, + num_heads_kv * head_size_qk, batch), + args.dK_cache)); + auto tensorV_cache = make_tensor( + make_gmem_ptr(args.ptr_V_cache), + make_layout( + make_shape(num_heads_kv * head_size_vo, seq_len_kv_cache, batch), + args.dV_cache)); + + XE_Copy_Q copyQ{XE_Copy_Q{}.with(tensorQ)}; + XE_Copy_K copyK{XE_Copy_K{}.with(tensorK)}; + XE_Copy_V copyV{XE_Copy_V{}.with(tensorV)}; + XE_Copy_K copyK_cache{XE_Copy_K{}.with(tensorK_cache)}; + XE_Copy_V copyV_cache{XE_Copy_V{}.with(tensorV_cache)}; + + return Params{copyQ, copyK, + copyV, copyK_cache, + copyV_cache, args.ptr_page_table, + args.page_size, args.num_pages_per_seq, + args.window_left, args.window_right}; + } + + template + CUTLASS_DEVICE void mmaQK(FragQccum &accum, TensorQ gQ, TensorK gK, + FragSrc const &frag_src, int const &k_tile_count, + Params const ¶ms, bool is_KV_cache) { + + auto &gmem_tiled_copy_k = + is_KV_cache ? params.gmem_tiled_copy_k_cache : params.gmem_tiled_copy_k; + + int thread_idx = static_cast(ThreadIdxX()); + auto thr_copy_Q = params.gmem_tiled_copy_q.get_slice(thread_idx); + auto thr_copy_K = gmem_tiled_copy_k.get_slice(thread_idx); + // Instantiate the MMA object + TiledMmaQK tiled_mma; + // To make all threads in a warp have the same global tensors pass in the + // index of thread 0 in each warp + auto sg = compat::get_nd_item<1>().get_sub_group(); + auto first_thread_in_sg_idx = + sg.get_group_id()[0] * DispatchPolicy::SubgroupSize; + auto thread_mma_q = tiled_mma.get_slice(first_thread_in_sg_idx); + auto thread_mma_k = tiled_mma.get_slice(0); + + Tensor tCgQ = thread_mma_q.partition_A(gQ); + Tensor tCgK = thread_mma_k.partition_B(gK); + + // Create fragments + // TODO(Codeplay): fix this, this is probably not general + Tensor tCrQ = make_tensor(make_fragment_layout( + params.gmem_tiled_copy_q, take<0, 3>(tCgQ.shape()))); + Tensor tCrK = make_tensor( + make_fragment_layout(gmem_tiled_copy_k, take<0, 3>(tCgK.shape()))); + + // Retile registers for copies + Tensor tQrQ = thr_copy_Q.retile_D(tCrQ); + Tensor tKrK = thr_copy_K.retile_D(tCrK); + + // Retile global tile for copies + Tensor tQgQ = thr_copy_Q.retile_S(tCgQ); + Tensor tKgK = thr_copy_K.retile_S(tCgK); + + // + // Mainloop + // + + for (int k_tile = 0; k_tile < k_tile_count; ++k_tile) { + copy(params.gmem_tiled_copy_q, tQgQ(_, _, _, k_tile), tQrQ); + copy(gmem_tiled_copy_k, tKgK(_, _, _, k_tile), tKrK); + cute::gemm(tiled_mma, accum, tCrQ, tCrK, frag_src); +#if 0 +#define PRINT(x) \ + print(#x ": "); \ + print(x); \ + print("\n"); + if (cute::thread(0, 0)) { + print("======================= Q: \n"); + PRINT(gQ); + PRINT(tCrQ); + PRINT(tCgQ); + PRINT(tQrQ); + PRINT(tQgQ); + + print("===================== K :\n"); + PRINT(gK); + PRINT(tCrK); + PRINT(tCgK); + PRINT(tKrK); + PRINT(tKgK); + + print("===================== Config: \n"); + PRINT(MaxThreadsPerBlock); + PRINT(SubgroupTileShapeQK{}); + } +#undef PRINT +#endif + } + } + + template + CUTLASS_DEVICE auto convert_type(Tensor const &tensor) { + using From_type = typename Engine::value_type; + constexpr int numel = decltype(size(tensor))::value; + cutlass::NumericArrayConverter convert_op; + auto frag = + convert_op(*reinterpret_cast *>( + tensor.data())); + return make_tensor(make_rmem_ptr(&frag), tensor.layout()); + } + + template + CUTLASS_DEVICE void mmaPV(FragQccum &accum, FragS const &tSr, TensorV gV, + FragSrc const &frag_src, Params const ¶ms, + bool is_KV_cache) { + + auto &gmem_tiled_copy_v = + is_KV_cache ? params.gmem_tiled_copy_v_cache : params.gmem_tiled_copy_v; + + int thread_idx = static_cast(ThreadIdxX()); + // Instantiate the MMA object + TiledMmaPV tiled_mma; + // Tile GV to the shape of <64,64> and loop over the HeadSize/64 to avoid + // Register spill + Tensor gV_ = take<0, 3>( + local_tile(gV, select<1, 2>(TileShapePV{}), make_coord(_, _))); + auto sg = compat::get_nd_item<1>().get_sub_group(); + auto first_thread_in_sg_idx = + sg.get_group_id()[0] * DispatchPolicy::SubgroupSize; + auto thread_mma = tiled_mma.get_slice(first_thread_in_sg_idx); + Tensor tCgV = thread_mma.partition_B(gV_); + Tensor tCrV = make_tensor( + make_fragment_layout(gmem_tiled_copy_v, take<0, 3>(tCgV.shape()))); + + // Partition the copying of A and B tiles across the threads + auto gmem_thr_copy_V = gmem_tiled_copy_v.get_slice(thread_idx); + Tensor tVrV = gmem_thr_copy_V.retile_D(tCrV); + Tensor tVgV = gmem_thr_copy_V.retile_S(tCgV); + +#if CUTLASS_ENABLE_DEBUG_PRINTS +#define PRINT(x) \ + print(#x ": "); \ + print(x); \ + print("\n"); + if (cute::thread(LOG_THREAD, LOG_GROUP)) { + print("===================== V :\n"); + PRINT(gV); + PRINT(tCrV); + PRINT(tCgV); + PRINT(tVrV); + PRINT(tVgV); + + print("===================== Config: \n"); + PRINT(MaxThreadsPerBlock); + PRINT(SubgroupTileShapePV{}); + } +#undef PRINT +#endif + + // 7) Convert S to P (FP32 -> BF16) + Tensor tPr = convert_type(tSr); + // + // Mainloop + // + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < tile_count; i++) { + copy(gmem_tiled_copy_v, tVgV(_, _, _, i), tVrV); + cute::gemm(tiled_mma, accum(_, _, _, i), tPr, tCrV, frag_src(_, _, _, i)); + } + } + + // SequenceLengthShape = Shape + // For Fixed Sequence Length, ProblemShape = Shape For Variable Sequence Length, ProblemShape = Shape + template + CUTLASS_DEVICE static constexpr Params + get_updated_copies(Params const ¶ms, ProblemShape const &problem_shape, + SequenceLengthShape const &sequence_length_shape, + int const &l_coord, int const &q_head_coord = 0) { + auto [batch, num_heads_q, num_heads_kv, head_size_qk, head_size_vo] = + select<0, 1, 2, 6, 7>(problem_shape); + auto [seq_len_qo, seq_len_kv, seq_len_kv_cache] = sequence_length_shape; + auto q_group_size = num_heads_q / num_heads_kv; + auto kv_head_coord = q_head_coord / q_group_size; + int offset_q = 0, offset_k = 0, offset_v = 0, offset_k_cache = 0, + offset_v_cache = 0; + int total_seq_len_kv_cache = 0; + if constexpr (is_var_len) { + auto qo_cumulative_length = get<3>(problem_shape).cumulative_length; + auto kv_cumulative_length = get<4>(problem_shape).cumulative_length; + auto kv_cached_cumulative_length = + get<5>(problem_shape).cumulative_length; + + offset_q = num_heads_q * head_size_qk * qo_cumulative_length[l_coord] + + q_head_coord * head_size_qk; + + offset_k = num_heads_kv * head_size_qk * kv_cumulative_length[l_coord] + + kv_head_coord * head_size_qk; + offset_v = num_heads_kv * head_size_vo * kv_cumulative_length[l_coord] + + kv_head_coord * head_size_vo; + offset_k_cache = seq_len_kv_cache == 0 + ? 0 + : PagedKV? // For page_kv, there is no batch dimension. + kv_head_coord * head_size_qk + : num_heads_kv * head_size_qk * kv_cached_cumulative_length[l_coord] + kv_head_coord * head_size_qk; + offset_v_cache = seq_len_kv_cache == 0 + ? 0 + : PagedKV? // For page_kv, there is no batch dimension. + kv_head_coord * head_size_vo + : num_heads_kv * head_size_vo * kv_cached_cumulative_length[l_coord] + kv_head_coord * head_size_vo; + total_seq_len_kv_cache = get<5>(problem_shape).total_length; + } else { + offset_q = num_heads_q * head_size_qk * seq_len_qo * l_coord + + q_head_coord * head_size_qk; + + offset_k = num_heads_kv * head_size_qk * seq_len_kv * l_coord + + kv_head_coord * head_size_qk; + offset_v = num_heads_kv * head_size_vo * seq_len_kv * l_coord + + kv_head_coord * head_size_vo; + offset_k_cache = + seq_len_kv_cache == 0 + ? 0 : + PagedKV? + kv_head_coord * head_size_qk + : num_heads_kv * head_size_qk * seq_len_kv_cache * l_coord + kv_head_coord * head_size_qk; + offset_v_cache = + seq_len_kv_cache == 0 + ? 0 : + PagedKV? + kv_head_coord * head_size_vo + : num_heads_kv * head_size_vo * seq_len_kv_cache * l_coord + kv_head_coord * head_size_vo; + total_seq_len_kv_cache = batch * seq_len_kv_cache; + } + + auto q_traits = + static_cast(params.gmem_tiled_copy_q); + const ElementQ *q_ptr = (const ElementQ *)q_traits.base_ptr; + auto k_traits = + static_cast(params.gmem_tiled_copy_k); + const ElementK *k_ptr = (const ElementK *)k_traits.base_ptr; + auto v_traits = + static_cast(params.gmem_tiled_copy_v); + const ElementV *v_ptr = (const ElementV *)v_traits.base_ptr; + auto k_traits_cache = + static_cast(params.gmem_tiled_copy_k_cache); + const ElementK *k_cache_ptr = (const ElementK *)k_traits_cache.base_ptr; + auto v_traits_cache = + static_cast(params.gmem_tiled_copy_v_cache); + const ElementV *v_cache_ptr = (const ElementV *)v_traits_cache.base_ptr; + // NHD format{batch, seq_len, head, dim_head} + // stride {seq_len*head*dim_head, head*dim_head, dim_head, 1} + auto shape_q = + make_shape(static_cast(seq_len_qo), head_size_qk * num_heads_q, 1); + StrideQ stride_q = cutlass::make_cute_packed_stride(StrideQ{}, shape_q); + auto shape_k = make_shape(static_cast(seq_len_kv), + num_heads_kv * head_size_qk, 1); + StrideK stride_k = cutlass::make_cute_packed_stride(StrideK{}, shape_k); + + auto shape_v = make_shape(head_size_vo * num_heads_kv, + static_cast(seq_len_kv), 1); + StrideV stride_v = cutlass::make_cute_packed_stride(StrideV{}, shape_v); + + auto shape_k_cache = make_shape(static_cast(PagedKV? total_seq_len_kv_cache : seq_len_kv_cache), + head_size_qk * num_heads_kv, 1); + StrideK stride_k_cache = + cutlass::make_cute_packed_stride(StrideK{}, shape_k_cache); + auto shape_v_cache = make_shape(head_size_vo * num_heads_kv, + static_cast(PagedKV? total_seq_len_kv_cache : seq_len_kv_cache), 1); + StrideV stride_v_cache = + cutlass::make_cute_packed_stride(StrideV{}, shape_v_cache); + auto tensorQ = make_tensor(make_gmem_ptr(q_ptr + offset_q), + make_layout(shape_q, stride_q)); + auto tensorK = make_tensor(make_gmem_ptr(k_ptr + offset_k), + make_layout(shape_k, stride_k)); + auto tensorV = make_tensor(make_gmem_ptr(v_ptr + offset_v), + make_layout(shape_v, stride_v)); + auto tensorK_cache = + make_tensor(make_gmem_ptr(k_cache_ptr + offset_k_cache), + make_layout(shape_k_cache, stride_k_cache)); + auto tensorV_cache = + make_tensor(make_gmem_ptr(v_cache_ptr + offset_v_cache), + make_layout(shape_v_cache, stride_v_cache)); + XE_Copy_Q copyQ{XE_Copy_Q{}.with(tensorQ)}; + XE_Copy_K copyK{XE_Copy_K{}.with(tensorK)}; + XE_Copy_V copyV{XE_Copy_V{}.with(tensorV)}; + XE_Copy_K copyK_cache{XE_Copy_K{}.with(tensorK_cache)}; + XE_Copy_V copyV_cache{XE_Copy_V{}.with(tensorV_cache)}; + return Params{copyQ, + copyK, + copyV, + copyK_cache, + copyV_cache, + params.ptr_page_table, + params.page_size, + params.num_pages_per_seq, + params.window_left, + params.window_right}; + } +}; + +} // namespace cutlass::flash_attention::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/applications/flash_attention_v2/collective/xe_flash_attn_chunk_prefill_softmax_epilogue.hpp b/applications/flash_attention_v2/collective/xe_flash_attn_chunk_prefill_softmax_epilogue.hpp new file mode 100644 index 0000000000..cc0d99da46 --- /dev/null +++ b/applications/flash_attention_v2/collective/xe_flash_attn_chunk_prefill_softmax_epilogue.hpp @@ -0,0 +1,222 @@ +/*************************************************************************************************** + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing online softmax. +*/ + +#pragma once + +#include +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_epilogue.hpp" +#include "cutlass/epilogue/collective/detail.hpp" +#include "cutlass/detail/layout.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace flash_attention { +namespace collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template class FlashChunkPrefillSoftmaxEpilogue { + static_assert(cutlass::detail::dependent_false, "Could not find an epilogue specialization."); +}; + + +template +class FlashChunkPrefillSoftmaxEpilogue { +public: + + // + // Type Aliases + // + using DispatchPolicy = epilogue::IntelXeXMX16; + using Element = Element_; + + static constexpr bool CausalMask = CausalMask_; + static constexpr bool LocalMask = LocalMask_; + + using GmemTiledCopyOut = void; + + // Host side epilogue arguments + struct Arguments { + Element const scale; + }; + + // Device side epilogue params + using Params = Arguments; + + // + // Methods + // + + static constexpr Params to_underlying_arguments(Arguments const &args) { + constexpr double kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E + Element val = args.scale * static_cast(kLog2e); + return Params{val}; + } + + template + static size_t get_workspace_size() { + return 0; + } + + template + static cutlass::Status initialize_workspace() { + return Status::kSuccess; + } + + template + CUTLASS_HOST_DEVICE static bool can_implement() { + return true; + } + + CUTLASS_HOST_DEVICE + FlashChunkPrefillSoftmaxEpilogue(Params const ¶ms_) : params(params_) {} + + template + CUTLASS_DEVICE void scale_exp_log2(FragAcc &frag_s, FragMax const &max, FragSum &sum) { + auto g = compat::get_nd_item<1>().get_sub_group(); + const auto max_scale = max * params.scale; + CUTLASS_PRAGMA_UNROLL + for (int indx = 0; indx < Vec * FragsM; indx++) { + const auto max_scale_bcast = group_broadcast(g, max_scale, indx); + CUTLASS_PRAGMA_UNROLL + for (int z = 0; z < FragsN; z++) { + auto base_indx = indx + (z * Vec * FragsM); + if constexpr (LocalMask) { + if ((std::isinf(max_scale) && max_scale < 0) || + (std::isinf(frag_s(base_indx)) && frag_s(base_indx) < 0)) { + frag_s(base_indx) = 0.f; + // continue; + } else { + Element eq = frag_s(base_indx) - max_scale_bcast; + frag_s(base_indx) = sycl::native::exp2(eq); + } + } else { + Element eq = frag_s(base_indx) - max_scale_bcast; + frag_s(base_indx) = sycl::native::exp2(eq); + } + sum(indx) += frag_s(base_indx); + } + } + } + + template + CUTLASS_DEVICE void reduce_max(FragSrc &src, FragMax &max) { + auto sg = compat::get_nd_item<1>().get_sub_group(); + CUTLASS_PRAGMA_UNROLL + for (int indx = 0; indx < Vec * FragsM; indx++) { + auto maxptr = group_broadcast(sg, max, indx); + CUTLASS_PRAGMA_UNROLL + for (int z = 0; z < FragsN; z++) { + auto base_indx = indx + (z * Vec * FragsM); + maxptr = sycl::max(maxptr, src(base_indx)); + src(base_indx) *= params.scale; + } + maxptr = reduce_over_group(sg, maxptr, sycl::maximum<>()); + if (indx == sg.get_local_id()[0]) { + max = maxptr; + } + } + } + + template + CUTLASS_DEVICE void operator()(bool is_first, FragAcc &frag_s, FragMax &max, FragSum &sum, FragOut &out) { + auto max_prev = max; + using FragAccLayout = typename FragAcc::layout_type; + using FragOutLayout = typename FragOut::layout_type; + constexpr int Vec = get<0>(FragAccLayout{}.shape()); + constexpr int FragsM = get<1>(FragAccLayout{}.shape()); + constexpr int FragsNAcc = get<2>(FragAccLayout{}.shape()); + constexpr int FragsNOut = size(select<2,3>(FragOutLayout{}.shape())); + reduce_max(frag_s, max); + static_assert(Vec * FragsM % 8 == 0, " No. of attention rows per subgroup should be >= 1 MMA Atom worth of rows."); + if (!is_first) { + auto sg = compat::get_nd_item<1>().get_sub_group(); + Element max_scale{max * params.scale}; + Element exp_scale; + if constexpr (LocalMask) { + if ((std::isinf(max_scale) && max_scale < 0) || (std::isinf(max_prev) && max_prev < 0)) { + exp_scale = 0.f; + } else { + exp_scale = sycl::native::exp2(max_prev * params.scale - max_scale); + } + } else { + exp_scale = sycl::native::exp2(max_prev * params.scale - max_scale); + } + + CUTLASS_PRAGMA_UNROLL + for (int indx = 0; indx < Vec * FragsM; indx++) { + auto max_scale_bcast = group_broadcast(sg, max_scale, indx); + auto exp_scale_bcast = group_broadcast(sg, exp_scale, indx); + sum(indx) *= exp_scale_bcast; + CUTLASS_PRAGMA_UNROLL + for (int z = 0; z < FragsNAcc; z++) { + auto base_indx = indx + (z * Vec * FragsM); + if constexpr (LocalMask) { + if ((std::isinf(max_scale) && max_scale < 0) || + (std::isinf(frag_s(base_indx)) && frag_s(base_indx) < 0)) { + frag_s(base_indx) = 0.f; + // continue; + } else { + Element eq = frag_s(base_indx) - max_scale_bcast; + frag_s(base_indx) = sycl::native::exp2(eq); + } + } else { + Element eq = frag_s(base_indx) - max_scale_bcast; + frag_s(base_indx) = sycl::native::exp2(eq); + } + sum(indx) += frag_s(base_indx); + } + CUTLASS_PRAGMA_UNROLL + for (int z = 0; z < FragsNOut; z++) { + auto base_indx = indx + (z * Vec * FragsM); + out(base_indx) *= exp_scale_bcast; + } + } + } else { + scale_exp_log2(frag_s, max, sum); + } + } + Params params; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace collective +} // namespace flash_attention +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/applications/flash_attention_v2/kernel/tile_scheduler_chunk_prefill.hpp b/applications/flash_attention_v2/kernel/tile_scheduler_chunk_prefill.hpp new file mode 100644 index 0000000000..6d429d52bc --- /dev/null +++ b/applications/flash_attention_v2/kernel/tile_scheduler_chunk_prefill.hpp @@ -0,0 +1,238 @@ +/*************************************************************************************************** + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/kernel_hardware_info.h" + +namespace cutlass::flash_attention { + +namespace kernel { + +struct XeFlashIndividualTileScheduler { + + struct Params { + dim3 grid; + // FastDivmod divmod_num_heads; + }; + + bool valid_ = true; + Params params; + + CUTLASS_DEVICE + XeFlashIndividualTileScheduler(Params const ¶ms) : params(params) {} + + template + static Params to_underlying_arguments(ProblemSize const &problem_size, + KernelHardwareInfo hw_info, + TileShape const &tile_shape) { + using namespace cute; + // problem_size = [batch, num_heads_q , num_heads_kv, seq_len_qo, + // seq_len_kv, seq_len_kv_cache, head_size_qk, head_size_vo] + + // dim3 grid(size(ceil_div(shape<7>(problem_size), shape<1>(tile_shape))), + // size(ceil_div(shape<3>(problem_size), shape<0>(tile_shape))), + // size(shape<0>(problem_size) * shape<1>(problem_size))); + + int batch = size<0>(problem_size); + int num_heads_q = size<1>(problem_size); + int num_heads_kv = size<2>(problem_size); + int seq_len_qo = + size<3>(problem_size); // if varlen seq_len_qo = max_seq_len + int seq_len_kv = + size<4>(problem_size); // if varlen seq_len_qo = max_seq_len + int seq_len_kv_cache = size<5>(problem_size); + int head_size_qk = size<6>(problem_size); + int head_size_vo = size<7>(problem_size); + auto group_heads_q = num_heads_q / num_heads_kv; + + dim3 grid(size(ceil_div(shape<3>(problem_size), shape<0>(tile_shape))), + size(shape<1>(problem_size)), size(shape<0>(problem_size))); + return Params{grid}; + } + + + template static dim3 get_grid_shape(Params const ¶ms) { + return params.grid; + } + + CUTLASS_DEVICE + bool is_valid() { return valid_; } + + CUTLASS_DEVICE + auto get_block_coord() { + using namespace cute; + return make_coord(BlockIdxX(), BlockIdxY(), BlockIdxZ()); + } + + CUTLASS_DEVICE + XeFlashIndividualTileScheduler &operator++() { + valid_ = false; + return *this; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +struct XeFlashPersistentTileScheduler { + + struct Params { + int num_blocks; + FastDivmod divmod_seq_len_block; + FastDivmod divmod_head_size_block; + FastDivmod divmod_num_heads; + + KernelHardwareInfo hw_info; + }; + + int block_idx = 0; + Params params; + + CUTLASS_DEVICE + XeFlashPersistentTileScheduler(Params const ¶ms) + : block_idx(BlockIdxX()), params(params) {} + + template + static Params to_underlying_arguments(ProblemSize const &problem_size, + KernelHardwareInfo hw_info, + TileShape const &tile_shape) { + using namespace cute; + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = hw_info.sm_count; + if (sm_count <= 0) { + CUTLASS_TRACE_HOST( + " WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments " + "KernelHardwareInfo struct with the SM count."); + sm_count = KernelHardwareInfo::query_device_multiprocessor_count( + hw_info.device_id); + } + + CUTLASS_TRACE_HOST( + "to_underlying_arguments(): Setting persistent grid SM count to " + << sm_count); + hw_info.sm_count = sm_count; + + // problem_size = [batch, num_heads_q, numhead_kv, seq_len_qo, seq_len_kv, + // seq_len_kv_cache, head_size_qk, head_size_vo] + int num_head_size_blocks = + size(ceil_div(shape<7>(problem_size), shape<1>(tile_shape))); + int num_seq_len_blocks = + size(ceil_div(shape<3>(problem_size), shape<0>(tile_shape))); + int num_blocks = num_seq_len_blocks * num_head_size_blocks * + size(shape<0>(problem_size) * shape<1>(problem_size)); + + return Params{num_blocks, + {num_seq_len_blocks}, + {num_head_size_blocks}, + {shape<1>(problem_size)}, + hw_info}; + } + + template static dim3 get_grid_shape(Params const ¶ms) { + auto queue = compat::get_default_queue(); + auto dev = queue.get_device(); + const size_t maxSubgroups = + dev.template get_info(); + // TODO (Codeplay): revert this back to std::min(params.num_blocks, + // params.hw_info.sm_count) once performance issue is fixed. + dim3 grid( + std::min(params.num_blocks, + ceil_div(params.hw_info.sm_count * maxSubgroups, Num_SGs)), + 1, 1); + return grid; + } + + CUTLASS_DEVICE + bool is_valid() { return block_idx < params.num_blocks; } + + CUTLASS_DEVICE + auto get_block_coord() { + using namespace cute; + int block_decode = block_idx; + int seq_len_block, head_size_block, bidh; + params.divmod_head_size_block(block_decode, head_size_block, block_decode); + params.divmod_seq_len_block(block_decode, seq_len_block, block_decode); + params.divmod_num_heads(block_decode, bidh, block_decode); + return make_coord(head_size_block, seq_len_block, block_decode, bidh); + } + + CUTLASS_DEVICE + XeFlashPersistentTileScheduler &operator++() { + block_idx += GridDimX(); + return *this; + } +}; + +//////////////////////////////////////////////////////////////////////////////// +} // namespace kernel + +struct IndividualScheduler {}; +struct PersistentScheduler {}; + +namespace detail { + +template +struct TileSchedulerSelector { + static_assert(cutlass::detail::dependent_false, + "Could not select a tile scheduler for given parameters."); +}; + +// Default (void) maps to XeFlashIndividualTileScheduler +template +struct TileSchedulerSelector< + void, ArchTag, + cute::enable_if_t>> { + using Scheduler = + typename TileSchedulerSelector::Scheduler; +}; + +template +struct TileSchedulerSelector< + IndividualScheduler, ArchTag, + cute::enable_if_t>> { + using Scheduler = kernel::XeFlashIndividualTileScheduler; +}; + +template +struct TileSchedulerSelector< + PersistentScheduler, ArchTag, + cute::enable_if_t>> { + using Scheduler = kernel::XeFlashPersistentTileScheduler; +}; +} // namespace detail + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::flash_attention diff --git a/applications/flash_attention_v2/kernel/xe_chunk_prefill.hpp b/applications/flash_attention_v2/kernel/xe_chunk_prefill.hpp new file mode 100644 index 0000000000..8a6e22b355 --- /dev/null +++ b/applications/flash_attention_v2/kernel/xe_chunk_prefill.hpp @@ -0,0 +1,675 @@ +/*************************************************************************************************** + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice,this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE + * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR + * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF + * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS + * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN + * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) + * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE + * POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/kernel_hardware_info.hpp" + +#include "flash_attention_v2/collective/xe_flash_attn_chunk_prefill_mma.hpp" +namespace cutlass::flash_attention::kernel { + +template +class FMHAPrefillChunk; +/////////////////////////////////////////////////////////////////////////////// +template +class FMHAPrefillChunk { + +public: + // + // Type Aliases + // + using ProblemShape = ProblemShape_; + + // ProblemShape: + static_assert( + rank(ProblemShape{}) == 8, + "ProblemShape{} should be "); + // Mainloop derived types + using CollectiveMainloop = CollectiveMainloop_; + using TileShapeQK = typename CollectiveMainloop::TileShapeQK; + using TileShapePV = typename CollectiveMainloop::TileShapePV; + using TiledMmaQK = typename CollectiveMainloop::TiledMmaQK; + using TiledMmaPV = typename CollectiveMainloop::TiledMmaPV; + using ArchTag = typename CollectiveMainloop::ArchTag; + using ElementQ = typename CollectiveMainloop::ElementQ; + using StrideQ = typename CollectiveMainloop::StrideQ; + using ElementK = typename CollectiveMainloop::ElementK; + using StrideK = typename CollectiveMainloop::StrideK; + using ElementV = typename CollectiveMainloop::ElementV; + using StrideV = typename CollectiveMainloop::StrideV; + using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + + using CollectiveSoftmaxEpilogue = CollectiveSoftmaxEpilogue_; + using SoftmaxArguments = typename CollectiveSoftmaxEpilogue::Arguments; + using SoftmaxParams = typename CollectiveSoftmaxEpilogue::Params; + + static_assert(cute::is_void_v or + cute::is_same_v or + cute::is_same_v, + "Unsupported TileScheduler for Intel Xe."); + using TileSchedulerTag = TileScheduler_; + using TileScheduler = + typename detail::TileSchedulerSelector::Scheduler; + using TileSchedulerParams = typename TileScheduler::Params; + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using ElementO = typename CollectiveEpilogue::ElementO; + using StrideO = typename CollectiveEpilogue::StrideO; + using ElementLSE = typename CollectiveEpilogue::ElementLSE; + using ElementSink = typename CollectiveEpilogue::ElementSink; + static constexpr bool Sink = CollectiveEpilogue::Sink; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + using TileShapeOutput = typename CollectiveEpilogue::TileShapeOutput; + using TiledMmaOutput = typename CollectiveEpilogue::TiledMmaOutput; + + static_assert( + cute::is_same_v, + "Mainloop and epilogue do not agree on accumulator value type."); + // MSVC requires the cast to fix a warning-as-error. + static constexpr int SharedStorageSize = 0; + + static constexpr bool CausalMask = CollectiveMainloop::CausalMask; + static constexpr bool LocalMask = CollectiveMainloop::LocalMask; + + static_assert(!(CausalMask && LocalMask), "Cannot be both causal and local"); + static constexpr bool PagedKV = CollectiveMainloop::PagedKV; + + + static constexpr int SubgroupSize = + CollectiveMainloop::SubgroupSize; // sub_group size + static constexpr uint32_t MaxThreadsPerBlock = + CollectiveMainloop::MaxThreadsPerBlock; + using MmaAtomShape = typename CollectiveMainloop::MmaAtomShape; // 8,16,16 + + static constexpr int QK_BLK_M = CollectiveMainloop::QK_BLK_M; + static constexpr int QK_BLK_N = CollectiveMainloop::QK_BLK_N; + static constexpr int QK_BLK_K = CollectiveMainloop::QK_BLK_K; + + static constexpr int QK_ATOM_N = CollectiveMainloop::QK_ATOM_N; + static constexpr int QK_ATOM_K = CollectiveMainloop::QK_ATOM_K; + + static constexpr int QK_SG_M = CollectiveMainloop::QK_SG_M; + + static constexpr int Epilogue_BLK_N = get<1>(TileShapeOutput{}); + static constexpr int Epilogue_BLK_K = get<2>(TileShapeOutput{}); + + static constexpr int PV_ATOM_M = CollectiveMainloop::PV_ATOM_M; + static constexpr int PV_ATOM_N = CollectiveMainloop::PV_ATOM_N; + static constexpr int PV_ATOM_K = CollectiveMainloop::PV_ATOM_K; + + static constexpr auto Num_SGs = PV_ATOM_N * PV_ATOM_M * PV_ATOM_K; + static constexpr int Vec = CollectiveMainloop::Vec; + static constexpr int FragsM = CollectiveMainloop::FragsM; + // The FragsN here used for Creation of S matrix so we use the FragsN for S + // shape + static constexpr int FragsN = CollectiveMainloop::FragsNS; + + static constexpr int VSlicer = + get<1>(TileShapeOutput{}) / + (get<1>(TileShapePV{}) * PV_ATOM_N); // ceil_div(FragsNOut,FragsNS); + using AccumeShape = decltype(make_shape( + Int{}, Int{}, get<1>(TileShapePV{}) / get<1>(MmaAtomShape()), + Int{})); + + static constexpr bool is_var_len = CollectiveMainloop::is_var_len; + // Kernel level shared memory storage + struct SharedStorage { + using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; + EpilogueTensorStorage epilogue; + }; + + // Device side arguments + struct Arguments { + gemm::GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopArguments mainloop{}; + SoftmaxArguments softmax{}; + EpilogueArguments epilogue{}; + KernelHardwareInfo hw_info{}; + }; + + // Kernel entry point API + struct Params { + gemm::GemmUniversalMode mode; + ProblemShape problem_shape; + MainloopParams mainloop; + SoftmaxParams softmax; + EpilogueParams epilogue; + TileSchedulerParams scheduler; + }; + + // + // Methods + // + + // Convert to underlying arguments. In this case, a simple copy for the + // aliased type. + static Params to_underlying_arguments(Arguments const &args, + void *workspace) { + (void)workspace; + return {args.mode, + args.problem_shape, + CollectiveMainloop::to_underlying_arguments( + args.problem_shape, args.mainloop, workspace), + CollectiveSoftmaxEpilogue::to_underlying_arguments(args.softmax), + CollectiveEpilogue::to_underlying_arguments( + args.problem_shape, args.epilogue, workspace), + TileScheduler::to_underlying_arguments( + args.problem_shape, args.hw_info, TileShapeOutput{})}; + } + + static bool can_implement(Arguments const &args) { + bool mode_implementable = args.mode == gemm::GemmUniversalMode::kGemm or + (args.mode == gemm::GemmUniversalMode::kBatched && + rank(ProblemShape{}) == 4); + return mode_implementable; + } + + static int get_workspace_size(Arguments const &args) { return 0; } + + static cutlass::Status + initialize_workspace(Arguments const &args, void *workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr) { + return Status::kSuccess; + } + + static dim3 get_grid_shape(Params const ¶ms) { + return TileScheduler::template get_grid_shape(params.scheduler); + } + + static dim3 get_block_shape() { return dim3(MaxThreadsPerBlock, 1, 1); } + + CUTLASS_DEVICE + Shape + get_sequence_length_shape(ProblemShape const &problem_shape, + int const &batch) { + if constexpr (is_var_len) { + return cutlass::fmha::collective::apply_variable_length( + select<3, 4, 5>(problem_shape), batch); + } else { + return select<3, 4, 5>(problem_shape); + } + } + + CUTLASS_DEVICE + void operator()(Params const ¶ms, char *smem_buf) { + SharedStorage &shared_storage = + *reinterpret_cast(smem_buf); + // Preconditions + CUTE_STATIC_ASSERT(is_static::value); + CUTE_STATIC_ASSERT(is_static::value); + // Separate out problem shape for convenience + + // "ProblemShape{} should be "); + auto batch = get<0>(params.problem_shape); + auto num_heads_q = get<1>(params.problem_shape); + auto num_heads_kv = get<2>(params.problem_shape); + + auto &head_size_qk = get<6>(params.problem_shape); + auto &head_size_vo = get<7>(params.problem_shape); + // Preconditions + static_assert(cute::rank(StrideQ{}) == 3, + "StrideQ must be rank-3: [seq_len_qo, head_size_qk, batch * " + "num_heads_q]."); + static_assert(cute::rank(StrideK{}) == 3, + "StrideK must be rank-3: [head_size_qk, seq_len_kv, batch * " + "num_heads_kv]."); + static_assert(cute::rank(StrideV{}) == 3, + "StrideV must be rank-3: [seq_len_kv, head_size_vo, batch * " + "num_heads_kv]."); + + int thread_idx = int(ThreadIdxX()); + int sub_group_id = thread_idx / SubgroupSize; + + TileScheduler tile_scheduler{params.scheduler}; + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = + tile_scheduler + .get_block_coord(); // head_size_blk_idx, seq_len_blk_idx, + // batch_blk_idx, num_heads_blk_idx + + auto blk_m_coord = get<0>(blk_coord); // seq_len_blk_idx + auto blk_n_coord = 0; // nums_head_blk_idx + auto q_head_coord = get<1>(blk_coord); // q_heads_idx + auto batch_coord = get<2>(blk_coord); // batch_blk_idx + + // For variable sequence length case, batch is considered to be 1 (same + // as group gemm). For fixed sequence length case, the l_coord is the + // weighted sum of both batch_coord and num_heads_coord. Flash Attention + // implementation combines batch and num_heads to calculate the total + // batch_size. iff is_var_len: batch_size = num_heads (as each batch + // would have it's own seq_len_qo and seq_len_kv) iff !is_var_len: + // batch_size = batch * num_heads + // auto blk_l_coord = q_head_coord; + + // Get problem shape for the current batch_blk_idx. For variable + // sequence length, it loads the sequence length from Global memory for + // the given batch_blk_idx and returns the appropriate problem_shape. + // For fixed sequence length, sequence_length_shape == select<3, 4, + // 5>(params.problem_shape). sequence_length_shape = [batch, + // num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, seq_len_kv_cache, + // head_size_qk, head_size_vo] + auto sequence_length_shape = + get_sequence_length_shape(params.problem_shape, batch_coord); + + auto [seq_len_qo, seq_len_kv, seq_len_kv_cache] = sequence_length_shape; + // int seq_len_kv_total = seq_len_kv_cache + seq_len_kv; + // For variable sequence length case, batch is considered to be 1 (same + // as group gemm). For fixed sequence length case, the l_coord is the + // weighted sum of both batch_coord and num_heads_coord. Flash Attention + // implementation combines batch and num_heads to calculate the total + // batch_size. iff is_var_len: batch_size = num_heads (as each batch + // would have it's own seq_len_qo and seq_len_kv) iff !is_var_len: + // batch_size = batch * num_heads + + // Calculate the seq_len_idx (blk_m_coord * get<0>(TileShapeOutput{})) + // and check if it is still within bounds of the actual seq_len_qo + // (get<0>(sequence_length_shape)). + if (blk_m_coord * get<0>(TileShapeOutput{}) >= + seq_len_qo) { + continue; + } + + const int seq_coord = + cute::min(seq_len_qo, (blk_m_coord * QK_BLK_M + (sub_group_id / PV_ATOM_N) * QK_SG_M) % + seq_len_qo); + auto offset = cute::min(seq_len_qo, seq_len_kv); //(2048, 1024) + auto discard_seq_coord = seq_len_qo - offset; // 1024 + auto full_tile_offset = seq_len_kv - offset; // 0 + + const int seq_len = + CausalMask + ? full_tile_offset + + cute::min(seq_len_kv, seq_coord - discard_seq_coord) + + QK_SG_M + : seq_len_kv; + + const int kv_splits_new = cute::ceil_div(seq_len, QK_BLK_N); + const int kv_splits_cache = cute::ceil_div(seq_len_kv_cache, QK_BLK_N); + const int kv_splits = kv_splits_cache + kv_splits_new; + + int tiles_per_page = params.mainloop.page_size / QK_BLK_N; + + if (CausalMask && seq_coord < discard_seq_coord) { // 1024 =0 + continue; + } + + Tensor mQ_mkl = cute::get_xe_tensor( + make_shape(seq_len_qo, head_size_qk, 1)); //(m,k,l) + + Tensor mK_nkl = cute::get_xe_tensor( + make_shape(seq_len_kv, head_size_qk, 1)); //(n,k,l) + Tensor mV_nkl = cute::get_xe_tensor( + make_shape(head_size_vo, seq_len_kv, 1)); //(n,k,l) + Tensor mK_cache_nkl = cute::get_xe_tensor( + make_shape(seq_len_kv_cache, head_size_qk, 1)); // (n_cache,k,l) + Tensor mV_cache_nkl = cute::get_xe_tensor( + make_shape(head_size_vo, seq_len_kv_cache, 1)); // (n_cache,k,l) + + // block_size and head_size are the same size. So no coord is needed. + Tensor mQ_mk = mQ_mkl(_, _, 0); + + Tensor mK_nk = mK_nkl(_, _, 0); // (n,k) + Tensor mV_nk = mV_nkl(_, _, 0); + + Tensor mK_cache_nk = mK_cache_nkl(_, _, 0); // (n_cache, k) + Tensor mV_cache_nk = mV_cache_nkl(_, _, 0); // (n_cache, k) + + auto gQ = local_tile(mQ_mk, TileShapeQK{}, make_coord(blk_m_coord, _, _), + Step<_1, X, _1>{}); + auto gK = local_tile(mK_nk, TileShapeQK{}, make_coord(_, _, _), + Step{}); + + auto gV = local_tile(mV_nk, TileShapeOutput{}, + make_coord(_, blk_n_coord, _), Step{}); + auto gK_cache = local_tile(mK_cache_nk, TileShapeQK{}, + make_coord(_, _, _), Step{}); + auto gV_cache = + local_tile(mV_cache_nk, TileShapeOutput{}, + make_coord(_, blk_n_coord, _), Step{}); + + auto mainloop_params = CollectiveMainloop::get_updated_copies( + params.mainloop, params.problem_shape, sequence_length_shape, + batch_coord, q_head_coord); + + + // we limit the horisontal size to two subgroup, the empirical resutls + // show that reading the two cacheline side by side in gives better + // performance and anything after that does not have an effect on + // performance. // (64 here for float b float when possible and loop over + // to cover all the data needed) + auto tiled_prefetch_q = cute::prefetch_selector< + Shape, Int>, + Num_SGs>(mainloop_params.gmem_tiled_copy_q); + auto tiled_prefetch_k = cute::prefetch_selector< + Shape, Int>, + Num_SGs>(mainloop_params.gmem_tiled_copy_k); + auto tiled_prefetch_v = cute::prefetch_selector< + Shape, + Int>, + Num_SGs>(mainloop_params.gmem_tiled_copy_v); + auto tiled_prefetch_k_cache = cute::prefetch_selector< + Shape, Int>, + Num_SGs>(mainloop_params.gmem_tiled_copy_k_cache); + auto tiled_prefetch_v_cache = cute::prefetch_selector< + Shape, + Int>, + Num_SGs>(mainloop_params.gmem_tiled_copy_v_cache); + auto thr_prefetch_Q = tiled_prefetch_q.get_slice(thread_idx); + auto thr_prefetch_K = tiled_prefetch_k.get_slice(thread_idx); + auto thr_prefetch_V = tiled_prefetch_v.get_slice(thread_idx); + auto pQgQ = thr_prefetch_Q.partition_S(gQ); + auto pKgK = thr_prefetch_K.partition_S(gK); + auto pVgV = thr_prefetch_V.partition_S(gV); + // assuming the copy function is the same otherwise this need to have its + // own tile_prefetch + auto pKgK_cache = thr_prefetch_K.partition_S(gK_cache); + auto pVgV_cache = thr_prefetch_V.partition_S(gV_cache); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<3>(pQgQ); i++) { + prefetch(tiled_prefetch_q, pQgQ(_, _, _, i)); + } + auto &prefetch_K = + (seq_len_kv_cache == 0) ? tiled_prefetch_k : tiled_prefetch_k_cache; + auto &pKgK1_ = (seq_len_kv_cache == 0) ? pKgK : pKgK_cache; + + int cached_nblock = 0; + if constexpr (PagedKV) { + int curr_batch_pages = ceil_div(seq_len_kv_cache, mainloop_params.page_size); + int batch_offset = + is_var_len ? mainloop_params.num_pages_per_seq[batch_coord] + : batch_coord * curr_batch_pages; + cached_nblock = + mainloop_params + .ptr_page_table[batch_offset // page table for this batch + ] * tiles_per_page; // base block idx of physical page + } + // The headsize for both cached and non-cached version is the same + for (int j = 0; j < size<4>(pKgK1_); j++) { + CUTLASS_PRAGMA_UNROLL + for (int i = cached_nblock; i < cached_nblock + DispatchPolicy::Stages; + i++) { + prefetch(prefetch_K, pKgK1_(_, _, _, i, j)); + } + } + + // Allocate the tiled_mma and the accumulators for the (M,N) + // workgroup_shape + Tensor out_reg = make_tensor(AccumeShape{}); + + // There are 16 workitem and 16 max per subgroup, each worktime containt 1 + // max and cumulatively, they calculate the max per subgroup + ElementAccumulator max_reg{-INFINITY}; + // The sum reg each contains a 2d tesnor for 8 x 2 This is number of + // sequence lenght process per subgroup + Tensor sum_reg = + make_tensor(Shape, Int>{}); + + clear(sum_reg); + clear(out_reg); + // Perform the collective scoped MMA + CollectiveMainloop collective_mma; + // when causal mask is true. It is not possible to set the scope + // of the barrier to workgroup level as the number n block is + // different for each subgroup due to triangular nature of causal based + // operation + static constexpr int barrier_scope = CausalMask ? 3 : 2; + CUTLASS_PRAGMA_UNROLL + for (int split = 0; split < kv_splits - static_cast(CausalMask); split++) { + barrier_arrive(barrier_scope); + + bool is_KV_cache = split < kv_splits_cache; + // 1) Load KV (performed inside mmaQK) + auto gK_ = is_KV_cache ? gK_cache(_, _, cached_nblock, _) + : gK(_, _, split - kv_splits_cache, _); + auto gV_ = is_KV_cache ? gV_cache(_, _, cached_nblock) + : gV(_, _, split - kv_splits_cache); + // 2) Create Tensor S + Tensor tSr = make_tensor( + Shape, Int, Int>{}); + clear(tSr); + // 3) Perform GEMM S = Q*K + // Then modify layout to LayoutQ = ((seq_leq_q, group_head_q), + // head_size_qk, batch* num_heads_q / group_head_q), which can be merged + // into one gemm for (int i = 0; i < q_group_size; ++i) { + collective_mma.mmaQK(tSr, gQ, gK_, tSr, + ceil_div(head_size_qk, QK_BLK_K), mainloop_params, + is_KV_cache); + + if constexpr (LocalMask) { + // Sliding windows + // mask the elements of each tile where j - left > i || j + right < i + const int item_id = thread_idx % SubgroupSize; + int col_idx; + if (split < kv_splits_cache) { + col_idx = item_id + split * cute::min(QK_BLK_N, seq_len_kv_cache) ; + } else { + col_idx = item_id + seq_len_kv_cache + (split - kv_splits_cache) * cute::min(QK_BLK_N, seq_len_kv); + } + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < FragsN; + n++, col_idx += get<1>(MmaAtomShape())) { // 4 + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < FragsM; m++) { // 2 + int row_idx = m * Vec + seq_coord; + int col_ref = seq_len_kv_cache + seq_len_kv - seq_len_qo; + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < Vec; row++) { // 8 + bool left_mask = col_idx < cute::max(0, row + row_idx + col_ref - mainloop_params.window_left); + bool right_mask = col_idx > cute::min(seq_len_kv_cache + seq_len_kv, row + row_idx + col_ref + mainloop_params.window_right); + if (left_mask || right_mask) { + tSr(row, m, n) = ElementAccumulator{-INFINITY}; + } + } + } + } + } + + if constexpr(!(CausalMask || LocalMask) && PagedKV) { + // Processing Not divisible, mask padding + const int item_id = thread_idx % SubgroupSize; + int col_idx = item_id + split * cute::min(QK_BLK_N, seq_len_kv_cache + seq_len_kv); + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < FragsN; n++, col_idx += get<1>(MmaAtomShape())) { // 4 + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < FragsM; m++) { // 2 + int row_idx = m * Vec + seq_coord; + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < Vec; row++) { // 8 + if (col_idx >= seq_len_kv_cache + seq_len_kv || row_idx + row >= seq_len_qo) { + tSr(row, m, n) = ElementAccumulator{-INFINITY}; + } + } + } + } + } + auto &tiled_prefetch_v_ = + is_KV_cache ? tiled_prefetch_v_cache + : tiled_prefetch_v; + auto &pVgV_ = is_KV_cache ? pVgV_cache : pVgV; + int v_prefetch_idx = is_KV_cache ? PagedKV ? cached_nblock : split + : split - kv_splits_cache; + for (int i = 0; i < size<1>(pVgV_); i++) { + prefetch(tiled_prefetch_v_, pVgV_(_, i, _, v_prefetch_idx)); + } + int next_cached_nblock = split + 1; + bool is_next_KV_cache = next_cached_nblock < kv_splits_cache; + if constexpr (PagedKV) { + if (is_next_KV_cache) { + int curr_batch_pages = ceil_div(seq_len_kv_cache, mainloop_params.page_size); + int next_page_logical_idx = + next_cached_nblock * QK_BLK_N / params.mainloop.page_size; + int batch_offset = + is_var_len ? mainloop_params.num_pages_per_seq[batch_coord] + : batch_coord * curr_batch_pages; + bool valid_page = next_page_logical_idx < curr_batch_pages; + // get physical page idx from page table + if (valid_page) { + next_cached_nblock = + params.mainloop.ptr_page_table + [batch_offset + // page table for this batch + next_page_logical_idx // split (tile idx) to logical + // page idx + ] * tiles_per_page + // base block idx of physical page + next_cached_nblock % tiles_per_page; // offset within page + } else { + next_cached_nblock = + curr_batch_pages * + tiles_per_page; // push idx out of bounds to respect the + // boundary between batches + } + } + } + + // 4) Fused softmax + CollectiveSoftmaxEpilogue softmax(params.softmax); + softmax(split == 0, tSr, max_reg, sum_reg, out_reg); + + // 5) Perform GEMM O = S*V + collective_mma.template mmaPV(out_reg, tSr, gV_, out_reg, + mainloop_params, is_KV_cache); + // ... prefetch next tile ... + // Prefetch the next Q tile + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<3>(pQgQ); i++) { + prefetch(tiled_prefetch_q, pQgQ(_, _, _, i)); + } + + is_KV_cache = is_next_KV_cache; + cached_nblock = next_cached_nblock; + // Prefetch the next K tile + // there is no need to gaurd it with if statememt as prefetch will + // ignore out of bound reading + if constexpr (PagedKV) { + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size<4>(pKgK_cache); j++) { + prefetch(tiled_prefetch_k_cache, pKgK_cache(_, _, _, cached_nblock, j)); + } + } else { + bool sel_prefetch_k = + (split + DispatchPolicy::Stages) < kv_splits_cache; + auto &prefetch_k_selector = + sel_prefetch_k ? tiled_prefetch_k_cache : tiled_prefetch_k; + auto &pKgK_ = sel_prefetch_k ? pKgK_cache : pKgK; + int k_prefetch_idx = + sel_prefetch_k + ? PagedKV ? cached_nblock : split + DispatchPolicy::Stages + : split + DispatchPolicy::Stages - kv_splits_cache; + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size<4>(pKgK_); j++) { + prefetch(prefetch_k_selector, pKgK_(_, _, _, k_prefetch_idx, j)); + } + } + barrier_wait(barrier_scope); + } + + if constexpr (CausalMask) { + // BAND Matrix + // 1) Load K (performed inside mmaQK) + // 2) Create Tensor S + Tensor tSr = make_tensor( + Shape, Int, Int>{}); + clear(tSr); + // 3) Perform GEMM S = Q*K + collective_mma.mmaQK(tSr, gQ, gK(_, _, kv_splits_new - 1, _), tSr, + ceil_div(head_size_qk, QK_BLK_K), mainloop_params, + false); + // we only need one block ahead, there is enough gap to prefetch it + // while doing softmax. because the gap between the two MMA is big, + // prefetching it the same way as cutlass K matrix does not make sense + for (int i = 0; i < size<1>(pVgV); i++) { + prefetch(tiled_prefetch_v, pVgV(_, i, _, kv_splits_new - 1)); + } + // mask the elements of each tile where j > i + const int item_id = thread_idx % SubgroupSize; + int col_idx = item_id + (kv_splits_new - 1) * QK_BLK_N; + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < FragsN; + n++, col_idx += get<1>(MmaAtomShape())) { // 4 + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < FragsM; m++) { // 2 + int row_idx = m * Vec + seq_coord; + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < Vec; row++, row_idx++) { // 8 + if (col_idx - full_tile_offset > row_idx - discard_seq_coord) { + tSr(row, m, n) = ElementAccumulator{-INFINITY}; + } + } + } + } + + CollectiveSoftmaxEpilogue softmax(params.softmax); + softmax((kv_splits - 1) == 0, tSr, max_reg, sum_reg, out_reg); + collective_mma.template mmaPV(out_reg, tSr, + gV(_, _, kv_splits_new - 1), + out_reg, mainloop_params, false); + } + + + // Epilogue + auto epilogue_params = + CollectiveEpilogue::template get_updated_copies( + params.epilogue, params.problem_shape, sequence_length_shape, + batch_coord, q_head_coord); + CollectiveEpilogue epilogue{epilogue_params, shared_storage.epilogue}; + auto blk_coord_mnkl = make_coord(blk_m_coord, blk_n_coord, _, 0); + if constexpr (Sink) { + ElementAccumulator max_scale{max_reg * params.softmax.scale}; + epilogue(params.problem_shape, sequence_length_shape, blk_coord_mnkl, out_reg, max_scale, sum_reg, static_cast(params.epilogue.ptr_sink[q_head_coord])); + } else { + epilogue(params.problem_shape, sequence_length_shape, blk_coord_mnkl, out_reg, max_reg, sum_reg, 0); + } + } + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::flash_attention::kernel diff --git a/benchmarks/flash_attention/flash_attention_decode/benchmark_runner.hpp b/benchmarks/flash_attention/flash_attention_decode/benchmark_runner.hpp index 2784dbb859..37cf8566ea 100644 --- a/benchmarks/flash_attention/flash_attention_decode/benchmark_runner.hpp +++ b/benchmarks/flash_attention/flash_attention_decode/benchmark_runner.hpp @@ -191,9 +191,9 @@ template struct BenchmarkRunnerFMHADecode { int max_seq_len_q = static_cast(get<3>(problem_size)); int max_seq_len_kv = static_cast(get<4>(problem_size)); int max_seq_len_kv_cache = static_cast(get<5>(problem_size)); - get<3>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_q, cumulative_seqlen_q.data()}; - get<4>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_kv, cumulative_seqlen_kv.data()}; - get<5>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_kv_cache, cumulative_seqlen_kv_cache.data()}; + get<3>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_q, 0, cumulative_seqlen_q.data()}; + get<4>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_kv, 0, cumulative_seqlen_kv.data()}; + get<5>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_kv_cache, 0, cumulative_seqlen_kv_cache.data()}; } auto [batch, num_heads_q, num_heads_kv, head_size_qk, head_size_vo] = cute::select<0,1,2,6,7>(problem_size); diff --git a/benchmarks/flash_attention/flash_attention_prefill/benchmark_runner.hpp b/benchmarks/flash_attention/flash_attention_prefill/benchmark_runner.hpp index d7d60d71fd..66b19bc72d 100644 --- a/benchmarks/flash_attention/flash_attention_prefill/benchmark_runner.hpp +++ b/benchmarks/flash_attention/flash_attention_prefill/benchmark_runner.hpp @@ -165,8 +165,8 @@ template struct BenchmarkRunnerFMHA { if constexpr (isVarLen) { int max_seq_len_q = static_cast(get<3>(problem_size)); int max_seq_len_kv = static_cast(get<4>(problem_size)); - get<3>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_q, cumulative_seqlen_q.data()}; - get<4>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_kv, cumulative_seqlen_kv.data()}; + get<3>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_q, 0, cumulative_seqlen_q.data()}; + get<4>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_kv, 0, cumulative_seqlen_kv.data()}; } auto [batch, num_heads_q, num_heads_kv, head_size_qk, head_size_vo] = cute::select<0,1,2,5,6>(problem_size); diff --git a/benchmarks/flash_attention/flash_attention_prefill_cachedKV/benchmark_runner.hpp b/benchmarks/flash_attention/flash_attention_prefill_cachedKV/benchmark_runner.hpp index 801630bd49..75b05ae253 100644 --- a/benchmarks/flash_attention/flash_attention_prefill_cachedKV/benchmark_runner.hpp +++ b/benchmarks/flash_attention/flash_attention_prefill_cachedKV/benchmark_runner.hpp @@ -176,9 +176,9 @@ template struct BenchmarkRunnerFMHA { int max_seq_len_q = static_cast(get<3>(problem_size)); int max_seq_len_kv = static_cast(get<4>(problem_size)); int max_seq_len_kv_cache = static_cast(get<5>(problem_size)); - get<3>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_q, cumulative_seqlen_q.data()}; - get<4>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_kv, cumulative_seqlen_kv.data()}; - get<5>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_kv_cache, cumulative_seqlen_kv_cache.data()}; + get<3>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_q, 0, cumulative_seqlen_q.data()}; + get<4>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_kv, 0, cumulative_seqlen_kv.data()}; + get<5>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_kv_cache, 0, cumulative_seqlen_kv_cache.data()}; } auto [batch, num_heads_q, num_heads_kv, head_size_qk, head_size_vo] = cute::select<0,1,2,6,7>(problem_size); diff --git a/examples/06_bmg_flash_attention/06_bmg_chunk_prefill.cpp b/examples/06_bmg_flash_attention/06_bmg_chunk_prefill.cpp new file mode 100644 index 0000000000..e3e3c2d9cb --- /dev/null +++ b/examples/06_bmg_flash_attention/06_bmg_chunk_prefill.cpp @@ -0,0 +1,116 @@ +/*************************************************************************************************** + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Flash Attention V2 Prefill for Intel BMG + + This example constructs and executes a Flash Attention Prefill with KV cache on Intel BMG. The + definition of the GEMM, options etc for this example are defined in the associated + bmg_flash_attn_cachedKV_runner.hpp header file. + + See https://arxiv.org/pdf/2307.08691 for details of Flash Attention V2 algorithm + + To run this example: + $ ./examples/sycl/06_bmg_flash_attention_cachedKV/06_bmg_prefill_attention_cachedKV --seq_len_qo=512 + --seq_len_kv=512 --seq_len_kv_cache=512 --head_size_vo=128 --head_size_qk=128 + + Causal masking of the first matrix multiplication is supported (`--is_causal`) + + To build & run this example (from your build dir): + + $ ninja 06_bmg_prefill_attention_cachedKV + $ ./examples/sycl/06_bmg_flash_attention_cachedKV/06_bmg_prefill_attention_cachedKV + + Call with `--help` for information about available options +*/ + +#include "bmg_flash_chunk_prefill_runner.hpp" + +int main(int argc, const char **argv) { + // + // Parse options + // + + Options options; + + options.parse(argc, argv); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return -1; + } + + // Define the work-group tile shape depending on the head-size of the second matmul + // Shape<_SequenceLenthOutputBLOCK, _HeadSizeout(NV), SequenceLengthKVBLOCK_KN/KV, HeadSizeQKBLOCK_KQK, HEADSIZEOutSlicerBlock> + // +#if !defined(HEAD_DIM) + std::cerr << "HEAD_DIM must be defined" << std::endl; + return -1; +#endif + if (options.head_size_vo != HEAD_DIM) { + std::cerr << "head_size_vo must be " << HEAD_DIM << ", but got " << options.head_size_vo << std::endl; + return -1; + } + + constexpr int PipelineStages = 2; +#if HEAD_DIM == 64 + using ShapeQK = Shape<_128, _64, _64>; + using ShapePV = Shape<_128, _32, _64>; + using ShapeOutPut = Shape<_128, _64, _64>; + using SubgroupLayout = Layout, Stride<_1, _1, _1>>; +#elif HEAD_DIM == 96 + using ShapeQK = Shape<_128, _64, _32>; + using ShapePV = Shape<_128, _32, _64>; + using ShapeOutPut = Shape<_128, _96, _64>; + using SubgroupLayout = Layout, Stride<_1, _1, _1>>; +#elif HEAD_DIM == 128 + using ShapeQK = Shape<_128, _64, _64>; + using ShapePV = Shape<_128, _32, _64>; + using ShapeOutPut = Shape<_128, _128, _64>; + using SubgroupLayout = Layout, Stride<_1, _1, _1>>; +#elif HEAD_DIM == 192 + using ShapeQK = Shape<_256, _64, _64>; + using ShapePV = Shape<_256, _32, _64>; + using ShapeOutPut = Shape<_256, _192, _64>; + using SubgroupLayout = Layout, Stride<_1, _1, _1>>; +#endif + if (options.is_causal) { + FMHAConfig::run(options); + } else if (options.is_local_mask) { + FMHAConfig::run(options); + } else { + FMHAConfig::run(options); + } +} diff --git a/examples/06_bmg_flash_attention/CMakeLists.txt b/examples/06_bmg_flash_attention/CMakeLists.txt index 39752da4ed..aea6e54b48 100644 --- a/examples/06_bmg_flash_attention/CMakeLists.txt +++ b/examples/06_bmg_flash_attention/CMakeLists.txt @@ -1,4 +1,5 @@ # Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved. +# Copyright (C) 2025 Intel Corporation, All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without @@ -63,6 +64,11 @@ foreach(HEAD_DIM 64 96 128 192) cutlass_example_add_executable( 06_bmg_decode_attention_fp8_hdim${HEAD_DIM} 06_bmg_decode_attention_fp8.cpp + ) + + cutlass_example_add_executable( + 06_bmg_chunk_prefill_hdim${HEAD_DIM} + 06_bmg_chunk_prefill.cpp TEST_COMMAND_OPTIONS TEST_NO_PAGED TEST_PAGED @@ -72,4 +78,5 @@ foreach(HEAD_DIM 64 96 128 192) target_compile_definitions(06_bmg_decode_attention_hdim${HEAD_DIM} PRIVATE HEAD_DIM=${HEAD_DIM}) target_compile_definitions(06_bmg_prefill_attention_fp8_hdim${HEAD_DIM} PRIVATE HEAD_DIM=${HEAD_DIM}) target_compile_definitions(06_bmg_decode_attention_fp8_hdim${HEAD_DIM} PRIVATE HEAD_DIM=${HEAD_DIM}) + target_compile_definitions(06_bmg_chunk_prefill_hdim${HEAD_DIM} PRIVATE HEAD_DIM=${HEAD_DIM}) endforeach() diff --git a/examples/06_bmg_flash_attention/bmg_flash_attn_decode_runner.hpp b/examples/06_bmg_flash_attention/bmg_flash_attn_decode_runner.hpp index 6bb7ad487f..74cda67a5f 100644 --- a/examples/06_bmg_flash_attention/bmg_flash_attn_decode_runner.hpp +++ b/examples/06_bmg_flash_attention/bmg_flash_attn_decode_runner.hpp @@ -222,9 +222,9 @@ template struct ExampleRunner { int max_seq_len_q = static_cast(get<3>(problem_size)); int max_seq_len_kv = static_cast(get<4>(problem_size)); int max_seq_len_kv_cache = static_cast(get<5>(problem_size)); - get<3>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_q, cumulative_seqlen_q.data()}; - get<4>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_kv, cumulative_seqlen_kv.data()}; - get<5>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_kv_cache, cumulative_seqlen_kv_cache.data()}; + get<3>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_q, 0, cumulative_seqlen_q.data()}; + get<4>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_kv, 0, cumulative_seqlen_kv.data()}; + get<5>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_kv_cache, 0, cumulative_seqlen_kv_cache.data()}; } auto [batch, num_heads_q, num_heads_kv, head_size_qk, head_size_vo] = cute::select<0,1,2,6,7>(problem_size); diff --git a/examples/06_bmg_flash_attention/bmg_flash_attn_prefill_cachedKV_runner.hpp b/examples/06_bmg_flash_attention/bmg_flash_attn_prefill_cachedKV_runner.hpp index f87cc8af2f..5adb778722 100644 --- a/examples/06_bmg_flash_attention/bmg_flash_attn_prefill_cachedKV_runner.hpp +++ b/examples/06_bmg_flash_attention/bmg_flash_attn_prefill_cachedKV_runner.hpp @@ -214,9 +214,9 @@ template struct ExampleRunner { int max_seq_len_q = static_cast(get<3>(problem_size)); int max_seq_len_kv = static_cast(get<4>(problem_size)); int max_seq_len_kv_cache = static_cast(get<5>(problem_size)); - get<3>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_q, cumulative_seqlen_q.data()}; - get<4>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_kv, cumulative_seqlen_kv.data()}; - get<5>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_kv_cache, cumulative_seqlen_kv_cache.data()}; + get<3>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_q, 0, cumulative_seqlen_q.data()}; + get<4>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_kv, 0, cumulative_seqlen_kv.data()}; + get<5>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_kv_cache, 0, cumulative_seqlen_kv_cache.data()}; } auto [batch, num_heads_q, num_heads_kv, head_size_qk, head_size_vo] = cute::select<0,1,2,6,7>(problem_size); diff --git a/examples/06_bmg_flash_attention/bmg_flash_attn_prefill_runner.hpp b/examples/06_bmg_flash_attention/bmg_flash_attn_prefill_runner.hpp index 264eae22e6..5e0086976e 100644 --- a/examples/06_bmg_flash_attention/bmg_flash_attn_prefill_runner.hpp +++ b/examples/06_bmg_flash_attention/bmg_flash_attn_prefill_runner.hpp @@ -193,8 +193,8 @@ template struct ExampleRunner { if constexpr (isVarLen) { int max_seq_len_q = static_cast(get<3>(problem_size)); int max_seq_len_kv = static_cast(get<4>(problem_size)); - get<3>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_q, cumulative_seqlen_q.data()}; - get<4>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_kv, cumulative_seqlen_kv.data()}; + get<3>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_q, 0, cumulative_seqlen_q.data()}; + get<4>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_kv, 0, cumulative_seqlen_kv.data()}; } auto [batch, num_heads_q, num_heads_kv, head_size_qk, head_size_vo] = cute::select<0,1,2,5,6>(problem_size); diff --git a/examples/06_bmg_flash_attention/bmg_flash_chunk_prefill_runner.hpp b/examples/06_bmg_flash_attention/bmg_flash_chunk_prefill_runner.hpp new file mode 100644 index 0000000000..f0e59afe4d --- /dev/null +++ b/examples/06_bmg_flash_attention/bmg_flash_chunk_prefill_runner.hpp @@ -0,0 +1,950 @@ +/*************************************************************************************************** + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "flash_attention_v2/collective/fmha_fusion.hpp" +#include "flash_attention_v2/kernel/tile_scheduler_chunk_prefill.hpp" +#include "cutlass/util/packed_stride.hpp" +#include "flash_attention_v2/kernel/xe_chunk_prefill.hpp" +#include "flash_attention_v2/collective/xe_flash_attn_chunk_prefill_epilogue.hpp" +#include "flash_attention_v2/collective/xe_flash_attn_chunk_prefill_softmax_epilogue.hpp" +#include "cutlass/util/GPU_Clock.hpp" +#include "cutlass/util/sycl_event_manager.hpp" + +#include +#include + +#include "helper.h" +#include "cutlass/util/command_line.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "sycl_common.hpp" + +using namespace cute; + +// Command line options parsing +struct Options { + + bool help; + bool error; + bool is_causal; + bool is_local_mask; + bool varlen = false; + bool use_sink = false; + bool use_paged_kv = false; + std::string scheduler; + + int batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, seq_len_kv_cache, page_size, head_size_qk, head_size_vo, iterations, window_left, window_right; + float softmax_scale; + + Options() + : help(false), error(false), is_causal(false), is_local_mask(false), varlen(false), use_sink(false), use_paged_kv(false), batch(32), num_heads_q(16), num_heads_kv(16), seq_len_qo(512), head_size_qk(128), + seq_len_kv(512), seq_len_kv_cache(512), page_size(128), head_size_vo(128), iterations(100), window_left(-1), window_right(-1), softmax_scale(1.f), scheduler("Individual") {} + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + if (cmd.check_cmd_line_flag("is_causal")) { + is_causal = true; + } + + if (cmd.check_cmd_line_flag("varlen")) { + varlen = true; + } + if (cmd.check_cmd_line_flag("use_sink")) { + use_sink = true; + } + + cmd.get_cmd_line_argument("scheduler", scheduler, std::string("Individual")); + + cmd.get_cmd_line_argument("batch", batch, 32); + cmd.get_cmd_line_argument("num_heads_q", num_heads_q, 16); + cmd.get_cmd_line_argument("num_heads_kv", num_heads_kv, num_heads_q); + cmd.get_cmd_line_argument("seq_len_qo", seq_len_qo, 512); + cmd.get_cmd_line_argument("seq_len_kv", seq_len_kv, seq_len_qo); + cmd.get_cmd_line_argument("seq_len_kv_cache", seq_len_kv_cache, 512); + cmd.get_cmd_line_argument("head_size_vo", head_size_vo, HEAD_DIM); + cmd.get_cmd_line_argument("head_size_qk", head_size_qk, head_size_vo); + cmd.get_cmd_line_argument("window_left", window_left, -1); + cmd.get_cmd_line_argument("window_right", window_right, -1); + cmd.get_cmd_line_argument("iterations", iterations, 100); + + if (cmd.check_cmd_line_flag("use_paged_kv")) { + use_paged_kv = true; + cmd.get_cmd_line_argument("page_size", page_size, 128); + seq_len_kv = 0; // seq_len_kv is not used when use paged kv + if (page_size % 128 != 0) { + std::cerr << "Invalid: page_size must be a multiple of 128" << std::endl; + return; + } + if (seq_len_kv_cache % page_size != 0) { + std::cerr << "Invalid: seq_len_kv_cache must be divisible by page_size" << std::endl; + return; + } + } + if (window_left > -1 && window_right > -1) { + is_local_mask = true; + } + softmax_scale = 1 / sqrt(static_cast(head_size_qk)); + } + + /// Prints the usage statement. + std::ostream &print_usage(std::ostream &out) const { + + out << "BMG Flash Attention v2 Example\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --is_causal Apply Causal Mask to the output of first Matmul\n" + << " --use_sink Apply Attention Sink\n" + << " --window_left= Set the left borders of the window, If set to -1, calculate all seq_len\n" + << " --window_right= Set the left borders of the window, If set to -1, calculate all seq_len\n" + << " --varlen Enable variable sequence length\n" + << " --scheduler=\"Value\" Choose between Individual or Persistent Scheduler\n" + << " --batch= Sets the Batch Size of the Multi-Head Self Attention module\n" + << " --num_heads_q= Sets the Number of Attention Heads for Key-Value pair the Multi-Head Self Attention module\n" + << " --num_heads_kv= Sets the Number of Attention Heads for Query input in the Multi-Head Self Attention module\n" + << " --seq_len_qo= Sets the Sequence length of the Query input in Multi-Head Self Attention module\n" + << " --seq_len_kv= Sets the Sequence length of the Key-Value pair in Multi-Head Self Attention module\n" + << " --seq_len_kv_cache= Sets the Sequence length of the cached Key-Value pair in Multi-Head Self Attention module\n" + << " --use_paged_kv Use paged (non-contiguous) KV cache. Default is contiguous KV Cache\n" + << " --page_size= Block size for paged KV cache. Default is 128\n" + << " --head_size_qk= Sets the Attention Head dimension of the 1st Matrix Multiplication in Multi-Head Self Attention module\n" + << " --head_size_vo= Sets the Attention Head dimension of the 2nd Matrix Multiplication in Multi-Head Self Attention module\n" + << " --iterations= Iterations\n\n"; + + return out; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// Flash Attention takes 3 input matrices: (K)eys, (Q)ueries and (V)alues. +using LayoutQ = cutlass::layout::RowMajor; +using LayoutK = cutlass::layout::ColumnMajor; +using LayoutV = cutlass::layout::RowMajor; +using LayoutO = cutlass::layout::RowMajor; + +template struct ExampleRunner { + + using StrideQ = typename FMHAChunkPrefillKernel::StrideQ; + using StrideK = typename FMHAChunkPrefillKernel::StrideK; + using StrideV = typename FMHAChunkPrefillKernel::StrideV; + using StrideO = typename FMHAChunkPrefillKernel::StrideO; + + using ElementQ = typename FMHAChunkPrefillKernel::ElementQ; + using ElementK = typename FMHAChunkPrefillKernel::ElementK; + using ElementV = typename FMHAChunkPrefillKernel::ElementV; + using ElementSink = typename FMHAChunkPrefillKernel::ElementSink; + using ElementAcc = typename FMHAChunkPrefillKernel::ElementAccumulator; + + using CollectiveEpilogue = typename FMHAChunkPrefillKernel::CollectiveEpilogue; + using ElementOutput = typename CollectiveEpilogue::ElementOutput; + using ElementCompute = typename CollectiveEpilogue::ElementCompute; + using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; + + using ProblemShapeType = typename FMHAChunkPrefillKernel::ProblemShape; + + // + // Data members + // + + /// Initialization + StrideQ stride_Q; + StrideK stride_K; + StrideV stride_V; + StrideK stride_K_cache; + StrideV stride_V_cache; + StrideO stride_O; + uint64_t seed = 0; + + cutlass::DeviceAllocation block_Q; + cutlass::DeviceAllocation block_K; + cutlass::DeviceAllocation block_V; + cutlass::DeviceAllocation block_K_cache; + cutlass::DeviceAllocation block_V_cache; + cutlass::DeviceAllocation block_Sink; + cutlass::DeviceAllocation block_O; + cutlass::DeviceAllocation block_ref_O; + + std::vector cumulative_seqlen_q; + std::vector cumulative_seqlen_kv; + std::vector cumulative_seqlen_kv_cache; + cutlass::DeviceAllocation device_cumulative_seqlen_q; + cutlass::DeviceAllocation device_cumulative_seqlen_kv; + cutlass::DeviceAllocation device_cumulative_seqlen_kv_cache; + + struct PagedKVParams { + cutlass::DeviceAllocation page_table; + int page_size = 0; + cutlass::DeviceAllocation num_pages_per_seq; + }; + PagedKVParams paged_kv_cache; + + // + // Methods + // + +bool verify(ProblemShapeType problem_size, Options options) { + std::vector host_O(block_ref_O.size()); + std::vector host_Sink; + + if (options.use_sink) { + host_Sink.resize(block_Sink.size()); + compat::memcpy(host_Sink.data(), block_Sink.get(), host_Sink.size()); + compat::wait(); + } + if constexpr (isVarLen) { + int max_seq_len_q = static_cast(get<3>(problem_size)); + int max_seq_len_kv = static_cast(get<4>(problem_size)); + int max_seq_len_kv_cache = static_cast(get<5>(problem_size)); + get<3>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_q, 0, cumulative_seqlen_q.data()}; + get<4>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_kv, 0, cumulative_seqlen_kv.data()}; + get<5>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_kv_cache, 0, cumulative_seqlen_kv_cache.data()}; + } + + auto [batch, num_heads_q, num_heads_kv, head_size_qk, head_size_vo] = cute::select<0,1,2,6,7>(problem_size); + int seq_len_qo, seq_len_kv, seq_len_kv_cache; + + int offset_q = 0; + int offset_k = 0; + int offset_v = 0; + int offset_k_cache = 0; + int offset_v_cache = 0; + int offset_o = 0; + // loop over the batch dimension to compute the output + // to avoid the risk of running out of device memory + int q_group_size = num_heads_q / num_heads_kv; + for (int b = 0; b < batch; b++) { + if constexpr (isVarLen) { + auto logical_problem_shape = cutlass::fmha::collective::apply_variable_length(problem_size, b); + seq_len_qo = get<3>(logical_problem_shape); + seq_len_kv = get<4>(logical_problem_shape); + seq_len_kv_cache = get<5>(logical_problem_shape); + } else { + seq_len_qo = get<3>(problem_size); + seq_len_kv = get<4>(problem_size); + seq_len_kv_cache = get<5>(problem_size); + } + ElementQ* q_ptr; + ElementK* k_ptr; + ElementV* v_ptr; + q_ptr = block_Q.get() + offset_q; + int seq_len_kv_total = seq_len_kv_cache + seq_len_kv; + cutlass::DeviceAllocation block_K_concat; + cutlass::DeviceAllocation block_V_concat; + + if (seq_len_kv_cache > 0) { // use_kv_cache + if (options.use_paged_kv) { + int num_pages = paged_kv_cache.page_table.size(); + std::vector host_page_table(paged_kv_cache.page_table.size()); + std::vector host_num_pages_per_seq(paged_kv_cache.num_pages_per_seq.size()); + compat::memcpy(host_page_table.data(), paged_kv_cache.page_table.get(), paged_kv_cache.page_table.size()); + compat::memcpy(host_num_pages_per_seq.data(), paged_kv_cache.num_pages_per_seq.get(), paged_kv_cache.num_pages_per_seq.size()); + + int curr_batch_pages = isVarLen ? host_num_pages_per_seq[b + 1] - host_num_pages_per_seq[b] : ceil_div(seq_len_kv_cache, paged_kv_cache.page_size); + int batch_offset = isVarLen ? host_num_pages_per_seq[b] : b * curr_batch_pages; + block_K_concat.reset((seq_len_kv + curr_batch_pages * paged_kv_cache.page_size) * num_heads_kv * head_size_qk); + block_V_concat.reset((seq_len_kv + curr_batch_pages * paged_kv_cache.page_size) * num_heads_kv * head_size_vo); + + for (int p = 0; p < curr_batch_pages; p++) { + int page_idx = host_page_table[batch_offset + p]; + // copy the page from KV cache to the concatenated buffer + compat::memcpy( + block_K_concat.get() + p * paged_kv_cache.page_size * num_heads_kv * head_size_qk, + block_K_cache.get() + page_idx * paged_kv_cache.page_size * num_heads_kv * head_size_qk, + paged_kv_cache.page_size * num_heads_kv * head_size_qk + ); + compat::memcpy( + block_V_concat.get() + p * paged_kv_cache.page_size * num_heads_kv * head_size_vo, + block_V_cache.get() + page_idx * paged_kv_cache.page_size * num_heads_kv * head_size_vo, + paged_kv_cache.page_size * num_heads_kv * head_size_vo + ); + } + if (seq_len_kv > 0) { + compat::memcpy( + // block_K_concat.get() + curr_batch_pages * paged_kv_cache.page_sze * num_heads_kv *head_size_qk, + block_K_concat.get() + seq_len_kv_cache * num_heads_kv * head_size_qk, + block_K.get() + offset_k, + seq_len_kv * num_heads_kv * head_size_qk + ); + compat::memcpy( + block_V_concat.get() + seq_len_kv_cache * num_heads_kv * head_size_vo, + block_V.get() + offset_v, + seq_len_kv * num_heads_kv * head_size_vo + ); + } + compat::wait(); + } else { + block_K_concat.reset(seq_len_kv_total * num_heads_kv * head_size_qk); + block_V_concat.reset(seq_len_kv_total * num_heads_kv * head_size_vo); + // Concatenate K_cache and K + compat::memcpy( + block_K_concat.get(), + block_K_cache.get() + offset_k_cache, + seq_len_kv_cache * num_heads_kv * head_size_qk + ); + compat::memcpy( + block_K_concat.get() + seq_len_kv_cache * num_heads_kv * head_size_qk, + block_K.get() + offset_k, + seq_len_kv * num_heads_kv * head_size_qk + ); + // Concatenate V_cache and V + compat::memcpy( + block_V_concat.get(), + block_V_cache.get() + offset_v_cache, + seq_len_kv_cache * num_heads_kv * head_size_vo + ); + compat::memcpy( + block_V_concat.get() + seq_len_kv_cache * num_heads_kv * head_size_vo, + block_V.get() + offset_v, + seq_len_kv * num_heads_kv * head_size_vo + ); + // compat::wait(); + } + k_ptr = block_K_concat.get(); + v_ptr = block_V_concat.get(); + } else { + k_ptr = block_K.get() + offset_k; + v_ptr = block_V.get() + offset_v; + } + + for (int q_group = 0; q_group < num_heads_q / q_group_size; q_group++) { + for (int q_head = 0; q_head < q_group_size; q_head++) { + cutlass::DeviceAllocation block_S; + block_S.reset(seq_len_qo * seq_len_kv_total); + + cutlass::TensorRef ref_Q(q_ptr, LayoutQ(num_heads_q * head_size_qk)); + cutlass::TensorRef ref_K(k_ptr, LayoutK(num_heads_kv * head_size_qk)); + cutlass::TensorRef ref_V(v_ptr, LayoutV(num_heads_kv * head_size_vo)); + cutlass::TensorRef ref_S(block_S.get(), LayoutQ::packed({seq_len_qo, seq_len_kv_total})); + + cutlass::reference::device::GemmComplex({seq_len_qo, seq_len_kv_total, head_size_qk}, ElementAccumulator{1}, ref_Q, + cutlass::ComplexTransform::kNone, ref_K, cutlass::ComplexTransform::kNone, + ElementAccumulator{0}, ref_S, ref_S, ElementAccumulator{0}, + 1, // batch_count + seq_len_qo * head_size_qk, // batch_stride_Q + seq_len_kv_total * head_size_qk, // batch_stride_K + seq_len_qo * seq_len_kv_total, // batch_stride_S + seq_len_qo * seq_len_kv_total // batch_stride_S + ); + compat::wait(); + std::vector host_S(block_S.size()); + compat::memcpy(host_S.data(), block_S.get(), host_S.size()); + + // delete this memory as it is no longer needed + block_S.reset(); + auto offset = cute::min(seq_len_qo, seq_len_kv); + auto discard_seq_coord = seq_len_qo - offset; + auto full_tile_offset = seq_len_kv - offset; + // apply mask to S + for (int row = 0; row < seq_len_qo; row++) { + for (int col = 0; col < seq_len_kv_total; col++) { + // causal mask + if (options.is_causal && (col - full_tile_offset > row + seq_len_kv_cache - discard_seq_coord)) { + host_S[col + row * seq_len_kv_total] = ElementAccumulator{-INFINITY}; + } + // sliding window mask + int col_ref = seq_len_kv_cache + seq_len_kv - seq_len_qo; + bool left_mask = col < cute::max(0, col_ref + row - options.window_left); + bool right_mask = col > cute::min(seq_len_kv_total, col_ref + row + options.window_right); + if (options.is_local_mask && (left_mask || right_mask)) { + host_S[col + row * seq_len_kv_total] = ElementAccumulator{-INFINITY}; + } + } + } + + // compute max element per row of S + std::vector max_vec(seq_len_qo, ElementAccumulator{-INFINITY}); + for (int row = 0; row < seq_len_qo; row++) { + int idx = row * seq_len_kv_total; + int max_idx = row; + max_vec[max_idx] = host_S[idx++]; + if (options.use_sink) { + ElementAccumulator sink_val = static_cast(host_Sink[q_group * q_group_size + q_head]); + max_vec[max_idx] = max(sink_val, max_vec[max_idx]); + } + for (int col = 1; col < seq_len_kv_total; col++, idx++) { + if (max_vec[max_idx] < host_S[idx]) + max_vec[max_idx] = host_S[idx]; + } + } + // compute exp of S + for (int row = 0; row < seq_len_qo; row++) { + int idx = row * seq_len_kv_total; + int max_idx = row; + for (int col = 0; col < seq_len_kv_total; col++, idx++) { + host_S[idx] = expf((host_S[idx] - max_vec[max_idx]) / options.softmax_scale); + } + } + + // compute sum per row of S + std::vector sum_vec(seq_len_qo, ElementAccumulator{0}); + for (int row = 0; row < seq_len_qo; row++) { + int idx = row * seq_len_kv_total; + int sum_idx = row; + for (int col = 0; col < seq_len_kv_total; col++, idx++) { + sum_vec[sum_idx] += host_S[idx]; + } + if (options.use_sink) { + ElementAccumulator sink_val = static_cast(host_Sink[q_group * q_group_size + q_head]); + auto exp_sink = expf(sink_val - max_vec[row]); + sum_vec[sum_idx] += exp_sink; + } + + // scale each row with the sum to compute softmax + idx = row * seq_len_kv_total; + sum_idx = row; + int col_ref = seq_len_kv_cache + seq_len_kv - seq_len_qo; + for (int col = 0; col < seq_len_kv_total; col++, idx++) { + if (options.is_causal && row < discard_seq_coord) { + host_S[idx] = 0; + } else if (options.is_local_mask && (col < cute::max(0, col_ref + row - options.window_left) + || col > cute::min(seq_len_kv_total, col_ref + row + options.window_right))) { + host_S[idx] = 0; + } else { + host_S[idx] /= sum_vec[sum_idx]; + } + } + } + std::vector host_P(host_S.size()); + for (int p = 0; p < host_P.size(); p++) + host_P[p] = static_cast(host_S[p]); + + cutlass::DeviceAllocation block_P; + block_P.reset(host_P.size()); + + compat::memcpy(block_P.get(), host_P.data(), host_P.size()); + + cutlass::TensorRef ref_P(block_P.get(), LayoutQ::packed({seq_len_qo, seq_len_kv_total})); + + cutlass::DeviceAllocation block_acc; + block_acc.reset(seq_len_qo * head_size_vo); + cutlass::TensorRef ref_acc(block_acc.get(), LayoutO::packed({seq_len_qo, head_size_vo})); + + cutlass::reference::device::GemmComplex({seq_len_qo, head_size_vo, seq_len_kv_total}, ElementAccumulator{1}, ref_P, + cutlass::ComplexTransform::kNone, ref_V, cutlass::ComplexTransform::kNone, + ElementAccumulator{0}, ref_acc, ref_acc, ElementAccumulator{0}, + 1, // batch_count + seq_len_qo * seq_len_kv_total, // batch_stride_P + seq_len_kv_total * head_size_vo, // batch_stride_V + seq_len_qo * head_size_vo, // batch_stride_O + seq_len_qo * head_size_vo // batch_stride_O + ); + + compat::wait(); + // delete this memory as it is no longer needed + block_P.reset(); + + std::vector vec_acc(block_acc.size()); + compat::memcpy(vec_acc.data(), block_acc.get(), vec_acc.size()); + + // delete this memory as it is no longer needed + block_acc.reset(); + for (int seq = 0; seq < seq_len_qo; seq++) { + for (int hvo = 0; hvo < head_size_vo; hvo++) { + int idx = offset_o + seq * num_heads_q * head_size_vo + (q_group * q_group_size + q_head) * head_size_vo + hvo; + host_O[idx] = static_cast(vec_acc[seq * head_size_vo + hvo]); + } + } + q_ptr += head_size_qk; + } // end of q_group loop + { + k_ptr += head_size_qk; + v_ptr += head_size_vo; + } + } // end of q_head loop + offset_q += seq_len_qo * num_heads_q * head_size_qk; + offset_k += seq_len_kv * num_heads_kv * head_size_qk; + offset_v += seq_len_kv * num_heads_kv * head_size_vo; + offset_k_cache += seq_len_kv_cache * num_heads_kv * head_size_qk; + offset_v_cache += seq_len_kv_cache * num_heads_kv * head_size_vo; + offset_o += seq_len_qo * num_heads_q * head_size_vo; + } // end of batch loop + + compat::wait(); + compat::memcpy(block_ref_O.get(), host_O.data(), host_O.size()); + // Check if output from CUTLASS kernel and reference kernel are equal or not + bool passed = cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_O.get(), block_O.get(), + block_O.size(), ElementOutput{0.5}, ElementOutput{0.5}); + + return passed; + } + + template + auto initialize_varlen(const ProblemShape& problem_size) { + int num_batches = get<0>(problem_size); + int seq_len_kv_cache = get<5>(problem_size); + + // generate Q as --b times + // gaussian (--Q, --Q / 2) sampled positive + // track cumulative + std::mt19937 rng(0x202305151552ull); + std::normal_distribution dist_q(get<3>(problem_size), get<3>(problem_size) / 2); + std::normal_distribution dist_kv(get<4>(problem_size), get<4>(problem_size) / 2); + std::normal_distribution dist_kv_cache(get<5>(problem_size), get<5>(problem_size) / 2); + + // Use Cacheline Size to calculate alignment + constexpr int cacheline_bytes = 64; + constexpr int AlignmentQ = cacheline_bytes / sizeof(ElementQ); // Alignment of Q matrix in units of elements + constexpr int AlignmentKV = cacheline_bytes / sizeof(ElementK); // Alignment of Kand V matrix in units of elements + + auto generate_positive_int = [](auto& dist, auto& gen) { + int result = 0; + do { + result = static_cast(dist(gen)); + } while (result <= 0); + return result; + }; + + cumulative_seqlen_q = {0}; + cumulative_seqlen_kv = {0}; + cumulative_seqlen_kv_cache = {0}; + + int total_seqlen_q = 0; + int total_seqlen_kv = 0; + int total_seqlen_kv_cache = 0; + int max_seqlen_q = 0; + int max_seqlen_kv = 0; + int max_seqlen_kv_cache = 0; + + for (int i = 0; i < num_batches; i++) { + int seqlen_q = cutlass::round_up(generate_positive_int(dist_q, rng), AlignmentQ); + int seqlen_kv = cute::get<4>(problem_size) == 0 ? 0 : cutlass::round_up(generate_positive_int(dist_kv, rng), AlignmentKV); + int seqlen_kv_cache = cute::get<5>(problem_size) == 0 ? 0 : cutlass::round_up(generate_positive_int(dist_kv_cache, rng), AlignmentKV); + + total_seqlen_q += seqlen_q; + total_seqlen_kv += seqlen_kv; + total_seqlen_kv_cache += seqlen_kv_cache; + + max_seqlen_q = std::max(max_seqlen_q, seqlen_q); + max_seqlen_kv = std::max(max_seqlen_kv, seqlen_kv); + max_seqlen_kv_cache = std::max(max_seqlen_kv_cache, seqlen_kv_cache); + + cumulative_seqlen_q.push_back(cumulative_seqlen_q.back() + seqlen_q); + cumulative_seqlen_kv.push_back(cumulative_seqlen_kv.back() + seqlen_kv); + cumulative_seqlen_kv_cache.push_back(cumulative_seqlen_kv_cache.back() + seqlen_kv_cache); + } + + ProblemShape problem_size_for_init = problem_size; + get<0>(problem_size_for_init) = 1; + get<3>(problem_size_for_init) = total_seqlen_q; + get<4>(problem_size_for_init) = total_seqlen_kv; + get<5>(problem_size_for_init) = total_seqlen_kv_cache; + + ProblemShapeType problem_size_for_launch; + + get<3>(problem_size_for_launch) = cutlass::fmha::collective::VariableLength{max_seqlen_q, total_seqlen_q}; + get<4>(problem_size_for_launch) = cutlass::fmha::collective::VariableLength{max_seqlen_kv, total_seqlen_kv}; + get<5>(problem_size_for_launch) = cutlass::fmha::collective::VariableLength{max_seqlen_kv_cache, total_seqlen_kv_cache}; + get<6>(problem_size_for_launch) = get<6>(problem_size); + get<7>(problem_size_for_launch) = get<7>(problem_size); + get<0>(problem_size_for_launch) = get<0>(problem_size); + get<1>(problem_size_for_launch) = get<1>(problem_size); + get<2>(problem_size_for_launch) = get<2>(problem_size); + + + return cute::make_tuple(problem_size_for_init, problem_size_for_launch); + } + + /// Initialize operands to be used in the GEMM and reference GEMM + ProblemShapeType initialize(const Options &options) { + auto problem_shape_in = + cute::make_tuple(options.batch, options.num_heads_q, options.num_heads_kv, options.seq_len_qo, options.seq_len_kv, options.seq_len_kv_cache, options.head_size_qk, options.head_size_vo); + + ProblemShapeType problem_shape; + decltype(problem_shape_in) problem_size; + + if constexpr (isVarLen) { + auto [problem_shape_init, problem_shape_launch] = initialize_varlen(problem_shape_in); + problem_shape = problem_shape_launch; + problem_size = problem_shape_init; + } + else { + problem_size = problem_shape_in; + problem_shape = problem_shape_in; + } + + auto [batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, seq_len_kv_cache, head_size_qk, head_size_vo] = problem_size; + + stride_Q = cutlass::make_cute_packed_stride(StrideQ{}, cute::make_shape(seq_len_qo, num_heads_q * head_size_qk, batch)); + stride_K = cutlass::make_cute_packed_stride(StrideK{}, cute::make_shape(seq_len_kv, num_heads_kv * head_size_qk, batch)); + stride_V = cutlass::make_cute_packed_stride(StrideV{}, cute::make_shape(head_size_vo * num_heads_kv, seq_len_kv, batch)); + + stride_K_cache = cutlass::make_cute_packed_stride(StrideK{}, cute::make_shape(seq_len_kv_cache, num_heads_kv * head_size_qk, batch)); + stride_V_cache = cutlass::make_cute_packed_stride(StrideV{}, cute::make_shape(head_size_vo * num_heads_kv, seq_len_kv_cache, batch)); + stride_O = cutlass::make_cute_packed_stride(StrideO{}, cute::make_shape(seq_len_qo, num_heads_q * head_size_vo, batch)); + + block_Q.reset(batch * num_heads_q * seq_len_qo * head_size_qk); + block_K.reset(batch * num_heads_kv * seq_len_kv * head_size_qk); + block_V.reset(batch * num_heads_kv * seq_len_kv * head_size_vo); + if (options.use_sink) { + block_Sink.reset(num_heads_q); + } + if (!options.use_paged_kv) { + block_K_cache.reset(batch * num_heads_kv * seq_len_kv_cache * head_size_qk); + block_V_cache.reset(batch * num_heads_kv * seq_len_kv_cache * head_size_vo); + } + block_O.reset(batch * num_heads_q * seq_len_qo * head_size_vo); + block_ref_O.reset(batch * num_heads_q * seq_len_qo * head_size_vo); + + if (options.use_paged_kv) { + paged_kv_cache.page_size = options.page_size; + std::vector num_pages_per_seq{0}; + int num_pages = 0; + for(int b = 0; b < cute::get<0>(problem_shape); b++) { + int seq_len_cache = isVarLen ? cumulative_seqlen_kv_cache[b + 1] - cumulative_seqlen_kv_cache[b] : seq_len_kv_cache; + int pages_per_seq = ceil_div(seq_len_cache, paged_kv_cache.page_size); + num_pages_per_seq.push_back(num_pages_per_seq.back() + pages_per_seq); + num_pages += pages_per_seq; + } + paged_kv_cache.page_table.reset(num_pages); + + + // initialize block table with random mapping for non-contiguous layout + std::vector page_mapping(num_pages); + for (int b = 0; b < cute::get<0>(problem_shape); ++b) { + std::vector physical_pages(num_pages_per_seq[b + 1] - num_pages_per_seq[b]); + std::iota(physical_pages.begin(), physical_pages.end(), 0); + // shuffle physical pages + std::shuffle(physical_pages.begin(), physical_pages.end(), std::mt19937{ std::random_device{}() }); + for (int blk = 0; blk < physical_pages.size(); ++blk) { + int logical_idx = num_pages_per_seq[b] + blk; + page_mapping[logical_idx] = physical_pages[blk]; + } + } + compat::memcpy(paged_kv_cache.page_table.get(), page_mapping.data(), page_mapping.size() * sizeof(int)); + + paged_kv_cache.num_pages_per_seq.reset(num_pages_per_seq.size()); + compat::memcpy(paged_kv_cache.num_pages_per_seq.get(), num_pages_per_seq.data(), num_pages_per_seq.size() * sizeof(int)); + + block_K_cache.reset(num_pages * paged_kv_cache.page_size * num_heads_kv * head_size_qk); + block_V_cache.reset(num_pages * paged_kv_cache.page_size * num_heads_kv * head_size_vo); + } + // std::vector host_Q(block_Q.size()); + // std::vector host_K(block_K.size()); + // std::vector host_V(block_V.size()); + // std::vector host_K_cache(block_K_cache.size()); + // std::vector host_V_cache(block_V_cache.size()); + // std::vector host_Sink(block_Sink.size()); + + // for (size_t i = 0; i < host_Q.size(); i++) { + // host_Q[i] = static_cast(1.f);//static_cast(rand_r(&seed) % 255 / static_cast(255)); + // } + // for (size_t i = 0; i < host_K.size(); i++) { + // host_K[i] = static_cast(1.f);//static_cast(rand_r(&seed + 1) % 255 / static_cast(255)); + // } + // for (size_t i = 0; i < host_V.size(); i++) { + // host_V[i] = static_cast(1.f);// static_cast(rand_r(&seed + 2) % 255 / static_cast(255)); + // } + // for (size_t i = 0; i < host_K_cache.size(); i++) { + // host_K_cache[i] = static_cast(1.f); + // } + // for (size_t i = 0; i < host_V_cache.size(); i++) { + // host_V_cache[i] = static_cast(1.f);//static_cast(rand_r(&seed + 4) % 255 / static_cast(255)); + // } + // for (size_t i = 0; i < host_Sink.size(); i++) { + // host_Sink[i] = static_cast(1000.f);//static_cast(rand_r(&seed + 5) % 255 / static_cast(255)); + // } + // compat::memcpy(block_Q.get(), host_Q.data(), block_Q.size()); + // compat::memcpy(block_K.get(), host_K.data(), block_K.size()); + // compat::memcpy(block_V.get(), host_V.data(), block_V.size()); + // compat::memcpy(block_K_cache.get(), host_K_cache.data(), block_K_cache.size()); + // compat::memcpy(block_V_cache.get(), host_V_cache.data(), block_V_cache.size()); + // if (options.use_sink) { + // compat::memcpy(block_Sink.get(), host_Sink.data(), block_Sink.size()); + // } + // compat::wait(); + initialize_block(block_Q, seed + 2023); + initialize_block(block_K, seed + 2022); + initialize_block(block_V, seed + 2021); + initialize_block(block_Sink, seed + 2021); + initialize_block(block_K_cache, seed + 2024); + initialize_block(block_V_cache, seed + 2025); + + if (!cumulative_seqlen_q.empty()) { + device_cumulative_seqlen_q.reset(cumulative_seqlen_q.size()); + device_cumulative_seqlen_q.copy_from_host( + cumulative_seqlen_q.data(), cumulative_seqlen_q.size()); + } + + if (!cumulative_seqlen_kv.empty()) { + device_cumulative_seqlen_kv.reset(cumulative_seqlen_kv.size()); + device_cumulative_seqlen_kv.copy_from_host( + cumulative_seqlen_kv.data(), cumulative_seqlen_kv.size()); + } + + if (!cumulative_seqlen_kv_cache.empty()) { + device_cumulative_seqlen_kv_cache.reset(cumulative_seqlen_kv_cache.size()); + device_cumulative_seqlen_kv_cache.copy_from_host( + cumulative_seqlen_kv_cache.data(), cumulative_seqlen_kv_cache.size()); + } + + if constexpr (isVarLen) { + get<3>(problem_shape).max_length = get<3>(problem_shape).max_length; + get<3>(problem_shape).total_length = get<3>(problem_shape).total_length; + get<3>(problem_shape).cumulative_length = device_cumulative_seqlen_q.get(); + + get<5>(problem_shape).max_length = get<5>(problem_shape).max_length; + get<5>(problem_shape).total_length = get<5>(problem_shape).total_length; + get<5>(problem_shape).cumulative_length = device_cumulative_seqlen_kv_cache.get(); + + get<4>(problem_shape).max_length = get<4>(problem_shape).max_length; + get<4>(problem_shape).total_length = get<4>(problem_shape).total_length; + get<4>(problem_shape).cumulative_length = device_cumulative_seqlen_kv.get(); + + } + + return problem_shape; + } + + // Note that the GemmUniversalAdapter currently doesn't support flash attention, which is why this + // secondary `run` function is required to launch the kernel. + static void run(typename FMHAChunkPrefillKernel::Params params) { + dim3 const block = FMHAChunkPrefillKernel::get_block_shape(); + dim3 const grid = FMHAChunkPrefillKernel::get_grid_shape(params); + + // configure smem size and carveout + int smem_size = FMHAChunkPrefillKernel::SharedStorageSize; + + const auto sycl_block = compat::dim3(block.x, block.y, block.z); + const auto sycl_grid = compat::dim3(grid.x, grid.y, grid.z); + +// Launch parameters depend on whether SYCL compiler supports work-group scratch memory extension +#if !defined(SYCL_EXT_ONEAPI_WORK_GROUP_SCRATCH_MEMORY) + using namespace compat::experimental; + auto event = launch>( + launch_policy{sycl_grid, sycl_block, local_mem_size{static_cast(smem_size)}, + kernel_properties{sycl_exp::sub_group_size}}, + params); +#else + compat::experimental::launch_properties launch_props { + sycl::ext::oneapi::experimental::work_group_scratch_size(smem_size), + }; + compat::experimental::kernel_properties kernel_props{ + sycl::ext::oneapi::experimental::sub_group_size + }; + compat::experimental::launch_policy policy{sycl_grid, sycl_block, launch_props, kernel_props}; + auto event = compat::experimental::launch, FMHAChunkPrefillKernel>(policy, params); +#endif + + EventManager::getInstance().addEvent(event); + } + + cutlass::Status run(const Options &options, const cutlass::KernelHardwareInfo &hw_info) { + + ProblemShapeType problem_size = initialize(options); + + typename FMHAChunkPrefillKernel::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + {block_Q.get(), stride_Q, + block_K.get(), stride_K, + block_V.get(), stride_V, + block_K_cache.get(), stride_K_cache, + block_V_cache.get(), stride_V_cache, + options.use_paged_kv ? paged_kv_cache.page_table.get() : nullptr, + options.use_paged_kv ? paged_kv_cache.page_size : 0, + options.use_paged_kv ? paged_kv_cache.num_pages_per_seq.get() : nullptr, + options.window_left, + options.window_right}, + {options.softmax_scale}, + {block_O.get(), stride_O, block_Sink.get()}, + hw_info}; + + // Define device-global scratch memory + size_t workspace_size = FMHAChunkPrefillKernel::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + if (!FMHAChunkPrefillKernel::can_implement(arguments)) { + std::cout << "Invalid Problem Size: " << options.batch << 'x' << options.num_heads_q << 'x' << + options.seq_len_qo << 'x' << options.seq_len_kv << 'x' << options.head_size_qk << 'x' << options.head_size_vo + << (options.is_causal ? "xCausal" : "xNonCausal") << (options.is_local_mask ? "xLocalMask" : "xNonLocalMask") << std::endl; + return cutlass::Status::kErrorInvalidProblem; + } + + // Initialize the workspace + CUTLASS_CHECK(FMHAChunkPrefillKernel::initialize_workspace(arguments, workspace.get())); + + // Convert host-side arguments to device-side arguments to be passed to the kernel + auto params = FMHAChunkPrefillKernel::to_underlying_arguments(arguments, workspace.get()); + + // Run the Flash Attention implementation. + run(params); + + compat::wait(); + + // Verify that the result is correct + bool passed = verify(problem_size, options); + std::cout << "Disposition: " << (passed ? "Passed" : "Failed") << std::endl; + + if (!passed) { + return cutlass::Status::kErrorInternal; + } + + if (options.iterations > 0) { + GPU_Clock timer; + timer.start(); + for (int i = 0; i < options.iterations; ++i) { + run(params); + } + compat::wait(); + + auto offset = cute::min(options.seq_len_qo, options.seq_len_kv); + auto discard_seq_coord = options.seq_len_qo - offset; + auto full_tile_offset = options.seq_len_kv - offset; + // offset + 1 is going to be ceil_div + auto effective_seq_len_kv = options.seq_len_kv_cache + (options.is_causal ? full_tile_offset + ((offset + 1) / 2.0) : + options.is_local_mask ? (options.window_left + options.window_right) + : options.seq_len_kv); + auto effective_seq_len_qo = options.is_causal ? options.seq_len_qo - discard_seq_coord : options.seq_len_qo; + double cute_time = timer.seconds() / options.iterations; + double flops_qk = 2.0 * options.batch * options.num_heads_q * effective_seq_len_qo * effective_seq_len_kv * options.head_size_qk; + double flops_pv = 2.0 * options.batch * options.num_heads_q * effective_seq_len_qo * options.head_size_vo * effective_seq_len_kv; + double tflops = ((flops_qk + flops_pv) * 1e-12) / cute_time; + double gbps_qk = options.batch * (sizeof(ElementQ) * options.num_heads_q * effective_seq_len_qo * options.head_size_qk + + sizeof(ElementK) * options.num_heads_kv * effective_seq_len_kv * options.head_size_qk); + double gbps_pv = sizeof(ElementV) * options.batch * options.num_heads_kv * effective_seq_len_kv * options.head_size_vo + + sizeof(ElementOutput) * options.batch * options.num_heads_q * effective_seq_len_qo * options.head_size_vo; + double gbps = ((gbps_qk + gbps_pv) * 1e-9) / (cute_time); + std::cout << "Batch: " << options.batch << "\tNumHeads_q: " << options.num_heads_q << "\tNumHeads_kv: " << options.num_heads_kv << "\tSeq Length QO: " << options.seq_len_qo + << "\tSeq Length KV: " << options.seq_len_kv << "\tSeq Length KV Cache: " << options.seq_len_kv_cache + << "\tHead Size QK: " << options.head_size_qk << "\tHead Size VO: " << options.head_size_vo + << "\tCausal Mask: " << (options.is_causal ? "true" : "false") << "\tVariable Sequence Length: " << (options.varlen ? "true" : "false") + << "\t Scheduler: " << options.scheduler << "\t Paged KV cache: " << (options.use_paged_kv ? "true" : "false"); + printf("\nPerformance: %4.3f GB/s, %4.3f TFlop/s, %6.4f ms\n\n", gbps, tflops, cute_time * 1000); + } + + return cutlass::Status::kSuccess; + } +}; + +// the default value used for the case BF16 +template struct FMHAConfig { + + template + static int run(const Options &options) { + // + // Run examples + // + + // The KernelHardwareInfo struct holds the number of EUs on the GPU with a given device ID. This + // information is used by the underlying kernel. + cutlass::KernelHardwareInfo hw_info; + + using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16; + using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16; + using CollectiveEpilogue = cutlass::flash_attention::collective::FlashChunkPrefillEpilogue< + Sink, EpilogueDispatchPolicy, MMAOperation, TileShapeOutput, SubgroupLayout, ElementComputeEpilogue, ElementOutput, cutlass::gemm::TagToStrideC_t, ElementOutput, ElementSink, GmemTiledCopyStore>; + using CollectiveSoftmaxEpilogue = cutlass::flash_attention::collective::FlashChunkPrefillSoftmaxEpilogue; + + using ProblemShapeRegular = cute::tuple; + using namespace cutlass::fmha::collective; + using ProblemShapeVarlen = cute::tuple; + using ProblemShapeType = std::conditional_t; + + // Mainloop + using CollectiveMainloop = cutlass::flash_attention::collective::FlashChunkPrefillMma< + GEMMDispatchPolicy, ProblemShapeType, ElementInputQ, cutlass::gemm::TagToStrideA_t, ElementInputKV, + cutlass::gemm::TagToStrideB_t, ElementInputKV, cutlass::gemm::TagToStrideB_t, MMAOperation, TileShapeQK, TileShapePV, SubgroupLayout, + GmemTiledCopyQ, // Q + GmemTiledCopyK, // K + GmemTiledCopyV, // V, + Causal, + LocalMask, + PagedKV>; + + using FMHAChunkPrefillKernel = cutlass::flash_attention::kernel::FMHAPrefillChunk; + + ExampleRunner runner; + + CUTLASS_CHECK(runner.run(options, hw_info)); + return 0; + } + + static int run(const Options &options) { + if (options.varlen) { + if (options.use_paged_kv) { + if (options.use_sink) { + return run(options); + } else { + return run(options); + } + } else { // not paged kv + if (options.use_sink) { + return run(options); + } else { + return run(options); + } + } + } else { // not varlen + if (options.use_paged_kv) { + if (options.use_sink) { + return run(options); + } else { + return run(options); + } + } else { // not paged kv + if (options.use_sink) { + return run(options); + } else { + return run(options); + } + } + } + } +}; diff --git a/test/unit/flash_attention/flash_attention_decode/flash_decode_testbed_3x.hpp b/test/unit/flash_attention/flash_attention_decode/flash_decode_testbed_3x.hpp index 30a09d7a69..a837f86b85 100644 --- a/test/unit/flash_attention/flash_attention_decode/flash_decode_testbed_3x.hpp +++ b/test/unit/flash_attention/flash_attention_decode/flash_decode_testbed_3x.hpp @@ -411,9 +411,9 @@ struct TestbedImpl { int max_seq_len_q = static_cast(cute::get<3>(problem_size)); int max_seq_len_kv = static_cast(cute::get<4>(problem_size)); int max_seq_len_kv_cache = static_cast(cute::get<5>(problem_size)); - cute::get<3>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_q, cumulative_seqlen_q.data()}; - cute::get<4>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_kv, cumulative_seqlen_kv.data()}; - cute::get<5>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_kv_cache, cumulative_seqlen_kv_cache.data()}; + cute::get<3>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_q, 0, cumulative_seqlen_q.data()}; + cute::get<4>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_kv, 0, cumulative_seqlen_kv.data()}; + cute::get<5>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_kv_cache, 0, cumulative_seqlen_kv_cache.data()}; } auto [batch, num_heads_q, num_heads_kv, head_size_qk, head_size_vo] = cute::select<0,1,2,6,7>(problem_size); diff --git a/test/unit/flash_attention/flash_attention_prefill/flash_prefill_testbed_3x.hpp b/test/unit/flash_attention/flash_attention_prefill/flash_prefill_testbed_3x.hpp index ece31f6f7a..13ca1ee3dc 100644 --- a/test/unit/flash_attention/flash_attention_prefill/flash_prefill_testbed_3x.hpp +++ b/test/unit/flash_attention/flash_attention_prefill/flash_prefill_testbed_3x.hpp @@ -377,8 +377,8 @@ struct TestbedImpl { if constexpr (isVarLen) { int max_seq_len_q = static_cast(cute::get<3>(problem_size)); int max_seq_len_kv = static_cast(cute::get<4>(problem_size)); - cute::get<3>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_q, cumulative_seqlen_q.data()}; - cute::get<4>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_kv, cumulative_seqlen_kv.data()}; + cute::get<3>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_q, 0, cumulative_seqlen_q.data()}; + cute::get<4>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_kv, 0, cumulative_seqlen_kv.data()}; } auto [batch, num_heads_q, num_heads_kv, head_size_qk, head_size_vo] = cute::select<0,1,2,5,6>(problem_size); diff --git a/test/unit/flash_attention/flash_attention_prefill_cachedkv/flash_prefill_cachedkv_testbed_3x.hpp b/test/unit/flash_attention/flash_attention_prefill_cachedkv/flash_prefill_cachedkv_testbed_3x.hpp index b758d1b8fd..9a70e379bc 100644 --- a/test/unit/flash_attention/flash_attention_prefill_cachedkv/flash_prefill_cachedkv_testbed_3x.hpp +++ b/test/unit/flash_attention/flash_attention_prefill_cachedkv/flash_prefill_cachedkv_testbed_3x.hpp @@ -362,9 +362,9 @@ struct TestbedImpl { int max_seq_len_q = static_cast(cute::get<3>(problem_size)); int max_seq_len_kv = static_cast(cute::get<4>(problem_size)); int max_seq_len_kv_cache = static_cast(cute::get<5>(problem_size)); - cute::get<3>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_q, cumulative_seqlen_q.data()}; - cute::get<4>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_kv, cumulative_seqlen_kv.data()}; - cute::get<5>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_kv_cache, cumulative_seqlen_kv_cache.data()}; + cute::get<3>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_q, 0, cumulative_seqlen_q.data()}; + cute::get<4>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_kv, 0, cumulative_seqlen_kv.data()}; + cute::get<5>(problem_size) = cutlass::fmha::collective::VariableLength{max_seq_len_kv_cache, 0, cumulative_seqlen_kv_cache.data()}; } auto [batch, num_heads_q, num_heads_kv, head_size_qk, head_size_vo] = cute::select<0,1,2,6,7>(problem_size);