diff --git a/src/sycl/chunked_prefill.cpp b/src/sycl/chunked_prefill.cpp deleted file mode 100644 index 14e201e..0000000 --- a/src/sycl/chunked_prefill.cpp +++ /dev/null @@ -1,858 +0,0 @@ -#include -#include -#include -#include - -#include - -#include "Utils.h" -#include "comm/common.h" -#include "cutlass/epilogue/collective/default_epilogue.hpp" -#include "cutlass/util/device_memory.h" -#include "cutlass/util/packed_stride.hpp" -#include "cutlass/util/sycl_event_manager.hpp" -#include "kernels/chunk_prefill/fmha_fusion.hpp" -#include "kernels/chunk_prefill/tile_scheduler_chunk_prefill.hpp" -#include "kernels/chunk_prefill/xe_chunk_prefill.hpp" -#include "kernels/chunk_prefill/xe_flash_attn_chunk_prefill_epilogue.hpp" -#include "kernels/chunk_prefill/xe_flash_attn_chunk_prefill_softmax_epilogue.hpp" - -using namespace cute; - -struct Flash_fwd_params { - using index_t = int64_t; - - // The QKV matrices. - void* __restrict__ q_ptr; - void* __restrict__ k_ptr; - void* __restrict__ v_ptr; - - // The stride between rows of the Q, K and V matrices. - index_t q_batch_stride; - index_t k_batch_stride; - index_t v_batch_stride; - index_t q_row_stride; - index_t k_row_stride; - index_t v_row_stride; - index_t q_head_stride; - index_t k_head_stride; - index_t v_head_stride; - index_t v_dim_stride; - - // The number of heads. - int h, h_k; - - // The O matrix (output). - void* __restrict__ o_ptr; - void* __restrict__ oaccum_ptr; - - // The stride between rows of O. - index_t o_batch_stride; - index_t o_row_stride; - index_t o_head_stride; - - // The pointer to the softmax sum. - void* __restrict__ softmax_lse_ptr; - void* __restrict__ softmax_lseaccum_ptr; - - // The dimensions. - int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim; - int total_q, total_k; - int total_knew = 0; - int b_k; // When having KV cache and with cache_batch_idx, K & V might have larger batch size than Q - int dv, dv_rounded; // For the case where V headdim is different from Q/K headdim - - // The scaling factors for the kernel. - float scale_softmax; - void* sink_softmax; - float softcap; - - // array of length b+1 holding starting offset of each sequence. - int* __restrict__ cu_seqlens_q; - int* __restrict__ cu_seqlens_k; - int* __restrict__ cu_seqlens_knew; - int* __restrict__ leftpad_k; - - // If provided, the actual length of each q/k sequence. - int* __restrict__ seqused_q; - int* __restrict__ seqused_k; - - // The stride between rows of Oaccum. - index_t oaccum_split_stride; - index_t oaccum_batch_stride; - index_t oaccum_row_stride; - index_t oaccum_head_stride; - - // The stride between rows of LSEaccum. - index_t lseaccum_split_stride; - index_t lseaccum_batch_stride; - index_t lseaccum_head_stride; - - // The K_new and V_new matrices. - void* __restrict__ knew_ptr; - void* __restrict__ vnew_ptr; - - // The stride between rows of the Q, K and V matrices. - index_t knew_batch_stride; - index_t vnew_batch_stride; - index_t knew_row_stride; - index_t vnew_row_stride; - index_t knew_head_stride; - index_t vnew_head_stride; - - void* __restrict__ qv_ptr; - index_t qv_batch_stride; - index_t qv_row_stride; - index_t qv_head_stride; - - // The cos and sin matrices for rotary embedding. - void* __restrict__ rotary_cos_ptr; - void* __restrict__ rotary_sin_ptr; - int* __restrict__ seqlens_rotary; - - // The indices to index into the KV cache. - int* __restrict__ kv_batch_idx; - - // Paged KV cache - int* __restrict__ page_table; - int max_num_pages_per_seq; - index_t page_table_batch_stride; - int page_size; - int num_pages; - bool pagedkv_tma; - - // The dropout probability (probability of keeping an activation). - float p_dropout; - // uint32_t p_dropout_in_uint; - // uint16_t p_dropout_in_uint16_t; - uint8_t p_dropout_in_uint8_t; - - // Scale factor of 1 / (1 - p_dropout). - float rp_dropout; - - // Local window size - int window_size_left, window_size_right; - - // Pointer to the RNG seed (idx 0) and offset (idx 1). - uint64_t* rng_state; - - bool is_bf16; - bool is_fp32; - bool is_e4m3; - bool is_causal; - bool is_local; - - bool is_rotary_interleaved; - - int num_splits; // For split-KV version - bool pack_gqa; - - int* __restrict__ tile_count_semaphore; - // int * __restrict__ num_m_blocks_ptr; - // int * __restrict__ num_n_blocks_ptr; - int* __restrict__ num_splits_dynamic_ptr; - bool skip_scheduler_metadata_computation; - - int arch; - int num_sm; -}; - -template -class KernelCur {}; - -// Flash Attention takes 3 input matrices: Keys, Queries and Values. -using LayoutQ = cutlass::layout::RowMajor; -using LayoutK = cutlass::layout::ColumnMajor; -using LayoutV = cutlass::layout::RowMajor; -using LayoutO = cutlass::layout::RowMajor; - -template -struct KernelRunner { - using StrideQ = typename FMHAChunkPrefillKernel::StrideQ; - using StrideK = typename FMHAChunkPrefillKernel::StrideK; - using StrideV = typename FMHAChunkPrefillKernel::StrideV; - using StrideO = typename FMHAChunkPrefillKernel::StrideO; - - using ElementQ = typename FMHAChunkPrefillKernel::ElementQ; - using ElementK = typename FMHAChunkPrefillKernel::ElementK; - using ElementV = typename FMHAChunkPrefillKernel::ElementV; - using ElementAcc = typename FMHAChunkPrefillKernel::ElementAccumulator; - using ElementSink = typename FMHAChunkPrefillKernel::ElementSink; - - using CollectiveEpilogue = typename FMHAChunkPrefillKernel::CollectiveEpilogue; - using ElementOutput = typename CollectiveEpilogue::ElementOutput; - using ElementCompute = typename CollectiveEpilogue::ElementCompute; - using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; - - using ProblemShapeType = typename FMHAChunkPrefillKernel::ProblemShape; - - // - // Data members - // - - /// Initialization - StrideQ stride_Q; - StrideK stride_K; - StrideV stride_V; - StrideK stride_K_cache; - StrideV stride_V_cache; - StrideO stride_O; - - template - auto initialize_varlen(const Flash_fwd_params& params, ProblemShape& problem_size) { - ProblemShape problem_size_for_init = problem_size; - get<0>(problem_size_for_init) = 1; // concentrated batch - get<3>(problem_size_for_init) = params.total_q; - get<4>(problem_size_for_init) = params.total_knew; - get<5>(problem_size_for_init) = params.total_k; - - ProblemShapeType problem_size_for_launch; - - get<0>(problem_size_for_launch) = get<0>(problem_size); - get<1>(problem_size_for_launch) = get<1>(problem_size); - get<2>(problem_size_for_launch) = get<2>(problem_size); - get<3>(problem_size_for_launch) = cutlass::fmha::collective::VariableLength{params.seqlen_q, params.total_q}; - get<4>(problem_size_for_launch) = cutlass::fmha::collective::VariableLength{params.seqlen_knew, params.total_knew}; - get<5>(problem_size_for_launch) = cutlass::fmha::collective::VariableLength{params.seqlen_k, params.total_k}; - get<6>(problem_size_for_launch) = get<6>(problem_size); - get<7>(problem_size_for_launch) = get<7>(problem_size); - - return cute::make_tuple(problem_size_for_init, problem_size_for_launch); - } - - /// Initialize operands to be used in the GEMM and reference GEMM - ProblemShapeType initialize(const Flash_fwd_params& params) { - auto problem_shape_in = cute::make_tuple( - params.b, // batch - params.h, // num_heads_q - params.h_k, // num_heads_kv - params.seqlen_q, - params.seqlen_knew, - params.seqlen_k, - params.d, - params.dv); - - ProblemShapeType problem_shape; - decltype(problem_shape_in) problem_size; - - if constexpr (isVarLen) { - auto [problem_shape_init, problem_shape_launch] = initialize_varlen(params, problem_shape_in); - problem_size = problem_shape_init; - problem_shape = problem_shape_launch; - } else { - problem_size = problem_shape_in; - problem_shape = problem_shape_in; - } - - auto [batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, seq_len_kv_cache, head_size_qk, head_size_vo] = - problem_size; - auto group_q_size = num_heads_q / num_heads_kv; - auto group_q_num = num_heads_q / group_q_size; - - stride_Q = - cutlass::make_cute_packed_stride(StrideQ{}, cute::make_shape(seq_len_qo, num_heads_q * head_size_qk, batch)); - stride_K = - cutlass::make_cute_packed_stride(StrideK{}, cute::make_shape(seq_len_kv, num_heads_kv * head_size_qk, batch)); - stride_V = - cutlass::make_cute_packed_stride(StrideV{}, cute::make_shape(head_size_vo * num_heads_kv, seq_len_kv, batch)); - - stride_K_cache = cutlass::make_cute_packed_stride( - StrideK{}, cute::make_shape(seq_len_kv_cache, num_heads_kv * head_size_qk, batch)); - stride_V_cache = cutlass::make_cute_packed_stride( - StrideV{}, cute::make_shape(head_size_vo * head_size_qk, seq_len_kv_cache, batch * num_heads_kv)); - stride_O = cutlass::make_cute_packed_stride( - StrideQ{}, cute::make_shape(seq_len_qo * group_q_size, group_q_num * head_size_vo, batch)); - - if constexpr (isVarLen) { - get<3>(problem_shape).cumulative_length = params.cu_seqlens_q; - get<4>(problem_shape).cumulative_length = params.cu_seqlens_knew; - get<5>(problem_shape).cumulative_length = params.cu_seqlens_k; - } - - return problem_shape; - } - - // Note that the GemmUniversalAdapter currently doesn't support flash attention, which is why this - // secondary `run` function is required to launch the kernel. - static void run(typename FMHAChunkPrefillKernel::Params params) { - dim3 const block = FMHAChunkPrefillKernel::get_block_shape(); - dim3 const grid = FMHAChunkPrefillKernel::get_grid_shape(params); - - // configure smem size and carveout - int smem_size = FMHAChunkPrefillKernel::SharedStorageSize; - - const auto sycl_block = compat::dim3(block.x, block.y, block.z); - const auto sycl_grid = compat::dim3(grid.x, grid.y, grid.z); - - using namespace compat::experimental; - compat::experimental::launch_properties launch_props{ - sycl::ext::oneapi::experimental::work_group_scratch_size(smem_size), - }; - compat::experimental::kernel_properties kernel_props{ - sycl::ext::oneapi::experimental::sub_group_size}; - compat::experimental::launch_policy policy{sycl_grid, sycl_block, launch_props, kernel_props}; - - sycl::ext::oneapi::experimental::launch_config config(policy.get_range(), policy.get_launch_properties()); - auto cgf = [&](::sycl::handler& cgh) { - auto KernelFunctor = - compat::experimental::detail::build_kernel_functor>( - cgh, policy, params); - sycl::ext::oneapi::experimental::detail:: - LaunchConfigAccess, decltype(policy.get_launch_properties())> - ConfigAccess(config); - cgh.parallel_for>( - ConfigAccess.getRange(), ConfigAccess.getProperties(), KernelFunctor); - }; - auto stream = at::xpu::getCurrentXPUStream(); - auto q = stream.queue(); - q.submit(cgf); - } - - cutlass::Status run(const Flash_fwd_params& params, const cutlass::KernelHardwareInfo& hw_info) { - ProblemShapeType problem_size = initialize(params); - - typename FMHAChunkPrefillKernel::Arguments arguments{ - cutlass::gemm::GemmUniversalMode::kGemm, - problem_size, - {// static_cast(params.q_ptr), - static_cast(params.q_ptr), - stride_Q, - // static_cast(params.knew_ptr), - // stride_K, - // static_cast(params.vnew_ptr), - // stride_V, - static_cast(params.k_ptr), - stride_K_cache, - static_cast(params.v_ptr), - stride_V_cache, - params.page_table, - params.page_size, - params.max_num_pages_per_seq, - params.window_size_left, - params.window_size_right}, - {(ElementQ)params.scale_softmax}, - {static_cast(params.o_ptr), - stride_O, - static_cast(params.sink_softmax)}, - hw_info}; - - // Define device-global scratch memory - size_t workspace_size = FMHAChunkPrefillKernel::get_workspace_size(arguments); - cutlass::device_memory::allocation workspace(workspace_size); - - if (!FMHAChunkPrefillKernel::can_implement(arguments)) { - return cutlass::Status::kErrorInvalidProblem; - } - - // Initialize the workspace - (FMHAChunkPrefillKernel::initialize_workspace(arguments, workspace.get())); - - // Convert host-side arguments to device-side arguments to be passed to the kernel - auto params_kernel = FMHAChunkPrefillKernel::to_underlying_arguments(arguments, workspace.get()); - - // Run the Flash Attention implementation. - run(params_kernel); - return cutlass::Status::kSuccess; - } -}; - -// the default value used for the case BF16 -template < - typename TileShapeQK, - typename TileShapePV, - typename TileShapeOutput, - typename SubgroupLayout, - int PipelineStages, - bool Causal = false, - bool LocalMask = false, - bool Sink = false, - typename ElementInputQ = bfloat16_t, - typename ElementInputKV = bfloat16_t, - typename MMAOperation = XE_8x16x16_F32BF16BF16F32_TT, - typename GmemTiledCopyQ = XE_2D_U16x8x32_LD_N, - typename GmemTiledCopyK = XE_2D_U16x16x16_LD_T, // _T designates a transposed block load operation - typename GmemTiledCopyV = XE_2D_U16x16x32_LD_V, - typename ElementAccumulator = float, - typename ElementComputeEpilogue = float, - typename ElementOutput = bfloat16_t, - typename ElementSink = bfloat16_t, - typename GmemTiledCopyStore = XE_2D_U16x8x16_ST_N> -struct FMHAConfig { - template - static int run(const Flash_fwd_params& params) { - // The KernelHardwareInfo struct holds the number of EUs on the GPU with a given device ID. This - // information is used by the underlying kernel. - cutlass::KernelHardwareInfo hw_info; - - using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16; - using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16; - using CollectiveEpilogue = cutlass::flash_attention::collective::FlashChunkPrefillEpilogue< - Sink, - EpilogueDispatchPolicy, - MMAOperation, - TileShapeOutput, - SubgroupLayout, - ElementComputeEpilogue, - ElementOutput, - cutlass::gemm::TagToStrideC_t, - ElementOutput, - GmemTiledCopyStore, - ElementSink>; - using CollectiveSoftmaxEpilogue = cutlass::flash_attention::collective:: - FlashChunkPrefillSoftmaxEpilogue; - - using ProblemShapeRegular = cute::tuple; - using namespace cutlass::fmha::collective; - using ProblemShapeVarlen = cute::tuple; - using ProblemShapeType = std::conditional_t; - - // Mainloop - using CollectiveMainloop = cutlass::flash_attention::collective::FlashChunkPrefillMma< - GEMMDispatchPolicy, - ProblemShapeType, - ElementInputQ, - cutlass::gemm::TagToStrideA_t, - ElementInputKV, - cutlass::gemm::TagToStrideB_t, - ElementInputKV, - cutlass::gemm::TagToStrideB_t, - MMAOperation, - TileShapeQK, - TileShapePV, - SubgroupLayout, - GmemTiledCopyQ, // Q - GmemTiledCopyK, // K - GmemTiledCopyV, // V, - Causal, - LocalMask, - PagedKV>; - - using FMHAChunkPrefillKernel = cutlass::flash_attention::kernel::FMHAPrefillChunk< - ProblemShapeType, - CollectiveMainloop, - CollectiveSoftmaxEpilogue, - CollectiveEpilogue, - Scheduler>; - - KernelRunner runner; - - (runner.run(params, hw_info)); - return 0; - } - - static int run(const Flash_fwd_params& params) { - // only support varlen and paged kv now - if (params.page_table != nullptr && params.cu_seqlens_k != nullptr) { - return run(params); - } else { - return 0; - } - } -}; - -inline int round_up_headdim(int head_size) { - if (head_size <= 64) { - return 64; - } - if (head_size <= 96) { - return 96; - } - if (head_size <= 128) { - return 128; - } - if (head_size <= 192) { - return 192; - } - if (head_size <= 256) { - return 256; - } - return 256; -} - -std::vector mha_fwd( - at::Tensor& q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q - const at::Tensor& k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, - // h_k, d) if there is page_table. - const at::Tensor& v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, - // page_size, h_k, dv) if there is page_table. - std::optional& q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q - const at::Tensor& cu_seqlens_q, // b+1 - const at::Tensor& cu_seqlens_k, // b+1 - int max_seqlen_q, - const at::Tensor& page_table, // (b_k, max_num_pages_per_seq) - std::optional& kv_batch_idx_, // b. indices to index into the KV cache - std::optional& leftpad_k_, // b - std::optional& rotary_cos_, // seqlen_ro x (rotary_dim / 2) - std::optional& rotary_sin_, // seqlen_ro x (rotary_dim / 2) - std::optional& seqlens_rotary_, // b - std::optional& q_descale_, // (b, h_k), not (b, h) - std::optional& k_descale_, // (b, h_k) - std::optional& v_descale_, // (b, h_k) - const float softmax_scale_, - std::optional& softmax_sink_, - bool is_causal, - int window_size_left, - int window_size_right, - float const softcap, - bool const is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 - std::optional& scheduler_metadata_, // (b + 1) - int num_splits, - std::optional pack_gqa_, - int const sm_margin) { - // TODO: check GPU support - // auto dprops = at::cuda::getCurrentDeviceProperties(); - // TORCH_CHECK(drops->name.find("B580") != std::string::npos, "sgl_kernel_xpu only supports BMG+"); - - auto q_type = q.scalar_type(); - TORCH_CHECK( - q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16, - "SGL Kernel XPU only supports fp16 and bf16 type"); - - TORCH_CHECK(k.scalar_type() == q_type, "query and key must have the same dtype"); - TORCH_CHECK(v.scalar_type() == q_type, "query and value must have the same dtype"); - - CHECK_DEVICE(q); - CHECK_DEVICE(k); - CHECK_DEVICE(v); - - TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); - - CHECK_DEVICE(page_table); - TORCH_CHECK(page_table.dtype() == torch::kInt32, "page_table must have dtype torch.int32"); - TORCH_CHECK(page_table.stride(-1) == 1, "page_table must have contiguous last dimension"); - - TORCH_CHECK(q.dim() == 3, "query must be in ragged format"); - CHECK_DEVICE(cu_seqlens_q); - CHECK_CONTIGUOUS(cu_seqlens_q); - TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype torch.int32"); - - CHECK_DEVICE(cu_seqlens_k); - CHECK_CONTIGUOUS(cu_seqlens_k); - TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype torch.int32"); - - auto const sizes = q.sizes(); - const int batch_size = cu_seqlens_q.size(0) - 1; - int seqlen_q = max_seqlen_q; - int total_q = q.size(0); - int num_heads = q.size(-2); - int const head_size = q.size(-1); - int const head_size_v = v.size(-1); - int const max_num_pages_per_seq = page_table.size(1); - int const num_pages = k.size(0); - int const page_size = k.size(1); - int const seqlen_k = max_num_pages_per_seq * page_size; - int const total_k = num_pages * page_size; - int const num_heads_k = k.size(-2); - int const batch_size_k = page_table.size(0); - float softmax_scale = softmax_scale_; - - if (!kv_batch_idx_.has_value()) { - TORCH_CHECK(batch_size == batch_size_k, "batch_size must be equal to batch_size_k"); - } - - // Currently only support head dims <= 256 - static constexpr int max_headdim = 256; - TORCH_CHECK( - head_size <= max_headdim, - "FlashAttention forward only supports head dimension at most " + std::to_string(max_headdim)); - TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); - - // This needs to go before kBlockM & kBlockN since we rely on the correct window_size and is_causal to set kBlockM - // TODO: check this - - if (window_size_left >= seqlen_k - 1) { - window_size_left = -1; - } - window_size_right = min(window_size_right, seqlen_q); - // causal=true is the same as causal=false in this case - if (is_causal) { - window_size_right = 0; - } - - CHECK_SHAPE(k, num_pages, page_size, num_heads_k, head_size); - CHECK_SHAPE(v, num_pages, page_size, num_heads_k, head_size_v); - CHECK_SHAPE(page_table, batch_size_k, max_num_pages_per_seq); - - if (leftpad_k_.has_value()) { - auto leftpad_k = leftpad_k_.value(); - TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32"); - CHECK_DEVICE(leftpad_k); - CHECK_CONTIGUOUS(leftpad_k); - CHECK_SHAPE(leftpad_k, batch_size); - } - - static constexpr int alignment = 8; - TORCH_CHECK(head_size % alignment == 0, "head_size should be a multiple of " + std::to_string(alignment)); - TORCH_CHECK(head_size_v % alignment == 0, "head_size_v should be a multiple of " + std::to_string(alignment)); - - auto opts = q.options(); - at::Tensor out; - out = torch::empty({total_q, num_heads, head_size_v}, opts); - - auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; - int const head_size_rounded = round_up_headdim(head_size); - int const head_size_v_rounded = head_size_v == head_size ? head_size_rounded : round_up_headdim(head_size_v); - int const seqlen_q_rounded = round_multiple(seqlen_q, 128); - int const seqlen_k_rounded = round_multiple(seqlen_k, 128); - - // Otherwise the kernel will be launched from cuda:0 device - // Cast to char to avoid compiler warning about narrowing - c10::DeviceGuard device_guard(q.device()); - - at::Tensor softmax_lse; - softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat)); - - // align with FA3 - Flash_fwd_params params; - params.is_bf16 = q.dtype() == torch::kBFloat16; - - // Set the pointers and strides. - params.q_ptr = q.data_ptr(); - params.k_ptr = k.data_ptr(); - params.v_ptr = v.data_ptr(); - // All stride are in elements, not bytes. - params.q_row_stride = q.stride(-3); - params.k_row_stride = k.stride(-3); - params.v_row_stride = v.stride(-3); - params.q_head_stride = q.stride(-2); - params.k_head_stride = k.stride(-2); - params.v_head_stride = v.stride(-2); - params.v_dim_stride = v.stride(-1); - params.o_ptr = out.data_ptr(); - params.o_row_stride = out.stride(-3); - params.o_head_stride = out.stride(-2); - - params.cu_seqlens_q = cu_seqlens_q.data_ptr(); - params.cu_seqlens_k = cu_seqlens_k.data_ptr(); - - // Softmax sum - params.softmax_lse_ptr = softmax_lse.data_ptr(); - - // Set the dimensions. - params.b = batch_size; - params.h = num_heads; - params.h_k = num_heads_k; - params.seqlen_q = seqlen_q; - params.seqlen_k = seqlen_k; - params.seqlen_q_rounded = seqlen_q_rounded; - params.seqlen_k_rounded = seqlen_k_rounded; - params.d = head_size; - params.d_rounded = head_size_rounded; - - // Set the different scale values. - params.scale_softmax = softmax_scale; - bool use_sink = softmax_sink_.has_value(); - params.sink_softmax = use_sink ? softmax_sink_.value().data_ptr() : nullptr; - - params.softcap = softcap; - - // Set this to probability of keeping an element to simplify things. - params.p_dropout = 1.f; - - // Causal is the special case where window_size_right == 0 and window_size_left < 0. - // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. - params.is_causal = window_size_left < 0 && window_size_right == 0; - params.is_local = (window_size_left >= 0 || window_size_right >= 0) && !params.is_causal; - - // TODO: check this - if (window_size_left < 0) { - window_size_left = seqlen_k - 1; - } - if (window_size_right < 0) { - window_size_right = seqlen_q - 1; - } - params.window_size_left = window_size_left; - params.window_size_right = window_size_right; - params.total_q = total_q; - params.total_k = total_k; - params.b_k = batch_size_k; - params.dv = head_size_v; - params.page_table = page_table.data_ptr(); - params.page_table_batch_stride = page_table.stride(0); - params.max_num_pages_per_seq = max_num_pages_per_seq; - params.page_size = page_size; - params.num_pages = num_pages; - - if (q_v_.has_value()) { - TORCH_CHECK(head_size <= 64, "q_v is only supported for head_size <= 64"); - TORCH_CHECK( - q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16, - "q_v is only supported for fp16 and bf16 data type"); - TORCH_CHECK(params.arch == 90, "q_v is only supported for Hopper GPUs"); - at::Tensor q_v = q_v_.value(); - TORCH_CHECK(q_v.dtype() == q_type, "q_v must have the same dtype as query"); - CHECK_DEVICE(q_v); - TORCH_CHECK(q_v.stride(-1) == 1, "q_v tensor must have contiguous last dimension"); - CHECK_SHAPE(q_v, total_q, num_heads, head_size_v); - params.qv_ptr = q_v.data_ptr(); - // All stride are in elements, not bytes. - params.qv_row_stride = q_v.stride(-3); - params.qv_head_stride = q_v.stride(-2); - } - - if (rotary_cos_.has_value()) { - auto rotary_cos = rotary_cos_.value(); - CHECK_DEVICE(rotary_cos); - CHECK_CONTIGUOUS(rotary_cos); - params.rotary_dim = rotary_cos.size(1) * 2; - TORCH_CHECK(params.rotary_dim <= head_size, "rotary_dim must be <= headdim"); - TORCH_CHECK(params.rotary_dim % 16 == 0, "Only rotary dimensions divisible by 16 are currently supported"); - const int seqlen_ro = rotary_cos.size(0); - TORCH_CHECK(seqlen_ro >= seqlen_k, "cos/sin seqlen must be at least the seqlen of KV cache"); - CHECK_SHAPE(rotary_cos, seqlen_ro, params.rotary_dim / 2); - TORCH_CHECK(rotary_cos.scalar_type() == q_type, "rotary_cos must have the same dtype as query"); - - TORCH_CHECK(rotary_sin_.has_value(), "If rotary cos is provided, rotary sin must also be provided"); - auto rotary_sin = rotary_sin_.value(); - CHECK_DEVICE(rotary_sin); - CHECK_CONTIGUOUS(rotary_sin); - CHECK_SHAPE(rotary_sin, seqlen_ro, params.rotary_dim / 2); - TORCH_CHECK(rotary_sin.scalar_type() == q_type, "rotary_cos must have the same dtype as query"); - params.rotary_cos_ptr = rotary_cos.data_ptr(); - params.rotary_sin_ptr = rotary_sin.data_ptr(); - params.is_rotary_interleaved = is_rotary_interleaved; - if (seqlens_rotary_.has_value()) { - at::Tensor seqlens_rotary = seqlens_rotary_.value(); - CHECK_DEVICE(seqlens_rotary); - CHECK_CONTIGUOUS(seqlens_rotary); - TORCH_CHECK(seqlens_rotary.dtype() == torch::kInt32, "seqlens_rotary must have dtype torch.int32"); - CHECK_SHAPE(seqlens_rotary, batch_size); - params.seqlens_rotary = seqlens_rotary.data_ptr(); - } - } else { - params.rotary_dim = 0; - } - - if (kv_batch_idx_.has_value()) { - auto kv_batch_idx = kv_batch_idx_.value(); - CHECK_DEVICE(kv_batch_idx); - CHECK_CONTIGUOUS(kv_batch_idx); - TORCH_CHECK(kv_batch_idx.scalar_type() == torch::kInt32, "kv_batch_idx must have dtype int32"); - params.kv_batch_idx = reinterpret_cast(kv_batch_idx.data_ptr()); - } - - at::Tensor out_accum, softmax_lse_accum; - auto outaccum_type = at::ScalarType::Float; - - constexpr int PipelineStages = 2; - switch (params.d) { - case 64: - AT_DISPATCH_BOOL_NO_RETURN(use_sink, Sink, { - if (params.is_causal) { - FMHAConfig< - cute::Shape<_128, _64, _64>, - cute::Shape<_128, _32, _64>, - cute::Shape<_128, _64, _64>, - cute::Layout, cute::Stride<_1, _1, _1>>, - PipelineStages, - true, - false, - Sink>::run(params); - } else { - AT_DISPATCH_BOOL_NO_RETURN( - params.is_local, - LocalMask, - FMHAConfig< - cute::Shape<_128, _64, _64>, - cute::Shape<_128, _32, _64>, - cute::Shape<_128, _64, _64>, - cute::Layout, cute::Stride<_1, _1, _1>>, - PipelineStages, - false, - LocalMask, - Sink>::run(params)) - } - }) - break; - case 96: - AT_DISPATCH_BOOL_NO_RETURN(use_sink, Sink, { - if (params.is_causal) { - FMHAConfig< - cute::Shape<_128, _64, _32>, - cute::Shape<_128, _32, _64>, - cute::Shape<_128, _96, _64>, - cute::Layout, cute::Stride<_1, _1, _1>>, - PipelineStages, - true, - false, - Sink>::run(params); - - } else { - AT_DISPATCH_BOOL_NO_RETURN( - params.is_local, - LocalMask, - FMHAConfig< - cute::Shape<_128, _64, _32>, - cute::Shape<_128, _32, _64>, - cute::Shape<_128, _96, _64>, - cute::Layout, cute::Stride<_1, _1, _1>>, - PipelineStages, - false, - LocalMask, - Sink>::run(params)) - } - }) - break; - case 128: - AT_DISPATCH_BOOL_NO_RETURN(use_sink, Sink, { - if (params.is_causal) { - FMHAConfig< - cute::Shape<_128, _64, _64>, - cute::Shape<_128, _32, _64>, - cute::Shape<_128, _128, _64>, - cute::Layout, cute::Stride<_1, _1, _1>>, - PipelineStages, - true, - false, - Sink>::run(params); - } else { - AT_DISPATCH_BOOL_NO_RETURN( - params.is_local, - LocalMask, - FMHAConfig< - cute::Shape<_128, _64, _64>, - cute::Shape<_128, _32, _64>, - cute::Shape<_128, _128, _64>, - cute::Layout, cute::Stride<_1, _1, _1>>, - PipelineStages, - false, - LocalMask, - Sink>::run(params)) - } - }) - break; - case 192: - AT_DISPATCH_BOOL_NO_RETURN(use_sink, Sink, { - if (params.is_causal) { - FMHAConfig< - cute::Shape<_256, _64, _64>, - cute::Shape<_256, _32, _64>, - cute::Shape<_256, _192, _64>, - cute::Layout, cute::Stride<_1, _1, _1>>, - PipelineStages, - true, - false, - Sink>::run(params); - } else { - AT_DISPATCH_BOOL_NO_RETURN( - params.is_local, - LocalMask, - FMHAConfig< - cute::Shape<_256, _64, _64>, - cute::Shape<_256, _32, _64>, - cute::Shape<_256, _192, _64>, - cute::Layout, cute::Stride<_1, _1, _1>>, - PipelineStages, - false, - LocalMask, - Sink>::run(params)) - } - }) - break; - default: - TORCH_CHECK(false, "Unsupported head size for causal attention"); - } - return {out, softmax_lse, out_accum, softmax_lse_accum}; -} diff --git a/src/sycl/flash_attn_interface.cpp b/src/sycl/flash_attn_interface.cpp new file mode 100644 index 0000000..a9a274a --- /dev/null +++ b/src/sycl/flash_attn_interface.cpp @@ -0,0 +1,558 @@ +/*************************************************************************************************** + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief various Flash Attention kernel interface and params +*/ +#include +#include +#include +#include +#include +#include "Utils.h" +#include "comm/common.h" +#include "cutlass/util/device_memory.h" +#include "flash_attn_runner.hpp" + +using namespace cute; + +/// @brief Create a structure to hold parameters for Flash Attention forward pass. +struct Flash_fwd_params { + using index_t = int64_t; + + // The QKV matrices. + void* __restrict__ q_ptr; + void* __restrict__ k_ptr; + void* __restrict__ v_ptr; + + // The stride between rows of the Q, K and V matrices. + index_t q_batch_stride; + index_t k_batch_stride; + index_t v_batch_stride; + index_t q_row_stride; + index_t k_row_stride; + index_t v_row_stride; + index_t q_head_stride; + index_t k_head_stride; + index_t v_head_stride; + index_t v_dim_stride; + + // The number of heads. + int h, h_k; + + // The O matrix (output). + void* __restrict__ o_ptr; + void* __restrict__ oaccum_ptr; + + // The stride between rows of O. + index_t o_batch_stride; + index_t o_row_stride; + index_t o_head_stride; + + // The pointer to the softmax sum. + void* __restrict__ softmax_lse_ptr; + void* __restrict__ softmax_lseaccum_ptr; + + // The dimensions. + int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim; + int total_q, total_k; + int total_knew = 0; + int b_k; // When having KV cache and with cache_batch_idx, K & V might have larger batch size than Q + int dv, dv_rounded; // For the case where V headdim is different from Q/K headdim + + // The scaling factors for the kernel. + float scale_softmax; + void* sink_softmax; + float softcap; + + // array of length b+1 holding starting offset of each sequence. + int* __restrict__ cu_seqlens_q; + int* __restrict__ cu_seqlens_k; + int* __restrict__ cu_seqlens_knew; + int* __restrict__ leftpad_k; + + // If provided, the actual length of each q/k sequence. + int* __restrict__ seqused_q; + int* __restrict__ seqused_k; + + // The stride between rows of Oaccum. + index_t oaccum_split_stride; + index_t oaccum_batch_stride; + index_t oaccum_row_stride; + index_t oaccum_head_stride; + + // The stride between rows of LSEaccum. + index_t lseaccum_split_stride; + index_t lseaccum_batch_stride; + index_t lseaccum_head_stride; + + // The K_new and V_new matrices. + void* __restrict__ knew_ptr; + void* __restrict__ vnew_ptr; + + // The stride between rows of the Q, K and V matrices. + index_t knew_batch_stride; + index_t vnew_batch_stride; + index_t knew_row_stride; + index_t vnew_row_stride; + index_t knew_head_stride; + index_t vnew_head_stride; + + void* __restrict__ qv_ptr; + index_t qv_batch_stride; + index_t qv_row_stride; + index_t qv_head_stride; + + // The cos and sin matrices for rotary embedding. + void* __restrict__ rotary_cos_ptr; + void* __restrict__ rotary_sin_ptr; + int* __restrict__ seqlens_rotary; + + // The indices to index into the KV cache. + int* __restrict__ kv_batch_idx; + + // Paged KV cache + int* __restrict__ page_table; + int max_num_pages_per_seq; + index_t page_table_batch_stride; + int page_size; + int num_pages; + bool pagedkv_tma; + + // The dropout probability (probability of keeping an activation). + float p_dropout; + // uint32_t p_dropout_in_uint; + // uint16_t p_dropout_in_uint16_t; + uint8_t p_dropout_in_uint8_t; + + // Scale factor of 1 / (1 - p_dropout). + float rp_dropout; + + // Local window size + int window_size_left, window_size_right; + + // Pointer to the RNG seed (idx 0) and offset (idx 1). + uint64_t* rng_state; + + bool is_bf16; + bool is_fp32; + bool is_e4m3; + bool is_causal; + bool is_local; + + bool is_rotary_interleaved; + + int num_splits; // For split-KV version + bool pack_gqa; + + int* __restrict__ tile_count_semaphore; + // int * __restrict__ num_m_blocks_ptr; + // int * __restrict__ num_n_blocks_ptr; + int* __restrict__ num_splits_dynamic_ptr; + bool skip_scheduler_metadata_computation; + + int arch; + int num_sm; +}; + +/// @brief Dispatch kernel implementation for Flash Attention. +template +void dispatch_kernel_impl(const Flash_fwd_params& params) { + using MMAOperation = typename runner::flash_attention::MMAOperationSelector::type; + using Shape_h = typename runner::flash_attention::ShapeSelector::type; + using Kernel = typename runner::flash_attention::XE_Flash_Attention< + ElementType, float, ElementType, + typename Shape_h::ShapeQK, typename Shape_h::ShapePV, + typename Shape_h::ShapeOutput, typename Shape_h::SubgroupLayout, + MMAOperation, PipelineStages, CausalMask, VarLen, PagedKV, LocalMask, Sink>::Kernel; + runner::flash_attention::RunFlashAttention(params); +} + +/// @brief Dispatch kernel based on varlen, causal, and local for Flash Attention. +template +void dispatch_kernel(const Flash_fwd_params& params) { + // Determine if variable length is needed + bool is_varlen = (params.cu_seqlens_q != nullptr) || (params.cu_seqlens_k != nullptr); + bool use_sink = params.sink_softmax != nullptr; + + // Dispatch based on varlen first, then the other boolean combinations + if (is_varlen) { + AT_DISPATCH_BOOL_NO_RETURN(use_sink, Sink, { + if (params.is_causal) { + if (params.page_table != nullptr) dispatch_kernel_impl(params); + else dispatch_kernel_impl(params); + } else { + AT_DISPATCH_BOOL_NO_RETURN(params.is_local, Local, { + if (params.page_table != nullptr) dispatch_kernel_impl(params); + // currently when paged_kv=false, causal=false, Local=true is causing issue for + // `compiled SIMD16 allocated 256 regs and spilled around 96` + // else dispatch_kernel_impl(params); + }); + } + }); + } else { + /* + // currently we don't support non-varlen kernels + */ + } +} + +inline int round_up_headdim(int head_size) { + if (head_size <= 64) { + return 64; + } + if (head_size <= 96) { + return 96; + } + if (head_size <= 128) { + return 128; + } + if (head_size <= 192) { + return 192; + } + if (head_size <= 256) { + return 256; + } + return 256; +} + + +/// @brief Dispatch kernel implementation for mha_fwd. +std::vector mha_fwd( + at::Tensor& q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + const at::Tensor& k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, + // h_k, d) if there is page_table. + const at::Tensor& v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, + // page_size, h_k, dv) if there is page_table. + std::optional& q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q + const at::Tensor& cu_seqlens_q, // b+1 + const at::Tensor& cu_seqlens_k, // b+1 + int max_seqlen_q, + const at::Tensor& page_table, // (b_k, max_num_pages_per_seq) + std::optional& kv_batch_idx_, // b. indices to index into the KV cache + std::optional& leftpad_k_, // b + std::optional& rotary_cos_, // seqlen_ro x (rotary_dim / 2) + std::optional& rotary_sin_, // seqlen_ro x (rotary_dim / 2) + std::optional& seqlens_rotary_, // b + std::optional& q_descale_, // (b, h_k), not (b, h) + std::optional& k_descale_, // (b, h_k) + std::optional& v_descale_, // (b, h_k) + const float softmax_scale_, + std::optional& softmax_sink_, + bool is_causal, + int window_size_left, + int window_size_right, + float const softcap, + bool const is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 + std::optional& scheduler_metadata_, // (b + 1) + int num_splits, + std::optional pack_gqa_, + int const sm_margin) { + // TODO: check GPU support + // auto dprops = at::cuda::getCurrentDeviceProperties(); + // TORCH_CHECK(drops->name.find("B580") != std::string::npos, "sgl_kernel_xpu only supports BMG+"); + + auto q_type = q.scalar_type(); + TORCH_CHECK( + q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16, + "SGL Kernel XPU only supports fp16 and bf16 type"); + + TORCH_CHECK(k.scalar_type() == q_type, "query and key must have the same dtype"); + TORCH_CHECK(v.scalar_type() == q_type, "query and value must have the same dtype"); + + CHECK_DEVICE(q); + CHECK_DEVICE(k); + CHECK_DEVICE(v); + + TORCH_CHECK(q.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(k.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + TORCH_CHECK(v.stride(-1) == 1, "Input tensor must have contiguous last dimension"); + + CHECK_DEVICE(page_table); + TORCH_CHECK(page_table.dtype() == torch::kInt32, "page_table must have dtype torch.int32"); + TORCH_CHECK(page_table.stride(-1) == 1, "page_table must have contiguous last dimension"); + + TORCH_CHECK(q.dim() == 3, "query must be in ragged format"); + CHECK_DEVICE(cu_seqlens_q); + CHECK_CONTIGUOUS(cu_seqlens_q); + TORCH_CHECK(cu_seqlens_q.dtype() == torch::kInt32, "cu_seqlens_q must have dtype torch.int32"); + + CHECK_DEVICE(cu_seqlens_k); + CHECK_CONTIGUOUS(cu_seqlens_k); + TORCH_CHECK(cu_seqlens_k.dtype() == torch::kInt32, "cu_seqlens_k must have dtype torch.int32"); + + auto const sizes = q.sizes(); + const int batch_size = cu_seqlens_q.size(0) - 1; + int seqlen_q = max_seqlen_q; + int total_q = q.size(0); + int num_heads = q.size(-2); + int const head_size = q.size(-1); + int const head_size_v = v.size(-1); + int const max_num_pages_per_seq = page_table.size(1); + int const num_pages = k.size(0); + int const page_size = k.size(1); + int const seqlen_k = max_num_pages_per_seq * page_size; + int const total_k = num_pages * page_size; + int const num_heads_k = k.size(-2); + int const batch_size_k = page_table.size(0); + float softmax_scale = softmax_scale_; + + if (!kv_batch_idx_.has_value()) { + TORCH_CHECK(batch_size == batch_size_k, "batch_size must be equal to batch_size_k"); + } + + // Currently only support head dims <= 256 + static constexpr int max_headdim = 256; + TORCH_CHECK( + head_size <= max_headdim, + "FlashAttention forward only supports head dimension at most " + std::to_string(max_headdim)); + TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"); + + // This needs to go before kBlockM & kBlockN since we rely on the correct window_size and is_causal to set kBlockM + // TODO: check this + + if (window_size_left >= seqlen_k - 1) { + window_size_left = -1; + } + window_size_right = min(window_size_right, seqlen_q); + // causal=true is the same as causal=false in this case + if (is_causal) { + window_size_right = 0; + } + + CHECK_SHAPE(k, num_pages, page_size, num_heads_k, head_size); + CHECK_SHAPE(v, num_pages, page_size, num_heads_k, head_size_v); + CHECK_SHAPE(page_table, batch_size_k, max_num_pages_per_seq); + + if (leftpad_k_.has_value()) { + auto leftpad_k = leftpad_k_.value(); + TORCH_CHECK(leftpad_k.dtype() == torch::kInt32, "leftpad_k must have dtype int32"); + CHECK_DEVICE(leftpad_k); + CHECK_CONTIGUOUS(leftpad_k); + CHECK_SHAPE(leftpad_k, batch_size); + } + + static constexpr int alignment = 8; + TORCH_CHECK(head_size % alignment == 0, "head_size should be a multiple of " + std::to_string(alignment)); + TORCH_CHECK(head_size_v % alignment == 0, "head_size_v should be a multiple of " + std::to_string(alignment)); + + auto opts = q.options(); + at::Tensor out; + out = torch::empty({total_q, num_heads, head_size_v}, opts); + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + int const head_size_rounded = round_up_headdim(head_size); + int const head_size_v_rounded = head_size_v == head_size ? head_size_rounded : round_up_headdim(head_size_v); + int const seqlen_q_rounded = round_multiple(seqlen_q, 128); + int const seqlen_k_rounded = round_multiple(seqlen_k, 128); + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + c10::DeviceGuard device_guard(q.device()); + + at::Tensor softmax_lse; + softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat)); + + // align with FA3 + Flash_fwd_params params; + params.is_bf16 = q.dtype() == torch::kBFloat16; + + // Set the pointers and strides. + params.q_ptr = q.data_ptr(); + params.k_ptr = k.data_ptr(); + params.v_ptr = v.data_ptr(); + // All stride are in elements, not bytes. + params.q_row_stride = q.stride(-3); + params.k_row_stride = k.stride(-3); + params.v_row_stride = v.stride(-3); + params.q_head_stride = q.stride(-2); + params.k_head_stride = k.stride(-2); + params.v_head_stride = v.stride(-2); + params.v_dim_stride = v.stride(-1); + params.o_ptr = out.data_ptr(); + params.o_row_stride = out.stride(-3); + params.o_head_stride = out.stride(-2); + + params.cu_seqlens_q = cu_seqlens_q.data_ptr(); + params.cu_seqlens_k = cu_seqlens_k.data_ptr(); + + // Softmax sum + params.softmax_lse_ptr = softmax_lse.data_ptr(); + + // Set the dimensions. + params.b = batch_size; + params.h = num_heads; + params.h_k = num_heads_k; + params.seqlen_q = seqlen_q; + params.seqlen_k = seqlen_k; + params.seqlen_q_rounded = seqlen_q_rounded; + params.seqlen_k_rounded = seqlen_k_rounded; + params.d = head_size; + params.d_rounded = head_size_rounded; + + // Set the different scale values. + params.scale_softmax = softmax_scale; + bool use_sink = softmax_sink_.has_value(); + params.sink_softmax = use_sink ? softmax_sink_.value().data_ptr() : nullptr; + + params.softcap = softcap; + + // Set this to probability of keeping an element to simplify things. + params.p_dropout = 1.f; + + // Causal is the special case where window_size_right == 0 and window_size_left < 0. + // Local is the more general case where window_size_right >= 0 or window_size_left >= 0. + params.is_causal = window_size_left < 0 && window_size_right == 0; + params.is_local = (window_size_left >= 0 || window_size_right >= 0) && !params.is_causal; + + // TODO: check this + if (window_size_left < 0) { + window_size_left = seqlen_k - 1; + } + if (window_size_right < 0) { + window_size_right = seqlen_q - 1; + } + params.window_size_left = window_size_left; + params.window_size_right = window_size_right; + params.total_q = total_q; + params.total_k = total_k; + params.b_k = batch_size_k; + params.dv = head_size_v; + params.page_table = page_table.data_ptr(); + params.page_table_batch_stride = page_table.stride(0); + params.max_num_pages_per_seq = max_num_pages_per_seq; + params.page_size = page_size; + params.num_pages = num_pages; + + if (q_v_.has_value()) { + TORCH_CHECK(head_size <= 64, "q_v is only supported for head_size <= 64"); + TORCH_CHECK( + q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16, + "q_v is only supported for fp16 and bf16 data type"); + TORCH_CHECK(params.arch == 90, "q_v is only supported for Hopper GPUs"); + at::Tensor q_v = q_v_.value(); + TORCH_CHECK(q_v.dtype() == q_type, "q_v must have the same dtype as query"); + CHECK_DEVICE(q_v); + TORCH_CHECK(q_v.stride(-1) == 1, "q_v tensor must have contiguous last dimension"); + CHECK_SHAPE(q_v, total_q, num_heads, head_size_v); + params.qv_ptr = q_v.data_ptr(); + // All stride are in elements, not bytes. + params.qv_row_stride = q_v.stride(-3); + params.qv_head_stride = q_v.stride(-2); + } + + if (rotary_cos_.has_value()) { + auto rotary_cos = rotary_cos_.value(); + CHECK_DEVICE(rotary_cos); + CHECK_CONTIGUOUS(rotary_cos); + params.rotary_dim = rotary_cos.size(1) * 2; + TORCH_CHECK(params.rotary_dim <= head_size, "rotary_dim must be <= headdim"); + TORCH_CHECK(params.rotary_dim % 16 == 0, "Only rotary dimensions divisible by 16 are currently supported"); + const int seqlen_ro = rotary_cos.size(0); + TORCH_CHECK(seqlen_ro >= seqlen_k, "cos/sin seqlen must be at least the seqlen of KV cache"); + CHECK_SHAPE(rotary_cos, seqlen_ro, params.rotary_dim / 2); + TORCH_CHECK(rotary_cos.scalar_type() == q_type, "rotary_cos must have the same dtype as query"); + + TORCH_CHECK(rotary_sin_.has_value(), "If rotary cos is provided, rotary sin must also be provided"); + auto rotary_sin = rotary_sin_.value(); + CHECK_DEVICE(rotary_sin); + CHECK_CONTIGUOUS(rotary_sin); + CHECK_SHAPE(rotary_sin, seqlen_ro, params.rotary_dim / 2); + TORCH_CHECK(rotary_sin.scalar_type() == q_type, "rotary_cos must have the same dtype as query"); + params.rotary_cos_ptr = rotary_cos.data_ptr(); + params.rotary_sin_ptr = rotary_sin.data_ptr(); + params.is_rotary_interleaved = is_rotary_interleaved; + if (seqlens_rotary_.has_value()) { + at::Tensor seqlens_rotary = seqlens_rotary_.value(); + CHECK_DEVICE(seqlens_rotary); + CHECK_CONTIGUOUS(seqlens_rotary); + TORCH_CHECK(seqlens_rotary.dtype() == torch::kInt32, "seqlens_rotary must have dtype torch.int32"); + CHECK_SHAPE(seqlens_rotary, batch_size); + params.seqlens_rotary = seqlens_rotary.data_ptr(); + } + } else { + params.rotary_dim = 0; + } + + if (kv_batch_idx_.has_value()) { + auto kv_batch_idx = kv_batch_idx_.value(); + CHECK_DEVICE(kv_batch_idx); + CHECK_CONTIGUOUS(kv_batch_idx); + TORCH_CHECK(kv_batch_idx.scalar_type() == torch::kInt32, "kv_batch_idx must have dtype int32"); + params.kv_batch_idx = reinterpret_cast(kv_batch_idx.data_ptr()); + } + + at::Tensor out_accum, softmax_lse_accum; + auto outaccum_type = at::ScalarType::Float; + + constexpr int PipelineStages = 2; + + if (q_type == at::ScalarType::BFloat16) { + switch (head_size) { + case 64: dispatch_kernel(params); break; + case 96: dispatch_kernel(params); break; + case 128: dispatch_kernel(params); break; + case 192: dispatch_kernel(params); break; + default: TORCH_CHECK(false, "Unsupported head size for BFloat16: " + std::to_string(head_size)); + } + } else if (q_type == at::ScalarType::Half) { + switch (head_size) { + case 64: dispatch_kernel(params); break; + case 96: dispatch_kernel(params); break; + case 128: dispatch_kernel(params); break; + case 192: dispatch_kernel(params); break; + default: TORCH_CHECK(false, "Unsupported head size for Half: " + std::to_string(head_size)); + } + } else { + TORCH_CHECK(false, "Unsupported data type"); + } + return {out, softmax_lse, out_accum, softmax_lse_accum}; +} + +int64_t cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches, int64_t sm_count, int64_t num_kv_splits) { + constexpr int PipelineStages = 2; + constexpr bool CausalMask = false; + constexpr bool VarLen = true; + constexpr bool PagedKV = false; + using MMAOperation = typename runner::flash_attention::MMAOperationSelector::type; + using Shape_h = typename runner::flash_attention::ShapeSelector<64>::type; + using MlaFlashAttnType = typename runner::flash_attention::XE_Flash_Attention; + + cutlass::KernelHardwareInfo hw_info; + typename MlaFlashAttnType::Kernel::Arguments arguments; + arguments.hw_info = hw_info; + // need to change these parameters to match the actual use case + // will do later + + return MlaFlashAttnType::Kernel::get_workspace_size(arguments); +} diff --git a/src/sycl/flash_attn_runner.hpp b/src/sycl/flash_attn_runner.hpp new file mode 100644 index 0000000..4b02f44 --- /dev/null +++ b/src/sycl/flash_attn_runner.hpp @@ -0,0 +1,538 @@ +/*************************************************************************************************** + * Copyright (C) 2025 Intel Corporation, All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Flash Attention execution engine +*/ + +#pragma once + +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "kernels/chunk_prefill/fmha_fusion.hpp" +#include "kernels/chunk_prefill/tile_scheduler_chunk_prefill.hpp" +#include "cutlass/util/packed_stride.hpp" +#include "kernels/chunk_prefill/xe_chunk_prefill.hpp" +#include "kernels/chunk_prefill/xe_flash_attn_chunk_prefill_epilogue.hpp" +#include "kernels/chunk_prefill/xe_flash_attn_chunk_prefill_softmax_epilogue.hpp" +#include "cutlass/util/GPU_Clock.hpp" +#include "cutlass/util/sycl_event_manager.hpp" +#include "cutlass/util/initialize_block.hpp" + +#include +#include + +#include "cutlass/util/command_line.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/reference/device/sycl_tensor_fill.h" + + +namespace runner { +namespace flash_attention { + +using namespace cute; + +using MMAOperationBF16 = cute::XE_8x16x16_F32BF16BF16F32_TT; +using MMAOperationFP16 = cute::XE_8x16x16_F32F16F16F32_TT; + +struct Shape_h64 { + using ShapeQK = Shape<_128, _64, _64>; + using ShapePV = Shape<_128, _32, _64>; + using ShapeOutput = Shape<_128, _64, _64>; + using SubgroupLayout = Layout, Stride<_1, _1, _1>>; +}; + +struct Shape_h96 { + using ShapeQK = Shape<_128, _64, _32>; + using ShapePV = Shape<_128, _32, _64>; + using ShapeOutput = Shape<_128, _96, _64>; + using SubgroupLayout = Layout, Stride<_1, _1, _1>>; +}; + +struct Shape_h128 { + using ShapeQK = Shape<_128, _64, _64>; + using ShapePV = Shape<_128, _32, _64>; + using ShapeOutput = Shape<_128, _128, _64>; + using SubgroupLayout = Layout, Stride<_1, _1, _1>>; +}; + +struct Shape_h192 { + using ShapeQK = Shape<_256, _64, _64>; + using ShapePV = Shape<_256, _32, _64>; + using ShapeOutput = Shape<_256, _192, _64>; + using SubgroupLayout = Layout, Stride<_1, _1, _1>>; +}; + +// Add a template-based selector for MMA operations +template +struct MMAOperationSelector; + +template<> +struct MMAOperationSelector { + using type = MMAOperationBF16; +}; + +template<> +struct MMAOperationSelector { + using type = MMAOperationFP16; +}; + +// Add a template-based selector for shapes +template +struct ShapeSelector; + +template<> +struct ShapeSelector<64> { + using type = Shape_h64; +}; + +template<> +struct ShapeSelector<96> { + using type = Shape_h96; +}; + +template<> +struct ShapeSelector<128> { + using type = Shape_h128; +}; + +template<> +struct ShapeSelector<192> { + using type = Shape_h192; +}; + + +///////////////////////////////////////////////////////////////////// + template struct TiledCopyConfig; + + template <> struct TiledCopyConfig<8, 32> { + using GmemTiledCopyQ = cute::XE_2D_U8x8x32_LD_N; + using GmemTiledCopyK = cute::XE_2D_U8x16x16_LD_T; + using GmemTiledCopyV = cute::XE_2D_U8x32x32_LD_V; + using GmemTiledCopyO = cute::XE_2D_U32x8x16_ST_N; + }; + + template <> struct TiledCopyConfig<8, 8> { + using GmemTiledCopyQ = cute::XE_2D_U8x8x32_LD_N; + using GmemTiledCopyK = cute::XE_2D_U8x16x16_LD_T; + using GmemTiledCopyV = cute::XE_2D_U8x32x32_LD_V; + using GmemTiledCopyO = cute::XE_2D_U8x8x16_ST_N; + }; + + template <> struct TiledCopyConfig<16, 32> { + using GmemTiledCopyQ = cute::XE_2D_U16x8x32_LD_N; + using GmemTiledCopyK = cute::XE_2D_U16x16x16_LD_T; + using GmemTiledCopyV = cute::XE_2D_U16x16x32_LD_V; + using GmemTiledCopyO = cute::XE_2D_U32x8x16_ST_N; + }; + + template <> struct TiledCopyConfig<16, 16> { + using GmemTiledCopyQ = cute::XE_2D_U16x8x32_LD_N; + using GmemTiledCopyK = cute::XE_2D_U16x16x16_LD_T; + using GmemTiledCopyV = cute::XE_2D_U16x16x32_LD_V; + using GmemTiledCopyO = cute::XE_2D_U16x8x16_ST_N; + }; + + template class convert_fp8_to_fp16_name; + + template + void convert_fp8_to_fp16(const SrcT* d_src, DstT* d_dst, size_t size) { + compat::get_default_queue().parallel_for>(size, [=](auto indx) { + d_dst[indx] = static_cast(d_src[indx]); + }).wait(); + } + + +///////////////////////////////////////////////////////////////////// + +using LayoutQ = cutlass::layout::RowMajor; +using LayoutK = cutlass::layout::ColumnMajor; +using LayoutV = cutlass::layout::RowMajor; +using LayoutO = cutlass::layout::RowMajor; + +///////////////////////////////////////////////////////////////////// + +template +struct XE_Flash_Attention { + using ElementAccumulator = ElementAccumulatorType; + using ElementComputeEpilogue = ElementAccumulatorType; + using ElementInputQ = ElementInputType; + using ElementInputKV = ElementInputType; + using ElementOutput = ElementOutputType; + + using ProblemShapeRegular = cute::tuple; + using ProblemShapeVarlen = cute::tuple; + using ProblemShapeType = std::conditional_t; + + using GEMMDispatchPolicy = cutlass::gemm::MainloopIntelXeXMX16; + using EpilogueDispatchPolicy = cutlass::epilogue::IntelXeXMX16; + + using GmemTiledCopyQ = typename TiledCopyConfig, cute::sizeof_bits_v>::GmemTiledCopyQ; + using GmemTiledCopyK = typename TiledCopyConfig, cute::sizeof_bits_v>::GmemTiledCopyK; + using GmemTiledCopyV = typename TiledCopyConfig, cute::sizeof_bits_v>::GmemTiledCopyV; + using GmemTiledCopyStore = typename TiledCopyConfig, cute::sizeof_bits_v>::GmemTiledCopyO; + using CollectiveEpilogue = cutlass::flash_attention::collective::FlashChunkPrefillEpilogue< + Sink, + EpilogueDispatchPolicy, + MMAOperation, + TileShapeOutput, + SubgroupLayout, + ElementComputeEpilogue, + ElementOutput, + cutlass::gemm::TagToStrideC_t, + ElementOutput, + GmemTiledCopyStore, + ElementSink>; + using CollectiveSoftmaxEpilogue = cutlass::flash_attention::collective::FlashChunkPrefillSoftmaxEpilogue< + CausalMask, LocalMask, EpilogueDispatchPolicy, ElementAccumulator>; + + // Mainloop + using CollectiveMainloop = cutlass::flash_attention::collective::FlashChunkPrefillMma< + GEMMDispatchPolicy, + ProblemShapeType, + ElementInputQ, + cutlass::gemm::TagToStrideA_t, + ElementInputKV, + cutlass::gemm::TagToStrideB_t, + ElementInputKV, + cutlass::gemm::TagToStrideB_t, + MMAOperation, + TileShapeQK, + TileShapePV, + SubgroupLayout, + GmemTiledCopyQ, // Q + GmemTiledCopyK, // K + GmemTiledCopyV, // V, + CausalMask, + LocalMask, + PagedKV>; + + using Kernel = cutlass::flash_attention::kernel::FMHAPrefillChunk; +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +class KernelCur {}; + +namespace detail { + +template +struct EngineImpl { + + using StrideQ = typename FlashAttentionKernel::StrideQ; + using StrideK = typename FlashAttentionKernel::StrideK; + using StrideV = typename FlashAttentionKernel::StrideV; + using StrideO = typename FlashAttentionKernel::StrideO; + + using ElementQ = typename FlashAttentionKernel::ElementQ; + using ElementK = typename FlashAttentionKernel::ElementK; + using ElementV = typename FlashAttentionKernel::ElementV; + using ElementAcc = typename FlashAttentionKernel::ElementAccumulator; + using ElementSink = typename FlashAttentionKernel::ElementSink; + + using CollectiveMainloop = typename FlashAttentionKernel::CollectiveMainloop; + using CollectiveEpilogue = typename FlashAttentionKernel::CollectiveEpilogue; + using ElementOutput = typename CollectiveEpilogue::ElementOutput; + using ElementCompute = typename CollectiveEpilogue::ElementCompute; + using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; + + using ProblemShapeType = typename FlashAttentionKernel::ProblemShape; + static constexpr bool HasCausalMask = CollectiveMainloop::CausalMask; + static constexpr bool isVarLen = CollectiveMainloop::is_var_len; + + StrideQ stride_Q; + StrideK stride_K; + StrideV stride_V; + StrideK stride_K_cache; + StrideV stride_V_cache; + StrideO stride_O; + + std::vector cumulative_seqlen_q; + std::vector cumulative_seqlen_kv; + cutlass::DeviceAllocation device_cumulative_seqlen_q; + cutlass::DeviceAllocation device_cumulative_seqlen_kv; + cutlass::DeviceAllocation block_Q; + cutlass::DeviceAllocation block_K; + cutlass::DeviceAllocation block_V; + cutlass::DeviceAllocation block_O; + cutlass::DeviceAllocation block_ref_O; + + // + // Methods + // + template + static constexpr bool is_fp8_v = cute::is_any_of_v; + + template inline auto in_memory(cutlass::DeviceAllocation& in) { + using outType = cute::conditional_t, half_t, Tin>; + if constexpr(is_fp8_v) { + cutlass::DeviceAllocation out(in.size()); + convert_fp8_to_fp16(in.get(), out.get(), in.size()); + return out; + } else { + return in; + }; + } + + template + auto initialize_varlen(const Params& params, ProblemShape& problem_size) { + ProblemShape problem_size_for_init = problem_size; + get<0>(problem_size_for_init) = 1; // concentrated batch + get<3>(problem_size_for_init) = params.total_q; + get<4>(problem_size_for_init) = params.total_knew; + get<5>(problem_size_for_init) = params.total_k; + + ProblemShapeType problem_size_for_launch; + + get<0>(problem_size_for_launch) = get<0>(problem_size); + get<1>(problem_size_for_launch) = get<1>(problem_size); + get<2>(problem_size_for_launch) = get<2>(problem_size); + get<3>(problem_size_for_launch) = cutlass::fmha::collective::VariableLength{params.seqlen_q, params.total_q}; + get<4>(problem_size_for_launch) = cutlass::fmha::collective::VariableLength{params.seqlen_knew, params.total_knew}; + get<5>(problem_size_for_launch) = cutlass::fmha::collective::VariableLength{params.seqlen_k, params.total_k}; + get<6>(problem_size_for_launch) = get<6>(problem_size); + get<7>(problem_size_for_launch) = get<7>(problem_size); + + return cute::make_tuple(problem_size_for_init, problem_size_for_launch); + } + + /// Initialize operands to be used in the GEMM and reference GEMM + template + ProblemShapeType initialize(const Params& params) { + auto problem_shape_in = cute::make_tuple( + params.b, // batch + params.h, // num_heads_q + params.h_k, // num_heads_kv + params.seqlen_q, + params.seqlen_knew, + params.seqlen_k, + params.d, + params.dv); + + ProblemShapeType problem_shape; + decltype(problem_shape_in) problem_size; + + if constexpr (isVarLen) { + auto [problem_shape_init, problem_shape_launch] = initialize_varlen(params, problem_shape_in); + problem_size = problem_shape_init; + problem_shape = problem_shape_launch; + } else { + problem_size = problem_shape_in; + problem_shape = problem_shape_in; + } + + auto [batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, seq_len_kv_cache, head_size_qk, head_size_vo] = + problem_size; + auto group_q_size = num_heads_q / num_heads_kv; + auto group_q_num = num_heads_q / group_q_size; + + stride_Q = + cutlass::make_cute_packed_stride(StrideQ{}, cute::make_shape(seq_len_qo, num_heads_q * head_size_qk, batch)); + stride_K = + cutlass::make_cute_packed_stride(StrideK{}, cute::make_shape(seq_len_kv, num_heads_kv * head_size_qk, batch)); + stride_V = + cutlass::make_cute_packed_stride(StrideV{}, cute::make_shape(head_size_vo * num_heads_kv, seq_len_kv, batch)); + + stride_K_cache = cutlass::make_cute_packed_stride( + StrideK{}, cute::make_shape(seq_len_kv_cache, num_heads_kv * head_size_qk, batch)); + stride_V_cache = cutlass::make_cute_packed_stride( + StrideV{}, cute::make_shape(head_size_vo * head_size_qk, seq_len_kv_cache, batch * num_heads_kv)); + stride_O = cutlass::make_cute_packed_stride( + StrideQ{}, cute::make_shape(seq_len_qo * group_q_size, group_q_num * head_size_vo, batch)); + + if constexpr (isVarLen) { + get<3>(problem_shape).cumulative_length = params.cu_seqlens_q; + get<4>(problem_shape).cumulative_length = params.cu_seqlens_knew; + get<5>(problem_shape).cumulative_length = params.cu_seqlens_k; + } + + return problem_shape; + } + + bool sufficient() { + // check device properties + // Currently, we assume that all Intel Xe devices support Flash Attention + return true; + } + + template + bool run(Params params) + { + // Fail test if insufficient device + if (!sufficient()) { + CUTLASS_TRACE_HOST("EngineImpl::run: Test failed due to insufficient device"); + std::cout << "Test failed due to insufficient device." << std::endl; + return false; + } + ProblemShapeType problem_size = this->initialize(params); + // + // Initialize the Flash attention operator + // + cutlass::KernelHardwareInfo hw_info; + + typename FlashAttentionKernel::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + {// static_cast(params.q_ptr), + static_cast(params.q_ptr), + stride_Q, + // static_cast(params.knew_ptr), + // stride_K, + // static_cast(params.vnew_ptr), + // stride_V, + static_cast(params.k_ptr), + stride_K_cache, + static_cast(params.v_ptr), + stride_V_cache, + params.page_table, + params.page_size, + params.max_num_pages_per_seq, + params.window_size_left, + params.window_size_right}, + {(ElementQ)params.scale_softmax}, + {static_cast(params.o_ptr), + stride_O, + static_cast(params.sink_softmax)}, + hw_info}; + + size_t workspace_size = FlashAttentionKernel::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + if (!FlashAttentionKernel::can_implement(arguments)) { + std::cerr << "This case is not supported." << "\n"; + return false; + } + + // + // Run Flash attention + // + auto params_kernel = FlashAttentionKernel::to_underlying_arguments(arguments, workspace.get()); + auto const block = FlashAttentionKernel::get_block_shape(); + auto const grid = FlashAttentionKernel::get_grid_shape(params_kernel); + + // configure smem size and carveout + int smem_size = FlashAttentionKernel::SharedStorageSize; + + const auto sycl_block = compat::dim3(block.x, block.y, block.z); + const auto sycl_grid = compat::dim3(grid.x, grid.y, grid.z); + + using namespace compat::experimental; + compat::experimental::launch_properties launch_props{ + sycl::ext::oneapi::experimental::work_group_scratch_size(smem_size), + }; + compat::experimental::kernel_properties kernel_props{ + sycl::ext::oneapi::experimental::sub_group_size + }; + compat::experimental::launch_policy policy{sycl_grid, sycl_block, launch_props, kernel_props}; + + sycl::ext::oneapi::experimental::launch_config config(policy.get_range(), policy.get_launch_properties()); + auto cgf = [&](::sycl::handler& cgh) { + auto KernelFunctor = + compat::experimental::detail::build_kernel_functor>( + cgh, policy, params_kernel); + sycl::ext::oneapi::experimental::detail:: + LaunchConfigAccess, decltype(policy.get_launch_properties())> + ConfigAccess(config); + cgh.parallel_for>( + ConfigAccess.getRange(), ConfigAccess.getProperties(), KernelFunctor); + }; + auto stream = at::xpu::getCurrentXPUStream(); + auto q = stream.queue(); + q.submit(cgf); + + try { + compat::wait_and_throw(); + } catch (std::exception const &e) { + std::cerr << "Error at Kernel Sync: " << e.what() << "\n"; + return false; + } + return true; + } +}; + +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename FlashAttention +> +struct Engine3x { + detail::EngineImpl impl_; + + // + // Methods + // + Engine3x() : impl_() {} + + template + bool run(Params params) { + return impl_.run(params); + } +}; + +template +void RunFlashAttention(Params params, std::string config="default") { + Engine3x engine; + bool passed = true; + try { + passed = engine.run(params); + } + catch (std::exception const& e) { + std::cerr << "Executing: engine.run {" + << "batch: " << params.b << ", num_heads_q: " << params.h << ", num_heads_kv: " << params.h_k + << ", seq_len_qo: " << params.seqlen_q << ", seq_len_kv: " << params.seqlen_k << ", seq_len_knew: " << params.seqlen_knew + << ", head_size_vo: " << params.dv << ", head_size_qk: " << params.d + << "} threw an exception: " << e.what() << "\n"; + throw; + } + catch (...) { + std::cerr << "Executing: engine.run {" + << "batch: " << params.b << ", num_heads_q: " << params.h << ", num_heads_kv: " << params.h_k + << ", seq_len_qo: " << params.seqlen_q << ", seq_len_kv: " << params.seqlen_k << ", seq_len_knew: " << params.seqlen_knew + << ", head_size_vo: " << params.dv << ", head_size_qk: " << params.d + << "} threw an exception (unknown)" << "\n"; + throw; + } + return; +} + +} // namespace flash_attention +} // namespace runner + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/src/torch_extension_sycl.cc b/src/torch_extension_sycl.cc index 70ba202..202ae5d 100644 --- a/src/torch_extension_sycl.cc +++ b/src/torch_extension_sycl.cc @@ -93,6 +93,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { " bool? pack_gqa," " int sm_margin) -> Tensor[]"); m.impl("fwd", torch::kXPU, make_pytorch_shim(&mha_fwd)); + + m.def("cutlass_mla_get_workspace_size", &cutlass_mla_get_workspace_size); } REGISTER_EXTENSION(common_ops)