Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/sycl/chunked_prefill.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,7 @@ template <
typename TileShapeOutput,
typename SubgroupLayout,
int PipelineStages,
bool RopeEmbedding = false,
bool LocalMask = false,
typename ElementInputQ = bfloat16_t,
typename ElementInputKV = bfloat16_t,
Expand Down Expand Up @@ -416,7 +417,8 @@ struct FMHAConfig {
GmemTiledCopyV, // V,
Causal,
LocalMask,
PagedKV>;
PagedKV,
RopeEmbedding>;

using FMHAChunkPrefillKernel = cutlass::flash_attention::kernel::FMHAPrefillChunk<
ProblemShapeType,
Expand Down Expand Up @@ -796,7 +798,7 @@ std::vector<at::Tensor> mha_fwd(
cute::Shape<_128, _32, _64>,
cute::Shape<_128, _64, _64>,
cute::Layout<cute::Shape<_8, _1, _1>, cute::Stride<_1, _1, _1>>,
PipelineStages>::run(params);
PipelineStages, true>::run(params);
break;
case 96:
FMHAConfig<
Expand Down
148 changes: 148 additions & 0 deletions src/sycl/kernels/chunk_prefill/xe_chunk_prefill.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
#include "cutlass/gemm/gemm.h"
#include "cutlass/kernel_hardware_info.hpp"
#include "xe_flash_attn_chunk_prefill_mma.hpp"
#define THREAD_ID 0
#define BLOCK_ID 0

namespace cutlass::flash_attention::kernel {

Expand Down Expand Up @@ -82,6 +84,8 @@ class FMHAPrefillChunk {
using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator;
using MainloopArguments = typename CollectiveMainloop::Arguments;
using MainloopParams = typename CollectiveMainloop::Params;
using traits_load_Q = typename CollectiveMainloop::traits_load_Q;
using traits_load_K = typename CollectiveMainloop::traits_load_K;

using CollectiveSoftmaxEpilogue = CollectiveSoftmaxEpilogue_;
using SoftmaxArguments = typename CollectiveSoftmaxEpilogue::Arguments;
Expand Down Expand Up @@ -150,6 +154,11 @@ class FMHAPrefillChunk {
decltype(make_shape(Int<Vec>{}, Int<FragsM>{}, get<1>(TileShapePV{}) / get<1>(MmaAtomShape()), Int<VSlicer>{}));

static constexpr bool is_var_len = CollectiveMainloop::is_var_len;
static constexpr bool rope_enabled = CollectiveMainloop::rope_enabled;

template <typename T>
static constexpr bool is_fp8_v = cute::is_same_v<T,float_e4m3_t> || cute::is_same_v<T,float_e5m2_t>;

// Kernel level shared memory storage
struct SharedStorage {
using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage;
Expand Down Expand Up @@ -241,6 +250,7 @@ class FMHAPrefillChunk {
auto batch = get<0>(params.problem_shape);
auto num_heads_q = get<1>(params.problem_shape);
auto num_heads_kv = get<2>(params.problem_shape);
auto seq_len_kv = get<4>(params.problem_shape);

auto& head_size_qk = get<6>(params.problem_shape);
auto& head_size_vo = get<7>(params.problem_shape);
Expand Down Expand Up @@ -328,24 +338,162 @@ class FMHAPrefillChunk {
int tiles_per_page = params.mainloop.page_size / QK_BLK_N;

Tensor mQ_mkl = cute::get_xe_tensor(make_shape(seq_len_qo, head_size_qk, 1)); //(m,k,l)
Tensor mK_nkl = cute::get_xe_tensor(make_shape(seq_len_kv, head_size_qk, 1)); //(n,k,l)

Tensor mK_cache_nkl = cute::get_xe_tensor(make_shape(seq_len_kv_cache, head_size_qk, 1)); // (n_cache,k,l)
Tensor mV_cache_nkl = cute::get_xe_tensor(make_shape(head_size_vo, seq_len_kv_cache, 1)); // (n_cache,k,l)

Tensor mCosQ_mkl = cute::get_xe_tensor(
make_shape(seq_len_qo, head_size_qk, 1)); // (m, k, l)
Tensor mSinQ_mkl = cute::get_xe_tensor(
make_shape(seq_len_qo, head_size_qk, 1)); // (m, k, l)
Tensor mCosK_nkl = cute::get_xe_tensor(
make_shape(seq_len_kv, head_size_qk, 1)); // (n, k, l)
Tensor mSinK_nkl = cute::get_xe_tensor(
make_shape(seq_len_kv, head_size_qk, 1)); // (n, k, l)

// block_size and head_size are the same size. So no coord is needed.
Tensor mQ_mk = mQ_mkl(_, _, 0);
Tensor mK_nk = mK_nkl(_, _, 0); // (n,k)

Tensor mK_cache_nk = mK_cache_nkl(_, _, 0); // (n_cache, k)
Tensor mV_cache_nk = mV_cache_nkl(_, _, 0); // (n_cache, k)

Tensor mCosQ_mk = mCosQ_mkl(_, _, 0); // (m,k)
Tensor mSinQ_mk = mSinQ_mkl(_, _, 0); // (m,k)
Tensor mCosK_nk = mCosK_nkl(_, _, 0); // (n,k)
Tensor mSinK_nk = mSinK_nkl(_, _, 0);

auto gQ = local_tile(mQ_mk, TileShapeQK{}, make_coord(blk_m_coord, _, _), Step<_1, X, _1>{});
auto gK = local_tile(mK_nk, TileShapeQK{}, make_coord(_, _, _), Step<X, _1, _1>{});

auto gK_cache = local_tile(mK_cache_nk, TileShapeQK{}, make_coord(_, _, _), Step<X, _1, _1>{});
auto gV_cache = local_tile(mV_cache_nk, TileShapeOutput{}, make_coord(_, blk_n_coord, _), Step<X, _1, _1>{});

auto gCosQ = local_tile(mCosQ_mk, TileShapeQK{},
make_coord(blk_m_coord, _, _), Step<_1, X, _1>{});
auto gSinQ = local_tile(mSinQ_mk, TileShapeQK{},
make_coord(blk_m_coord, _, _), Step<_1, X, _1>{});
auto gCosK = local_tile(mCosK_nk, TileShapeQK{},
make_coord(_, _ , _), Step<X, _1, _1>{});
auto gSinK = local_tile(mSinK_nk, TileShapeQK{},
make_coord(_, _ , _), Step<X, _1, _1>{});

auto mainloop_params = CollectiveMainloop::get_updated_copies(
params.mainloop, params.problem_shape, sequence_length_shape, batch_coord, q_head_coord);

// currently RoPE is not supported for fp8.
if constexpr (rope_enabled && !is_fp8_v<ElementQ>) {
if(cute::thread(THREAD_ID,BLOCK_ID)){
print("inside rope in kernel\n");
}
int block_idx = static_cast<int>(BlockIdxX());
int block_idy = static_cast<int>(BlockIdxY());
int block_idz = static_cast<int>(BlockIdxZ());
int block_dimx = static_cast<int>(BlockDimX());
int block_dimy = static_cast<int>(BlockDimY());
int block_dimz = static_cast<int>(BlockDimZ());
int thread_idx = static_cast<int>(ThreadIdxX());
int thread_idy = static_cast<int>(ThreadIdxY());
int thread_idz = static_cast<int>(ThreadIdxZ());
int grid_dimx = static_cast<int>(GridDimX());
int grid_dimy = static_cast<int>(GridDimY());
int grid_dimz = static_cast<int>(GridDimZ());
int block_id = block_idx + block_idy * grid_dimx + block_idz * grid_dimx * grid_dimy;
int thread_id = block_id * block_dimx * block_dimy * block_dimz + thread_idz * block_dimx * block_dimy + thread_idy * block_dimx + thread_idx;


// calculate the base_ptr and offset for Q, K.
// also calculate the layout for Q, K.
// then apply RoPE on Q, K accordingly
auto [batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, seq_len_kv_cache, head_size_qk, head_size_vo] = params.problem_shape;

int offset_q = num_heads_q * head_size_qk * seq_len_qo * batch_coord + // Jump to the correct batch
q_head_coord * head_size_qk + // Jump to the correct head
(blk_m_coord*QK_BLK_M*head_size_qk); // Jump to the correct seq_len_qo block

auto q_group_size = num_heads_q / num_heads_kv;
auto kv_head_coord = q_head_coord / q_group_size;
int offset_k = num_heads_kv * head_size_qk * seq_len_kv * batch_coord +
kv_head_coord * head_size_qk;

// calculate Q/cosQ/sinQ ptr
auto q_traits = static_cast<traits_load_Q const&>(mainloop_params.gmem_tiled_copy_q);
ElementQ* base_ptr_q = (ElementQ*)q_traits.base_ptr;

auto q_traits_cos = static_cast<traits_load_Q const&>(mainloop_params.gmem_tiled_copy_q_cos);
ElementQ* base_ptr_q_cos = (ElementQ*)q_traits_cos.base_ptr;

auto q_traits_sin = static_cast<traits_load_Q const&>(mainloop_params.gmem_tiled_copy_q_sin);
ElementQ* base_ptr_q_sin = (ElementQ*)q_traits_sin.base_ptr;

auto static_shape_q = make_shape(size<0>(gQ), size<1>(gQ)*size<2>(gQ));
int s = head_size_qk * num_heads_q;
auto stride_q = make_stride(s, Int<1>{});
auto layout_q = make_layout(static_shape_q, stride_q);

// calculate K/cosK/sinK ptr
auto k_traits = static_cast<traits_load_K const&>(mainloop_params.gmem_tiled_copy_k);
ElementK* base_ptr_k = (ElementK*)k_traits.base_ptr;

auto k_traits_cos = static_cast<traits_load_K const&>(mainloop_params.gmem_tiled_copy_k_cos);
ElementK* base_ptr_k_cos = (ElementK*)k_traits_cos.base_ptr;

auto k_traits_sin = static_cast<traits_load_K const&>(mainloop_params.gmem_tiled_copy_k_sin);
ElementK* base_ptr_k_sin = (ElementK*)k_traits_sin.base_ptr;

auto static_shape_k = make_shape(size<0>(gK), size<1>(gK)*size<3>(gK));
auto layout_k = make_layout(static_shape_k, LayoutRight{});
auto gK_dim3 = size<3>(gK);

// calculating rope for Q
auto tensorQ = make_tensor(make_gmem_ptr(base_ptr_q+offset_q), layout_q);
auto tensorCosQ = make_tensor(make_gmem_ptr(base_ptr_q_cos+offset_q), layout_q);
auto tensorSinQ = make_tensor(make_gmem_ptr(base_ptr_q_sin+offset_q), layout_q);
cutlass::flash_attention::collective::apply_rope_interleaved_gmem(thread_idx, tensorQ, tensorCosQ, tensorSinQ, tensorQ);

//calculating rope for K
// need to consider the case when there are multiple blocks in y direction
// each block in y direction will handle a different set of K
// so need to adjust the base pointer of K accordingly.
if(grid_dimx == 4){
if (block_id%4==1){
offset_k += QK_BLK_N*QK_BLK_K*gK_dim3;
} else if (block_id%4==2){
offset_k += 2*QK_BLK_N*QK_BLK_K*gK_dim3;
} else if (block_id%4==3){
offset_k += 3*QK_BLK_N*QK_BLK_K*gK_dim3;
}

auto new_offset_k = offset_k;
for (int i =0 ;i< size<2>(gK); i+=4){
auto tensorK = make_tensor(make_gmem_ptr(base_ptr_k+new_offset_k), layout_k);
auto tensorCosK = make_tensor(make_gmem_ptr(base_ptr_k_cos+new_offset_k), layout_k);
auto tensorSinK = make_tensor(make_gmem_ptr(base_ptr_k_sin+new_offset_k), layout_k);
// fix next
// cutlass::flash_attention::collective::apply_rope_interleaved_gmem(thread_idx, tensorK, tensorCosK, tensorSinK, tensorK);
new_offset_k += 4*QK_BLK_N*QK_BLK_K*gK_dim3;
}
} else if (grid_dimx ==2){
if (block_id%2==1){
offset_k += QK_BLK_N*QK_BLK_K*gK_dim3;
}
auto new_offset_k = offset_k;
for (int i =0 ;i< size<2>(gK); i+=2){
auto tensorK = make_tensor(make_gmem_ptr(base_ptr_k+new_offset_k), layout_k);
auto tensorCosK = make_tensor(make_gmem_ptr(base_ptr_k_cos+new_offset_k), layout_k);
auto tensorSinK = make_tensor(make_gmem_ptr(base_ptr_k_sin+new_offset_k), layout_k);
// fix next
// cutlass::flash_attention::collective::apply_rope_interleaved_gmem(thread_idx, tensorK, tensorCosK, tensorSinK, tensorK);
new_offset_k += 2*QK_BLK_N*QK_BLK_K*gK_dim3;
}
}

if(cute::thread(THREAD_ID,BLOCK_ID)){
print("after rope\n");
}
}

// we limit the horizontal size to two subgroup, the empirical results
// show that reading the two cacheline side by side in gives better
// performance and anything after that does not have an effect on
Expand Down
73 changes: 67 additions & 6 deletions src/sycl/kernels/chunk_prefill/xe_flash_attn_chunk_prefill_mma.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "cutlass/cutlass.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "fmha_fusion.hpp"
#include "xe_rope.h"

////////////////////////////////////////////////////////////
namespace {}
Expand Down Expand Up @@ -66,7 +67,8 @@ template <
class GmemTiledCopyV_,
bool CausalMask_,
bool LocalMask_,
bool PagedKV_>
bool PagedKV_,
bool RopeEmbedding_ = false>
struct FlashChunkPrefillMma {
static_assert(cutlass::detail::dependent_false<ElementQ_>, "Could not find a mainloop specialization.");
};
Expand All @@ -91,7 +93,8 @@ template <
class GmemTiledCopyV_,
bool CausalMask_,
bool LocalMask_,
bool PagedKV_>
bool PagedKV_,
bool RopeEmbedding_>
struct FlashChunkPrefillMma<
gemm::MainloopIntelXeXMX16<Stages>,
ProblemShapeType_,
Expand All @@ -110,7 +113,8 @@ struct FlashChunkPrefillMma<
GmemTiledCopyV_,
CausalMask_,
LocalMask_,
PagedKV_> {
PagedKV_,
RopeEmbedding_> {
//
// Type Aliases
//
Expand Down Expand Up @@ -138,6 +142,7 @@ struct FlashChunkPrefillMma<
static constexpr bool CausalMask = CausalMask_;
static constexpr bool LocalMask = LocalMask_;
static constexpr bool PagedKV = PagedKV_;
static constexpr bool rope_enabled = RopeEmbedding_;

static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize;

Expand Down Expand Up @@ -207,6 +212,9 @@ struct FlashChunkPrefillMma<
int max_num_pages_per_seq;
int window_left;
int window_right;
// for RoPE case
ElementQ const *ptr_cos = nullptr;
ElementQ const *ptr_sin = nullptr;
};

struct Params {
Expand All @@ -218,6 +226,11 @@ struct FlashChunkPrefillMma<
int max_num_pages_per_seq;
int window_left;
int window_right;
// RoPE
XE_Copy_Q gmem_tiled_copy_q_cos;
XE_Copy_Q gmem_tiled_copy_q_sin;
XE_Copy_K gmem_tiled_copy_k_cos;
XE_Copy_K gmem_tiled_copy_k_sin;
};

//
Expand All @@ -241,10 +254,19 @@ struct FlashChunkPrefillMma<
auto tensorV_cache = make_tensor(
make_gmem_ptr(args.ptr_V_cache),
make_layout(make_shape(num_heads_kv * head_size_vo, seq_len_kv_cache, batch), args.dV_cache));

auto tensorQCos = make_tensor(make_gmem_ptr(args.ptr_cos), make_layout(make_shape(seq_len_qo, num_heads_q * head_size_qk, batch), args.dQ));
auto tensorQSin = make_tensor(make_gmem_ptr(args.ptr_sin), make_layout(make_shape(seq_len_qo, num_heads_q * head_size_qk, batch), args.dQ));
auto tensorKCos = make_tensor(make_gmem_ptr(args.ptr_cos), make_layout(make_shape(seq_len_kv, num_heads_kv * head_size_qk, batch), args.dK));
auto tensorKSin = make_tensor(make_gmem_ptr(args.ptr_sin), make_layout(make_shape(seq_len_kv, num_heads_kv * head_size_qk, batch), args.dK));

XE_Copy_Q copyQ{XE_Copy_Q{}.with(tensorQ)};
XE_Copy_K copyK_cache{XE_Copy_K{}.with(tensorK_cache)};
XE_Copy_V copyV_cache{XE_Copy_V{}.with(tensorV_cache)};
XE_Copy_Q copyQCos{XE_Copy_Q{}.with(tensorQCos)};
XE_Copy_Q copyQSin{XE_Copy_Q{}.with(tensorQSin)};
XE_Copy_K copyKCos{XE_Copy_K{}.with(tensorKCos)};
XE_Copy_K copyKSin{XE_Copy_K{}.with(tensorKSin)};

return Params{
copyQ,
Expand All @@ -254,7 +276,11 @@ struct FlashChunkPrefillMma<
args.page_size,
args.max_num_pages_per_seq,
args.window_left,
args.window_right};
args.window_right,
copyQCos,
copyQSin,
copyKCos,
copyKSin};
}

template <class FragQccum, class TensorQ, class TensorK, class FragSrc>
Expand Down Expand Up @@ -406,7 +432,7 @@ struct FlashChunkPrefillMma<
SequenceLengthShape const& sequence_length_shape,
int const& l_coord,
int const& q_head_coord = 0) {
auto [batch, num_heads_q, num_heads_kv, head_size_qk, head_size_vo] = select<0, 1, 2, 6, 7>(problem_shape);
auto [batch, num_heads_q, num_heads_kv, seq_len_kv, head_size_qk, head_size_vo] = select<0, 1, 2, 4, 6, 7>(problem_shape);
auto [seq_len_qo, seq_len_kv_cache] = sequence_length_shape;
auto q_group_size = num_heads_q / num_heads_kv;
auto kv_head_coord = q_head_coord / q_group_size;
Expand Down Expand Up @@ -435,6 +461,11 @@ struct FlashChunkPrefillMma<
auto shape_q = make_shape(static_cast<int>(seq_len_qo), head_size_qk * num_heads_q, 1);
StrideQ stride_q = cutlass::make_cute_packed_stride(StrideQ{}, shape_q);

// added for k[need to change in future]
auto shape_k = make_shape(static_cast<int>(seq_len_kv),
num_heads_kv * head_size_qk, 1);
StrideK stride_k = cutlass::make_cute_packed_stride(StrideK{}, shape_k);

auto shape_k_cache = make_shape(
static_cast<int>(PagedKV ? total_seq_len_kv_cache : seq_len_kv_cache), head_size_qk * num_heads_kv, 1);
StrideK stride_k_cache = cutlass::make_cute_packed_stride(StrideK{}, shape_k_cache);
Expand All @@ -446,9 +477,34 @@ struct FlashChunkPrefillMma<
make_tensor(make_gmem_ptr(k_cache_ptr + offset_k_cache), make_layout(shape_k_cache, stride_k_cache));
auto tensorV_cache =
make_tensor(make_gmem_ptr(v_cache_ptr + offset_v_cache), make_layout(shape_v_cache, stride_v_cache));

// for RoPE
auto q_traits_cos = static_cast<traits_load_Q const&>(params.gmem_tiled_copy_q_cos);
ElementQ* base_ptr_q_cos = (ElementQ*)q_traits_cos.base_ptr;

auto q_traits_sin = static_cast<traits_load_Q const&>(params.gmem_tiled_copy_q_sin);
ElementQ* base_ptr_q_sin = (ElementQ*)q_traits_sin.base_ptr;

auto k_traits_cos = static_cast<traits_load_K const&>(params.gmem_tiled_copy_k_cos);
ElementK* base_ptr_k_cos = (ElementK*)k_traits_cos.base_ptr;

auto k_traits_sin = static_cast<traits_load_K const&>(params.gmem_tiled_copy_k_sin);
ElementK* base_ptr_k_sin = (ElementK*)k_traits_sin.base_ptr;

auto tensorQCos = make_tensor(make_gmem_ptr(base_ptr_q_cos + offset_q), make_layout(shape_q, stride_q));
auto tensorQSin = make_tensor(make_gmem_ptr(base_ptr_q_sin + offset_q), make_layout(shape_q, stride_q));
auto tensorKCos = make_tensor(make_gmem_ptr(base_ptr_k_cos + offset_k), make_layout(shape_k, stride_k));
auto tensorKSin = make_tensor(make_gmem_ptr(base_ptr_k_sin + offset_k), make_layout(shape_k, stride_k));


XE_Copy_Q copyQ{XE_Copy_Q{}.with(tensorQ)};
XE_Copy_K copyK_cache{XE_Copy_K{}.with(tensorK_cache)};
XE_Copy_V copyV_cache{XE_Copy_V{}.with(tensorV_cache)};
XE_Copy_Q copyQCos{XE_Copy_Q{}.with(tensorQCos)};
XE_Copy_Q copyQSin{XE_Copy_Q{}.with(tensorQSin)};
XE_Copy_K copyKCos{XE_Copy_K{}.with(tensorKCos)};
XE_Copy_K copyKSin{XE_Copy_K{}.with(tensorKSin)};

return Params{
copyQ,
copyK_cache,
Expand All @@ -457,7 +513,12 @@ struct FlashChunkPrefillMma<
params.page_size,
params.max_num_pages_per_seq,
params.window_left,
params.window_right};
params.window_right,
copyQCos,
copyQSin,
copyKCos,
copyKSin
};
}
};

Expand Down
Loading