Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion applications/dual_gemm/collective/xe_dual_gemm_mma.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved.
* Copyright (C) 2025 Intel Corporation, All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
Expand Down Expand Up @@ -169,7 +170,7 @@ struct DualGemmMma<MainloopIntelXeXMX16<Stages, Schedule>, TileShape_, ElementA_
TiledMma tiled_mma;
// TODO(Codeplay): see if we can make this nicer
// To make all work items in a subgroup have the same global tensors pass in the index of work item 0 in each subgroup
auto sg = syclcompat::get_nd_item<1>().get_sub_group();
auto sg = compat::get_nd_item<1>().get_sub_group();
auto first_thread_in_sg_idx = sg.get_group_linear_id() * DispatchPolicy::SubgroupSize;
auto thr_mma = tiled_mma.get_slice(first_thread_in_sg_idx);

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved.
* Copyright (C) 2025 Intel Corporation, All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
Expand Down Expand Up @@ -167,8 +168,8 @@ class FlashDecodeEpilogue<epilogue::IntelXeXMX16, MMAOp_, TileShapeOutput_, Subg
using namespace cute;
static constexpr bool is_var_len = cutlass::fmha::collective::is_variable_length_v<tuple_element_t<2, ProblemShape>>;

auto sg = syclcompat::get_nd_item<1>().get_sub_group();
auto group = syclcompat::get_nd_item<1>().get_group();
auto sg = compat::get_nd_item<1>().get_sub_group();
auto group = compat::get_nd_item<1>().get_group();
const int sg_local_id = sg.get_local_id()[0];
const int sg_group_id = sg.get_group_id()[0];

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved.
* Copyright (C) 2025 Intel Corporation, All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
Expand Down Expand Up @@ -226,7 +227,7 @@ struct FlashDecodeMma<gemm::MainloopIntelXeXMX16<Stages>, ProblemShapeType_, Ele
auto thr_copy_K = gmem_tiled_copy_k.get_slice(thread_idx);
// Instantiate the MMA object
TiledMmaQK tiled_mma;
auto sg = syclcompat::get_nd_item<1>().get_sub_group();
auto sg = compat::get_nd_item<1>().get_sub_group();
auto first_thread_in_sg_idx = sg.get_group_id()[0] * DispatchPolicy::SubgroupSize;
// For Normal Attention, K matrix tile_id = subgroup_id (cache and new both)
// For Paged Attention, K matrix tile_id = page_table[subgroup_id] (cache, new keys follow normal attention)
Expand Down Expand Up @@ -315,7 +316,7 @@ struct FlashDecodeMma<gemm::MainloopIntelXeXMX16<Stages>, ProblemShapeType_, Ele
int thread_idx = static_cast<int>(ThreadIdxX());
// Instantiate the MMA object
TiledMmaPV tiled_mma;
auto sg = syclcompat::get_nd_item<1>().get_sub_group();
auto sg = compat::get_nd_item<1>().get_sub_group();
auto thread_mma = tiled_mma.get_slice(0);
// convert X*512|1024 to 32*64*x*8|16 and use (_, sg.get_group_id()[0] / ATOM_N) to index in the (x,8|16) coordinate
Tensor gV_ = take<0,3>(local_tile(gV, select<1,2>(SubgroupTileShapePV{}), make_coord(_, kv_tile_idx)));
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved.
* Copyright (C) 2025 Intel Corporation, All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
Expand Down Expand Up @@ -106,7 +107,7 @@ class FlashDecodeSoftmaxEpilogue<CausalMask_, epilogue::IntelXeXMX16, Element_>

template <int FragsN, class FragAcc, class FragMax, class FragSum>
CUTLASS_DEVICE void scale_exp_log2(FragAcc &frag_s, FragMax const &max, FragSum &sum) {
auto sg = syclcompat::get_nd_item<1>().get_sub_group();
auto sg = compat::get_nd_item<1>().get_sub_group();
const auto max_scale = max * params.scale;
const auto max_scale_bcast = group_broadcast(sg, max_scale, 0);
CUTLASS_PRAGMA_UNROLL
Expand All @@ -119,8 +120,8 @@ class FlashDecodeSoftmaxEpilogue<CausalMask_, epilogue::IntelXeXMX16, Element_>

template <int Num_SGs, int FragsN, class FragSrc, class STensorMax>
CUTLASS_DEVICE void reduce_max(FragSrc &src, STensorMax &stensor_max, Element& max_val) {
auto sg = syclcompat::get_nd_item<1>().get_sub_group();
auto group = syclcompat::get_nd_item<1>().get_group();
auto sg = compat::get_nd_item<1>().get_sub_group();
auto group = compat::get_nd_item<1>().get_group();
const int sg_group_id = sg.get_group_id()[0];
const int sg_local_id = sg.get_local_id()[0];

Expand Down Expand Up @@ -162,7 +163,7 @@ class FlashDecodeSoftmaxEpilogue<CausalMask_, epilogue::IntelXeXMX16, Element_>
reduce_max<Num_SGs,FragsNS>(frag_s, shmem_tensor_max, max_val);

if (!is_first) {
auto sg = syclcompat::get_nd_item<1>().get_sub_group();
auto sg = compat::get_nd_item<1>().get_sub_group();
const int sg_group_id = sg.get_group_id()[0];
const int sg_local_id = sg.get_local_id()[0];
const int sg_size = sg.get_local_range()[0];
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved.
* Copyright (C) 2025 Intel Corporation, All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
Expand Down Expand Up @@ -162,7 +163,7 @@ class FlashPrefillEpilogue<epilogue::IntelXeXMX16, MMAOperation_, TileShapeOutp
constexpr int FragsM = shape<1>(FragOutLayout{});
constexpr int FragsN = size(select<2,3>(shape(FragOutLayout{})));

auto g = syclcompat::get_nd_item<1>().get_sub_group();
auto g = compat::get_nd_item<1>().get_sub_group();
auto out_reg = make_tensor(static_cast<decltype(out) &&>(out).data() , Shape<Int<Vec>, Int<FragsM>, Int<FragsN>>{});

CUTLASS_PRAGMA_UNROLL
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved.
* Copyright (C) 2025 Intel Corporation, All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
Expand Down Expand Up @@ -163,7 +164,7 @@ class FlashPrefillCachedEpilogue<epilogue::IntelXeXMX16, MMAOperation_, TileSha
constexpr int FragsM = shape<1>(FragOutLayout{});
constexpr int FragsN = size(select<2,3>(shape(FragOutLayout{})));

auto g = syclcompat::get_nd_item<1>().get_sub_group();
auto g = compat::get_nd_item<1>().get_sub_group();
auto out_reg = make_tensor(static_cast<decltype(out) &&>(out).data() , Shape<Int<Vec>, Int<FragsM>, Int<FragsN>>{});

CUTLASS_PRAGMA_UNROLL
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved.
* Copyright (C) 2025 Intel Corporation, All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
Expand Down Expand Up @@ -197,7 +198,7 @@ struct FlashPrefillMma<gemm::MainloopIntelXeXMX16<Stages>, ProblemShapeType_, El
// Instantiate the MMA object
TiledMmaQK tiled_mma;
// To make all threads in a warp have the same global tensors pass in the index of thread 0 in each warp
auto sg = syclcompat::get_nd_item<1>().get_sub_group();
auto sg = compat::get_nd_item<1>().get_sub_group();
auto first_thread_in_sg_idx = sg.get_group_id()[0] * DispatchPolicy::SubgroupSize;
auto thread_mma_k = tiled_mma.get_slice(0);
auto thread_mma_q = tiled_mma.get_slice(first_thread_in_sg_idx);
Expand Down Expand Up @@ -282,7 +283,7 @@ struct FlashPrefillMma<gemm::MainloopIntelXeXMX16<Stages>, ProblemShapeType_, El
TiledMmaPV tiled_mma;
// Tile GV to the shape of <64,64> and loop over the HeadSize/64 to avoid Register spill
Tensor gV_ = take<0,3>(local_tile(gV, select<1,2>(TileShapePV{}), make_coord(_, _)));
auto sg = syclcompat::get_nd_item<1>().get_sub_group();
auto sg = compat::get_nd_item<1>().get_sub_group();
auto first_thread_in_sg_idx = sg.get_group_id()[0] * DispatchPolicy::SubgroupSize;
auto thread_mma = tiled_mma.get_slice(first_thread_in_sg_idx);
Tensor tCgV = thread_mma.partition_B(gV_);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved.
* Copyright (C) 2025 Intel Corporation, All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
Expand Down Expand Up @@ -217,7 +218,7 @@ struct FlashPrefillCachedMma<gemm::MainloopIntelXeXMX16<Stages>, ProblemShapeTyp
// Instantiate the MMA object
TiledMmaQK tiled_mma;
// To make all threads in a warp have the same global tensors pass in the index of thread 0 in each warp
auto sg = syclcompat::get_nd_item<1>().get_sub_group();
auto sg = compat::get_nd_item<1>().get_sub_group();
auto first_thread_in_sg_idx = sg.get_group_id()[0] * DispatchPolicy::SubgroupSize;
auto thread_mma_k = tiled_mma.get_slice(0);
auto thread_mma_q = tiled_mma.get_slice(first_thread_in_sg_idx);
Expand Down Expand Up @@ -285,7 +286,7 @@ struct FlashPrefillCachedMma<gemm::MainloopIntelXeXMX16<Stages>, ProblemShapeTyp
TiledMmaPV tiled_mma;
// Tile GV to the shape of <64,64> and loop over the HeadSize/64 to avoid Register spill
Tensor gV_ = take<0,3>(local_tile(gV, select<1,2>(TileShapePV{}), make_coord(_, _)));
auto sg = syclcompat::get_nd_item<1>().get_sub_group();
auto sg = compat::get_nd_item<1>().get_sub_group();
auto first_thread_in_sg_idx = sg.get_group_id()[0] * DispatchPolicy::SubgroupSize;
auto thread_mma = tiled_mma.get_slice(first_thread_in_sg_idx);
Tensor tCgV = thread_mma.partition_B(gV_);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved.
* Copyright (C) 2025 Intel Corporation, All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
Expand Down Expand Up @@ -106,7 +107,7 @@ class FlashPrefillSoftmaxEpilogue<CausalMask_, epilogue::IntelXeXMX16, Element_>

template <int Vec, int FragsM, int FragsN, class FragAcc, class FragMax, class FragSum>
CUTLASS_DEVICE void scale_exp_log2(FragAcc &frag_s, FragMax const &max, FragSum &sum) {
auto g = syclcompat::get_nd_item<1>().get_sub_group();
auto g = compat::get_nd_item<1>().get_sub_group();
const auto max_scale = max * params.scale;
CUTLASS_PRAGMA_UNROLL
for (int indx = 0; indx < Vec * FragsM; indx++) {
Expand All @@ -123,7 +124,7 @@ class FlashPrefillSoftmaxEpilogue<CausalMask_, epilogue::IntelXeXMX16, Element_>

template <int Vec, int FragsM, int FragsN, class FragSrc, class FragMax>
CUTLASS_DEVICE void reduce_max(FragSrc &src, FragMax &max) {
auto g = syclcompat::get_nd_item<1>().get_sub_group();
auto g = compat::get_nd_item<1>().get_sub_group();
CUTLASS_PRAGMA_UNROLL
for (int indx = 0; indx < Vec * FragsM; indx++) {
auto maxptr = group_broadcast(g, max, indx);
Expand Down Expand Up @@ -152,7 +153,7 @@ class FlashPrefillSoftmaxEpilogue<CausalMask_, epilogue::IntelXeXMX16, Element_>
reduce_max<Vec, FragsM, FragsNAcc>(frag_s, max);
static_assert(Vec * FragsM % 8 ==0, " No. of attention rows per subgroup should be >= 1 MMA Atom worth of rows.");
if (!is_first) {
auto g = syclcompat::get_nd_item<1>().get_sub_group();
auto g = compat::get_nd_item<1>().get_sub_group();
Element max_scale{max * params.scale};
Element exp_scale{sycl::native::exp2(max_prev * params.scale - max_scale)};
CUTLASS_PRAGMA_UNROLL
Expand Down
3 changes: 2 additions & 1 deletion applications/flash_attention_v2/kernel/tile_scheduler.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved.
* Copyright (C) 2025 Intel Corporation, All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
Expand Down Expand Up @@ -192,7 +193,7 @@ struct XeFlashPersistentTileScheduler {

template <int Num_SGs>
static dim3 get_grid_shape(Params const& params) {
auto queue = syclcompat::get_default_queue();
auto queue = compat::get_default_queue();
auto dev = queue.get_device();
const size_t maxSubgroups =
dev.template get_info<sycl::info::device::max_num_sub_groups>();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved.
* Copyright (C) 2025 Intel Corporation, All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
Expand Down Expand Up @@ -141,7 +142,7 @@ struct XeFlashPersistentTileScheduler {

template <int Num_SGs>
static dim3 get_grid_shape(Params const& params) {
auto queue = syclcompat::get_default_queue();
auto queue = compat::get_default_queue();
auto dev = queue.get_device();
const size_t maxSubgroups =
dev.template get_info<sycl::info::device::max_num_sub_groups>();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved.
* Copyright (C) 2025 Intel Corporation, All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
Expand Down Expand Up @@ -343,7 +344,7 @@ class FMHADecode {
Tensor out_reg = make_tensor<ElementAccumulator>(AccumShape{});
clear(out_reg);

auto smem = syclcompat::local_mem<ElementAccumulator[((Int<size(AccumShape{}) + 1>{}) * Num_SGs * SubgroupSize)]>();
auto smem = compat::local_mem<ElementAccumulator[((Int<size(AccumShape{}) + 1>{}) * Num_SGs * SubgroupSize)]>();
Tensor shmem_max_tensor = make_tensor(make_smem_ptr(smem), make_shape(Int<Num_SGs * FragsM>{}));

bool is_KV_cache = seq_len_kv_cache != 0;
Expand Down Expand Up @@ -459,7 +460,7 @@ class FMHADecode {
collective_mma.template mmaPV<VSlicer>(out_reg, tSr, gV, out_reg, mainloop_params, false, curr_kv_tile_idx);

// need to apply barrier here to avoid race condition
auto group = syclcompat::get_nd_item<1>().get_group();
auto group = compat::get_nd_item<1>().get_group();
sycl::group_barrier(group);

Tensor shmem_out_tensor = make_tensor(make_smem_ptr(smem), make_shape(Int<(size(AccumShape{})) * SubgroupSize * Num_SGs>{}));
Expand Down
5 changes: 3 additions & 2 deletions benchmarks/common.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved.
* Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved.
* Copyright (C) 2025 Intel Corporation, All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
Expand Down Expand Up @@ -39,7 +40,7 @@
namespace cutlass {
static inline std::size_t get_llc_size() {
#if defined(CUTLASS_ENABLE_SYCL)
return syclcompat::get_default_queue().get_device().get_info<sycl::info::device::global_mem_cache_size>();
return compat::get_default_queue().get_device().get_info<sycl::info::device::global_mem_cache_size>();
#else
cudaDeviceProp prop_struct;
auto result = cudaGetDeviceProperties(&prop_struct, 0);
Expand Down
Loading
Loading