diff --git a/src/sycl/chunked_prefill.cpp b/src/sycl/chunked_prefill.cpp index d29733f..a21e006 100644 --- a/src/sycl/chunked_prefill.cpp +++ b/src/sycl/chunked_prefill.cpp @@ -359,6 +359,7 @@ template < typename TileShapeOutput, typename SubgroupLayout, int PipelineStages, + bool RopeEmbedding = false, bool LocalMask = false, typename ElementInputQ = bfloat16_t, typename ElementInputKV = bfloat16_t, @@ -416,7 +417,8 @@ struct FMHAConfig { GmemTiledCopyV, // V, Causal, LocalMask, - PagedKV>; + PagedKV, + RopeEmbedding>; using FMHAChunkPrefillKernel = cutlass::flash_attention::kernel::FMHAPrefillChunk< ProblemShapeType, @@ -796,7 +798,7 @@ std::vector mha_fwd( cute::Shape<_128, _32, _64>, cute::Shape<_128, _64, _64>, cute::Layout, cute::Stride<_1, _1, _1>>, - PipelineStages>::run(params); + PipelineStages, true>::run(params); break; case 96: FMHAConfig< diff --git a/src/sycl/kernels/chunk_prefill/xe_chunk_prefill.hpp b/src/sycl/kernels/chunk_prefill/xe_chunk_prefill.hpp index 1384bca..1e60b44 100644 --- a/src/sycl/kernels/chunk_prefill/xe_chunk_prefill.hpp +++ b/src/sycl/kernels/chunk_prefill/xe_chunk_prefill.hpp @@ -35,6 +35,8 @@ #include "cutlass/gemm/gemm.h" #include "cutlass/kernel_hardware_info.hpp" #include "xe_flash_attn_chunk_prefill_mma.hpp" +#define THREAD_ID 0 +#define BLOCK_ID 0 namespace cutlass::flash_attention::kernel { @@ -82,6 +84,8 @@ class FMHAPrefillChunk { using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; using MainloopArguments = typename CollectiveMainloop::Arguments; using MainloopParams = typename CollectiveMainloop::Params; + using traits_load_Q = typename CollectiveMainloop::traits_load_Q; + using traits_load_K = typename CollectiveMainloop::traits_load_K; using CollectiveSoftmaxEpilogue = CollectiveSoftmaxEpilogue_; using SoftmaxArguments = typename CollectiveSoftmaxEpilogue::Arguments; @@ -150,6 +154,11 @@ class FMHAPrefillChunk { decltype(make_shape(Int{}, Int{}, get<1>(TileShapePV{}) / get<1>(MmaAtomShape()), Int{})); static constexpr bool is_var_len = CollectiveMainloop::is_var_len; + static constexpr bool rope_enabled = CollectiveMainloop::rope_enabled; + + template + static constexpr bool is_fp8_v = cute::is_same_v || cute::is_same_v; + // Kernel level shared memory storage struct SharedStorage { using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; @@ -241,6 +250,7 @@ class FMHAPrefillChunk { auto batch = get<0>(params.problem_shape); auto num_heads_q = get<1>(params.problem_shape); auto num_heads_kv = get<2>(params.problem_shape); + auto seq_len_kv = get<4>(params.problem_shape); auto& head_size_qk = get<6>(params.problem_shape); auto& head_size_vo = get<7>(params.problem_shape); @@ -328,24 +338,162 @@ class FMHAPrefillChunk { int tiles_per_page = params.mainloop.page_size / QK_BLK_N; Tensor mQ_mkl = cute::get_xe_tensor(make_shape(seq_len_qo, head_size_qk, 1)); //(m,k,l) + Tensor mK_nkl = cute::get_xe_tensor(make_shape(seq_len_kv, head_size_qk, 1)); //(n,k,l) Tensor mK_cache_nkl = cute::get_xe_tensor(make_shape(seq_len_kv_cache, head_size_qk, 1)); // (n_cache,k,l) Tensor mV_cache_nkl = cute::get_xe_tensor(make_shape(head_size_vo, seq_len_kv_cache, 1)); // (n_cache,k,l) + Tensor mCosQ_mkl = cute::get_xe_tensor( + make_shape(seq_len_qo, head_size_qk, 1)); // (m, k, l) + Tensor mSinQ_mkl = cute::get_xe_tensor( + make_shape(seq_len_qo, head_size_qk, 1)); // (m, k, l) + Tensor mCosK_nkl = cute::get_xe_tensor( + make_shape(seq_len_kv, head_size_qk, 1)); // (n, k, l) + Tensor mSinK_nkl = cute::get_xe_tensor( + make_shape(seq_len_kv, head_size_qk, 1)); // (n, k, l) + // block_size and head_size are the same size. So no coord is needed. Tensor mQ_mk = mQ_mkl(_, _, 0); + Tensor mK_nk = mK_nkl(_, _, 0); // (n,k) Tensor mK_cache_nk = mK_cache_nkl(_, _, 0); // (n_cache, k) Tensor mV_cache_nk = mV_cache_nkl(_, _, 0); // (n_cache, k) + Tensor mCosQ_mk = mCosQ_mkl(_, _, 0); // (m,k) + Tensor mSinQ_mk = mSinQ_mkl(_, _, 0); // (m,k) + Tensor mCosK_nk = mCosK_nkl(_, _, 0); // (n,k) + Tensor mSinK_nk = mSinK_nkl(_, _, 0); + auto gQ = local_tile(mQ_mk, TileShapeQK{}, make_coord(blk_m_coord, _, _), Step<_1, X, _1>{}); + auto gK = local_tile(mK_nk, TileShapeQK{}, make_coord(_, _, _), Step{}); auto gK_cache = local_tile(mK_cache_nk, TileShapeQK{}, make_coord(_, _, _), Step{}); auto gV_cache = local_tile(mV_cache_nk, TileShapeOutput{}, make_coord(_, blk_n_coord, _), Step{}); + auto gCosQ = local_tile(mCosQ_mk, TileShapeQK{}, + make_coord(blk_m_coord, _, _), Step<_1, X, _1>{}); + auto gSinQ = local_tile(mSinQ_mk, TileShapeQK{}, + make_coord(blk_m_coord, _, _), Step<_1, X, _1>{}); + auto gCosK = local_tile(mCosK_nk, TileShapeQK{}, + make_coord(_, _ , _), Step{}); + auto gSinK = local_tile(mSinK_nk, TileShapeQK{}, + make_coord(_, _ , _), Step{}); + auto mainloop_params = CollectiveMainloop::get_updated_copies( params.mainloop, params.problem_shape, sequence_length_shape, batch_coord, q_head_coord); + // currently RoPE is not supported for fp8. + if constexpr (rope_enabled && !is_fp8_v) { + if(cute::thread(THREAD_ID,BLOCK_ID)){ + print("inside rope in kernel\n"); + } + int block_idx = static_cast(BlockIdxX()); + int block_idy = static_cast(BlockIdxY()); + int block_idz = static_cast(BlockIdxZ()); + int block_dimx = static_cast(BlockDimX()); + int block_dimy = static_cast(BlockDimY()); + int block_dimz = static_cast(BlockDimZ()); + int thread_idx = static_cast(ThreadIdxX()); + int thread_idy = static_cast(ThreadIdxY()); + int thread_idz = static_cast(ThreadIdxZ()); + int grid_dimx = static_cast(GridDimX()); + int grid_dimy = static_cast(GridDimY()); + int grid_dimz = static_cast(GridDimZ()); + int block_id = block_idx + block_idy * grid_dimx + block_idz * grid_dimx * grid_dimy; + int thread_id = block_id * block_dimx * block_dimy * block_dimz + thread_idz * block_dimx * block_dimy + thread_idy * block_dimx + thread_idx; + + + // calculate the base_ptr and offset for Q, K. + // also calculate the layout for Q, K. + // then apply RoPE on Q, K accordingly + auto [batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, seq_len_kv_cache, head_size_qk, head_size_vo] = params.problem_shape; + + int offset_q = num_heads_q * head_size_qk * seq_len_qo * batch_coord + // Jump to the correct batch + q_head_coord * head_size_qk + // Jump to the correct head + (blk_m_coord*QK_BLK_M*head_size_qk); // Jump to the correct seq_len_qo block + + auto q_group_size = num_heads_q / num_heads_kv; + auto kv_head_coord = q_head_coord / q_group_size; + int offset_k = num_heads_kv * head_size_qk * seq_len_kv * batch_coord + + kv_head_coord * head_size_qk; + + // calculate Q/cosQ/sinQ ptr + auto q_traits = static_cast(mainloop_params.gmem_tiled_copy_q); + ElementQ* base_ptr_q = (ElementQ*)q_traits.base_ptr; + + auto q_traits_cos = static_cast(mainloop_params.gmem_tiled_copy_q_cos); + ElementQ* base_ptr_q_cos = (ElementQ*)q_traits_cos.base_ptr; + + auto q_traits_sin = static_cast(mainloop_params.gmem_tiled_copy_q_sin); + ElementQ* base_ptr_q_sin = (ElementQ*)q_traits_sin.base_ptr; + + auto static_shape_q = make_shape(size<0>(gQ), size<1>(gQ)*size<2>(gQ)); + int s = head_size_qk * num_heads_q; + auto stride_q = make_stride(s, Int<1>{}); + auto layout_q = make_layout(static_shape_q, stride_q); + + // calculate K/cosK/sinK ptr + auto k_traits = static_cast(mainloop_params.gmem_tiled_copy_k); + ElementK* base_ptr_k = (ElementK*)k_traits.base_ptr; + + auto k_traits_cos = static_cast(mainloop_params.gmem_tiled_copy_k_cos); + ElementK* base_ptr_k_cos = (ElementK*)k_traits_cos.base_ptr; + + auto k_traits_sin = static_cast(mainloop_params.gmem_tiled_copy_k_sin); + ElementK* base_ptr_k_sin = (ElementK*)k_traits_sin.base_ptr; + + auto static_shape_k = make_shape(size<0>(gK), size<1>(gK)*size<3>(gK)); + auto layout_k = make_layout(static_shape_k, LayoutRight{}); + auto gK_dim3 = size<3>(gK); + + // calculating rope for Q + auto tensorQ = make_tensor(make_gmem_ptr(base_ptr_q+offset_q), layout_q); + auto tensorCosQ = make_tensor(make_gmem_ptr(base_ptr_q_cos+offset_q), layout_q); + auto tensorSinQ = make_tensor(make_gmem_ptr(base_ptr_q_sin+offset_q), layout_q); + cutlass::flash_attention::collective::apply_rope_interleaved_gmem(thread_idx, tensorQ, tensorCosQ, tensorSinQ, tensorQ); + + //calculating rope for K + // need to consider the case when there are multiple blocks in y direction + // each block in y direction will handle a different set of K + // so need to adjust the base pointer of K accordingly. + if(grid_dimx == 4){ + if (block_id%4==1){ + offset_k += QK_BLK_N*QK_BLK_K*gK_dim3; + } else if (block_id%4==2){ + offset_k += 2*QK_BLK_N*QK_BLK_K*gK_dim3; + } else if (block_id%4==3){ + offset_k += 3*QK_BLK_N*QK_BLK_K*gK_dim3; + } + + auto new_offset_k = offset_k; + for (int i =0 ;i< size<2>(gK); i+=4){ + auto tensorK = make_tensor(make_gmem_ptr(base_ptr_k+new_offset_k), layout_k); + auto tensorCosK = make_tensor(make_gmem_ptr(base_ptr_k_cos+new_offset_k), layout_k); + auto tensorSinK = make_tensor(make_gmem_ptr(base_ptr_k_sin+new_offset_k), layout_k); + // fix next + // cutlass::flash_attention::collective::apply_rope_interleaved_gmem(thread_idx, tensorK, tensorCosK, tensorSinK, tensorK); + new_offset_k += 4*QK_BLK_N*QK_BLK_K*gK_dim3; + } + } else if (grid_dimx ==2){ + if (block_id%2==1){ + offset_k += QK_BLK_N*QK_BLK_K*gK_dim3; + } + auto new_offset_k = offset_k; + for (int i =0 ;i< size<2>(gK); i+=2){ + auto tensorK = make_tensor(make_gmem_ptr(base_ptr_k+new_offset_k), layout_k); + auto tensorCosK = make_tensor(make_gmem_ptr(base_ptr_k_cos+new_offset_k), layout_k); + auto tensorSinK = make_tensor(make_gmem_ptr(base_ptr_k_sin+new_offset_k), layout_k); + // fix next + // cutlass::flash_attention::collective::apply_rope_interleaved_gmem(thread_idx, tensorK, tensorCosK, tensorSinK, tensorK); + new_offset_k += 2*QK_BLK_N*QK_BLK_K*gK_dim3; + } + } + + if(cute::thread(THREAD_ID,BLOCK_ID)){ + print("after rope\n"); + } + } + // we limit the horizontal size to two subgroup, the empirical results // show that reading the two cacheline side by side in gives better // performance and anything after that does not have an effect on diff --git a/src/sycl/kernels/chunk_prefill/xe_flash_attn_chunk_prefill_mma.hpp b/src/sycl/kernels/chunk_prefill/xe_flash_attn_chunk_prefill_mma.hpp index 4c21c3b..234e3cb 100644 --- a/src/sycl/kernels/chunk_prefill/xe_flash_attn_chunk_prefill_mma.hpp +++ b/src/sycl/kernels/chunk_prefill/xe_flash_attn_chunk_prefill_mma.hpp @@ -36,6 +36,7 @@ #include "cutlass/cutlass.h" #include "cutlass/gemm/dispatch_policy.hpp" #include "fmha_fusion.hpp" +#include "xe_rope.h" //////////////////////////////////////////////////////////// namespace {} @@ -66,7 +67,8 @@ template < class GmemTiledCopyV_, bool CausalMask_, bool LocalMask_, - bool PagedKV_> + bool PagedKV_, + bool RopeEmbedding_ = false> struct FlashChunkPrefillMma { static_assert(cutlass::detail::dependent_false, "Could not find a mainloop specialization."); }; @@ -91,7 +93,8 @@ template < class GmemTiledCopyV_, bool CausalMask_, bool LocalMask_, - bool PagedKV_> + bool PagedKV_, + bool RopeEmbedding_> struct FlashChunkPrefillMma< gemm::MainloopIntelXeXMX16, ProblemShapeType_, @@ -110,7 +113,8 @@ struct FlashChunkPrefillMma< GmemTiledCopyV_, CausalMask_, LocalMask_, - PagedKV_> { + PagedKV_, + RopeEmbedding_> { // // Type Aliases // @@ -138,6 +142,7 @@ struct FlashChunkPrefillMma< static constexpr bool CausalMask = CausalMask_; static constexpr bool LocalMask = LocalMask_; static constexpr bool PagedKV = PagedKV_; + static constexpr bool rope_enabled = RopeEmbedding_; static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize; @@ -207,6 +212,9 @@ struct FlashChunkPrefillMma< int max_num_pages_per_seq; int window_left; int window_right; + // for RoPE case + ElementQ const *ptr_cos = nullptr; + ElementQ const *ptr_sin = nullptr; }; struct Params { @@ -218,6 +226,11 @@ struct FlashChunkPrefillMma< int max_num_pages_per_seq; int window_left; int window_right; + // RoPE + XE_Copy_Q gmem_tiled_copy_q_cos; + XE_Copy_Q gmem_tiled_copy_q_sin; + XE_Copy_K gmem_tiled_copy_k_cos; + XE_Copy_K gmem_tiled_copy_k_sin; }; // @@ -241,10 +254,19 @@ struct FlashChunkPrefillMma< auto tensorV_cache = make_tensor( make_gmem_ptr(args.ptr_V_cache), make_layout(make_shape(num_heads_kv * head_size_vo, seq_len_kv_cache, batch), args.dV_cache)); + + auto tensorQCos = make_tensor(make_gmem_ptr(args.ptr_cos), make_layout(make_shape(seq_len_qo, num_heads_q * head_size_qk, batch), args.dQ)); + auto tensorQSin = make_tensor(make_gmem_ptr(args.ptr_sin), make_layout(make_shape(seq_len_qo, num_heads_q * head_size_qk, batch), args.dQ)); + auto tensorKCos = make_tensor(make_gmem_ptr(args.ptr_cos), make_layout(make_shape(seq_len_kv, num_heads_kv * head_size_qk, batch), args.dK)); + auto tensorKSin = make_tensor(make_gmem_ptr(args.ptr_sin), make_layout(make_shape(seq_len_kv, num_heads_kv * head_size_qk, batch), args.dK)); XE_Copy_Q copyQ{XE_Copy_Q{}.with(tensorQ)}; XE_Copy_K copyK_cache{XE_Copy_K{}.with(tensorK_cache)}; XE_Copy_V copyV_cache{XE_Copy_V{}.with(tensorV_cache)}; + XE_Copy_Q copyQCos{XE_Copy_Q{}.with(tensorQCos)}; + XE_Copy_Q copyQSin{XE_Copy_Q{}.with(tensorQSin)}; + XE_Copy_K copyKCos{XE_Copy_K{}.with(tensorKCos)}; + XE_Copy_K copyKSin{XE_Copy_K{}.with(tensorKSin)}; return Params{ copyQ, @@ -254,7 +276,11 @@ struct FlashChunkPrefillMma< args.page_size, args.max_num_pages_per_seq, args.window_left, - args.window_right}; + args.window_right, + copyQCos, + copyQSin, + copyKCos, + copyKSin}; } template @@ -406,7 +432,7 @@ struct FlashChunkPrefillMma< SequenceLengthShape const& sequence_length_shape, int const& l_coord, int const& q_head_coord = 0) { - auto [batch, num_heads_q, num_heads_kv, head_size_qk, head_size_vo] = select<0, 1, 2, 6, 7>(problem_shape); + auto [batch, num_heads_q, num_heads_kv, seq_len_kv, head_size_qk, head_size_vo] = select<0, 1, 2, 4, 6, 7>(problem_shape); auto [seq_len_qo, seq_len_kv_cache] = sequence_length_shape; auto q_group_size = num_heads_q / num_heads_kv; auto kv_head_coord = q_head_coord / q_group_size; @@ -435,6 +461,11 @@ struct FlashChunkPrefillMma< auto shape_q = make_shape(static_cast(seq_len_qo), head_size_qk * num_heads_q, 1); StrideQ stride_q = cutlass::make_cute_packed_stride(StrideQ{}, shape_q); + // added for k[need to change in future] + auto shape_k = make_shape(static_cast(seq_len_kv), + num_heads_kv * head_size_qk, 1); + StrideK stride_k = cutlass::make_cute_packed_stride(StrideK{}, shape_k); + auto shape_k_cache = make_shape( static_cast(PagedKV ? total_seq_len_kv_cache : seq_len_kv_cache), head_size_qk * num_heads_kv, 1); StrideK stride_k_cache = cutlass::make_cute_packed_stride(StrideK{}, shape_k_cache); @@ -446,9 +477,34 @@ struct FlashChunkPrefillMma< make_tensor(make_gmem_ptr(k_cache_ptr + offset_k_cache), make_layout(shape_k_cache, stride_k_cache)); auto tensorV_cache = make_tensor(make_gmem_ptr(v_cache_ptr + offset_v_cache), make_layout(shape_v_cache, stride_v_cache)); + + // for RoPE + auto q_traits_cos = static_cast(params.gmem_tiled_copy_q_cos); + ElementQ* base_ptr_q_cos = (ElementQ*)q_traits_cos.base_ptr; + + auto q_traits_sin = static_cast(params.gmem_tiled_copy_q_sin); + ElementQ* base_ptr_q_sin = (ElementQ*)q_traits_sin.base_ptr; + + auto k_traits_cos = static_cast(params.gmem_tiled_copy_k_cos); + ElementK* base_ptr_k_cos = (ElementK*)k_traits_cos.base_ptr; + + auto k_traits_sin = static_cast(params.gmem_tiled_copy_k_sin); + ElementK* base_ptr_k_sin = (ElementK*)k_traits_sin.base_ptr; + + auto tensorQCos = make_tensor(make_gmem_ptr(base_ptr_q_cos + offset_q), make_layout(shape_q, stride_q)); + auto tensorQSin = make_tensor(make_gmem_ptr(base_ptr_q_sin + offset_q), make_layout(shape_q, stride_q)); + auto tensorKCos = make_tensor(make_gmem_ptr(base_ptr_k_cos + offset_k), make_layout(shape_k, stride_k)); + auto tensorKSin = make_tensor(make_gmem_ptr(base_ptr_k_sin + offset_k), make_layout(shape_k, stride_k)); + + XE_Copy_Q copyQ{XE_Copy_Q{}.with(tensorQ)}; XE_Copy_K copyK_cache{XE_Copy_K{}.with(tensorK_cache)}; XE_Copy_V copyV_cache{XE_Copy_V{}.with(tensorV_cache)}; + XE_Copy_Q copyQCos{XE_Copy_Q{}.with(tensorQCos)}; + XE_Copy_Q copyQSin{XE_Copy_Q{}.with(tensorQSin)}; + XE_Copy_K copyKCos{XE_Copy_K{}.with(tensorKCos)}; + XE_Copy_K copyKSin{XE_Copy_K{}.with(tensorKSin)}; + return Params{ copyQ, copyK_cache, @@ -457,7 +513,12 @@ struct FlashChunkPrefillMma< params.page_size, params.max_num_pages_per_seq, params.window_left, - params.window_right}; + params.window_right, + copyQCos, + copyQSin, + copyKCos, + copyKSin + }; } }; diff --git a/src/sycl/kernels/chunk_prefill/xe_rope.h b/src/sycl/kernels/chunk_prefill/xe_rope.h new file mode 100644 index 0000000..4f65e14 --- /dev/null +++ b/src/sycl/kernels/chunk_prefill/xe_rope.h @@ -0,0 +1,61 @@ +/*************************************************************************************************** + * Copyright (c) 2025 Intel Corporation. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" + +namespace cutlass::flash_attention::collective { +using namespace cute; + +template +CUTLASS_DEVICE void apply_rope_interleaved_gmem( + int thread_idx, + Tensor const &srcTensor, + TensorCos const &gCos, + TensorSin const &gSin, TensorOut &destTensor) { + if(thread_idx < size<0>(srcTensor)){ + for (int j = 0; j < size<1>(gCos); j+=2) { + auto real = static_cast(srcTensor[make_coord(thread_idx, j)]); + auto imag = static_cast(srcTensor[make_coord(thread_idx, j + 1)]); + auto cos_val = static_cast(gCos[make_coord(thread_idx, j)]); + auto sin_val = static_cast(gSin[make_coord(thread_idx, j)]); + + auto new_real = real * cos_val - imag * sin_val; + auto new_imag = real * sin_val + imag * cos_val; + + destTensor[make_coord(thread_idx,j)] = static_cast(new_real); + destTensor[make_coord(thread_idx,j + 1)] = static_cast(new_imag); + } + } + syncthreads(); +} +} // namespace cutlass::flash_attention::collective \ No newline at end of file