Skip to content

Commit 403cccd

Browse files
leonling-llrolandschulz
authored andcommitted
Import SYCLCompat as Compat (#514)
This change imports `SYCLCompat` to cutlass-sycl repo as `compat`. Previous dependencies on `syclcompat` are changed to `compat`. This PR also fix some failures of `SYCLCompat` in oneapi 2025.2. --------- Co-authored-by: Roland Schulz <[email protected]>
1 parent 560f2ac commit 403cccd

File tree

125 files changed

+10418
-770
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

125 files changed

+10418
-770
lines changed

applications/dual_gemm/collective/xe_dual_gemm_mma.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
/***************************************************************************************************
22
* Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved.
3+
* Copyright (C) 2025 Intel Corporation, All rights reserved.
34
* SPDX-License-Identifier: BSD-3-Clause
45
*
56
* Redistribution and use in source and binary forms, with or without
@@ -169,7 +170,7 @@ struct DualGemmMma<MainloopIntelXeXMX16<Stages, Schedule>, TileShape_, ElementA_
169170
TiledMma tiled_mma;
170171
// TODO(Codeplay): see if we can make this nicer
171172
// To make all work items in a subgroup have the same global tensors pass in the index of work item 0 in each subgroup
172-
auto sg = syclcompat::get_nd_item<1>().get_sub_group();
173+
auto sg = compat::get_nd_item<1>().get_sub_group();
173174
auto first_thread_in_sg_idx = sg.get_group_linear_id() * DispatchPolicy::SubgroupSize;
174175
auto thr_mma = tiled_mma.get_slice(first_thread_in_sg_idx);
175176

applications/flash_attention_v2/collective/xe_flash_attn_decode_epilogue.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
/***************************************************************************************************
22
* Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved.
3+
* Copyright (C) 2025 Intel Corporation, All rights reserved.
34
* SPDX-License-Identifier: BSD-3-Clause
45
*
56
* Redistribution and use in source and binary forms, with or without
@@ -167,8 +168,8 @@ class FlashDecodeEpilogue<epilogue::IntelXeXMX16, MMAOp_, TileShapeOutput_, Subg
167168
using namespace cute;
168169
static constexpr bool is_var_len = cutlass::fmha::collective::is_variable_length_v<tuple_element_t<2, ProblemShape>>;
169170

170-
auto sg = syclcompat::get_nd_item<1>().get_sub_group();
171-
auto group = syclcompat::get_nd_item<1>().get_group();
171+
auto sg = compat::get_nd_item<1>().get_sub_group();
172+
auto group = compat::get_nd_item<1>().get_group();
172173
const int sg_local_id = sg.get_local_id()[0];
173174
const int sg_group_id = sg.get_group_id()[0];
174175

applications/flash_attention_v2/collective/xe_flash_attn_decode_mma.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
/***************************************************************************************************
22
* Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved.
3+
* Copyright (C) 2025 Intel Corporation, All rights reserved.
34
* SPDX-License-Identifier: BSD-3-Clause
45
*
56
* Redistribution and use in source and binary forms, with or without
@@ -226,7 +227,7 @@ struct FlashDecodeMma<gemm::MainloopIntelXeXMX16<Stages>, ProblemShapeType_, Ele
226227
auto thr_copy_K = gmem_tiled_copy_k.get_slice(thread_idx);
227228
// Instantiate the MMA object
228229
TiledMmaQK tiled_mma;
229-
auto sg = syclcompat::get_nd_item<1>().get_sub_group();
230+
auto sg = compat::get_nd_item<1>().get_sub_group();
230231
auto first_thread_in_sg_idx = sg.get_group_id()[0] * DispatchPolicy::SubgroupSize;
231232
// For Normal Attention, K matrix tile_id = subgroup_id (cache and new both)
232233
// For Paged Attention, K matrix tile_id = page_table[subgroup_id] (cache, new keys follow normal attention)
@@ -315,7 +316,7 @@ struct FlashDecodeMma<gemm::MainloopIntelXeXMX16<Stages>, ProblemShapeType_, Ele
315316
int thread_idx = static_cast<int>(ThreadIdxX());
316317
// Instantiate the MMA object
317318
TiledMmaPV tiled_mma;
318-
auto sg = syclcompat::get_nd_item<1>().get_sub_group();
319+
auto sg = compat::get_nd_item<1>().get_sub_group();
319320
auto thread_mma = tiled_mma.get_slice(0);
320321
// convert X*512|1024 to 32*64*x*8|16 and use (_, sg.get_group_id()[0] / ATOM_N) to index in the (x,8|16) coordinate
321322
Tensor gV_ = take<0,3>(local_tile(gV, select<1,2>(SubgroupTileShapePV{}), make_coord(_, kv_tile_idx)));

applications/flash_attention_v2/collective/xe_flash_attn_decode_softmax_epilogue.hpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
/***************************************************************************************************
22
* Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved.
3+
* Copyright (C) 2025 Intel Corporation, All rights reserved.
34
* SPDX-License-Identifier: BSD-3-Clause
45
*
56
* Redistribution and use in source and binary forms, with or without
@@ -106,7 +107,7 @@ class FlashDecodeSoftmaxEpilogue<CausalMask_, epilogue::IntelXeXMX16, Element_>
106107

107108
template <int FragsN, class FragAcc, class FragMax, class FragSum>
108109
CUTLASS_DEVICE void scale_exp_log2(FragAcc &frag_s, FragMax const &max, FragSum &sum) {
109-
auto sg = syclcompat::get_nd_item<1>().get_sub_group();
110+
auto sg = compat::get_nd_item<1>().get_sub_group();
110111
const auto max_scale = max * params.scale;
111112
const auto max_scale_bcast = group_broadcast(sg, max_scale, 0);
112113
CUTLASS_PRAGMA_UNROLL
@@ -119,8 +120,8 @@ class FlashDecodeSoftmaxEpilogue<CausalMask_, epilogue::IntelXeXMX16, Element_>
119120

120121
template <int Num_SGs, int FragsN, class FragSrc, class STensorMax>
121122
CUTLASS_DEVICE void reduce_max(FragSrc &src, STensorMax &stensor_max, Element& max_val) {
122-
auto sg = syclcompat::get_nd_item<1>().get_sub_group();
123-
auto group = syclcompat::get_nd_item<1>().get_group();
123+
auto sg = compat::get_nd_item<1>().get_sub_group();
124+
auto group = compat::get_nd_item<1>().get_group();
124125
const int sg_group_id = sg.get_group_id()[0];
125126
const int sg_local_id = sg.get_local_id()[0];
126127

@@ -162,7 +163,7 @@ class FlashDecodeSoftmaxEpilogue<CausalMask_, epilogue::IntelXeXMX16, Element_>
162163
reduce_max<Num_SGs,FragsNS>(frag_s, shmem_tensor_max, max_val);
163164

164165
if (!is_first) {
165-
auto sg = syclcompat::get_nd_item<1>().get_sub_group();
166+
auto sg = compat::get_nd_item<1>().get_sub_group();
166167
const int sg_group_id = sg.get_group_id()[0];
167168
const int sg_local_id = sg.get_local_id()[0];
168169
const int sg_size = sg.get_local_range()[0];

applications/flash_attention_v2/collective/xe_flash_attn_prefill_epilogue.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
/***************************************************************************************************
22
* Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved.
3+
* Copyright (C) 2025 Intel Corporation, All rights reserved.
34
* SPDX-License-Identifier: BSD-3-Clause
45
*
56
* Redistribution and use in source and binary forms, with or without
@@ -162,7 +163,7 @@ class FlashPrefillEpilogue<epilogue::IntelXeXMX16, MMAOperation_, TileShapeOutp
162163
constexpr int FragsM = shape<1>(FragOutLayout{});
163164
constexpr int FragsN = size(select<2,3>(shape(FragOutLayout{})));
164165

165-
auto g = syclcompat::get_nd_item<1>().get_sub_group();
166+
auto g = compat::get_nd_item<1>().get_sub_group();
166167
auto out_reg = make_tensor(static_cast<decltype(out) &&>(out).data() , Shape<Int<Vec>, Int<FragsM>, Int<FragsN>>{});
167168

168169
CUTLASS_PRAGMA_UNROLL

applications/flash_attention_v2/collective/xe_flash_attn_prefill_epilogue_cachedKV.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
/***************************************************************************************************
22
* Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved.
3+
* Copyright (C) 2025 Intel Corporation, All rights reserved.
34
* SPDX-License-Identifier: BSD-3-Clause
45
*
56
* Redistribution and use in source and binary forms, with or without
@@ -163,7 +164,7 @@ class FlashPrefillCachedEpilogue<epilogue::IntelXeXMX16, MMAOperation_, TileSha
163164
constexpr int FragsM = shape<1>(FragOutLayout{});
164165
constexpr int FragsN = size(select<2,3>(shape(FragOutLayout{})));
165166

166-
auto g = syclcompat::get_nd_item<1>().get_sub_group();
167+
auto g = compat::get_nd_item<1>().get_sub_group();
167168
auto out_reg = make_tensor(static_cast<decltype(out) &&>(out).data() , Shape<Int<Vec>, Int<FragsM>, Int<FragsN>>{});
168169

169170
CUTLASS_PRAGMA_UNROLL

applications/flash_attention_v2/collective/xe_flash_attn_prefill_mma.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
/***************************************************************************************************
22
* Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved.
3+
* Copyright (C) 2025 Intel Corporation, All rights reserved.
34
* SPDX-License-Identifier: BSD-3-Clause
45
*
56
* Redistribution and use in source and binary forms, with or without
@@ -197,7 +198,7 @@ struct FlashPrefillMma<gemm::MainloopIntelXeXMX16<Stages>, ProblemShapeType_, El
197198
// Instantiate the MMA object
198199
TiledMmaQK tiled_mma;
199200
// To make all threads in a warp have the same global tensors pass in the index of thread 0 in each warp
200-
auto sg = syclcompat::get_nd_item<1>().get_sub_group();
201+
auto sg = compat::get_nd_item<1>().get_sub_group();
201202
auto first_thread_in_sg_idx = sg.get_group_id()[0] * DispatchPolicy::SubgroupSize;
202203
auto thread_mma_k = tiled_mma.get_slice(0);
203204
auto thread_mma_q = tiled_mma.get_slice(first_thread_in_sg_idx);
@@ -282,7 +283,7 @@ struct FlashPrefillMma<gemm::MainloopIntelXeXMX16<Stages>, ProblemShapeType_, El
282283
TiledMmaPV tiled_mma;
283284
// Tile GV to the shape of <64,64> and loop over the HeadSize/64 to avoid Register spill
284285
Tensor gV_ = take<0,3>(local_tile(gV, select<1,2>(TileShapePV{}), make_coord(_, _)));
285-
auto sg = syclcompat::get_nd_item<1>().get_sub_group();
286+
auto sg = compat::get_nd_item<1>().get_sub_group();
286287
auto first_thread_in_sg_idx = sg.get_group_id()[0] * DispatchPolicy::SubgroupSize;
287288
auto thread_mma = tiled_mma.get_slice(first_thread_in_sg_idx);
288289
Tensor tCgV = thread_mma.partition_B(gV_);

applications/flash_attention_v2/collective/xe_flash_attn_prefill_mma_cachedKV.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
/***************************************************************************************************
22
* Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved.
3+
* Copyright (C) 2025 Intel Corporation, All rights reserved.
34
* SPDX-License-Identifier: BSD-3-Clause
45
*
56
* Redistribution and use in source and binary forms, with or without
@@ -217,7 +218,7 @@ struct FlashPrefillCachedMma<gemm::MainloopIntelXeXMX16<Stages>, ProblemShapeTyp
217218
// Instantiate the MMA object
218219
TiledMmaQK tiled_mma;
219220
// To make all threads in a warp have the same global tensors pass in the index of thread 0 in each warp
220-
auto sg = syclcompat::get_nd_item<1>().get_sub_group();
221+
auto sg = compat::get_nd_item<1>().get_sub_group();
221222
auto first_thread_in_sg_idx = sg.get_group_id()[0] * DispatchPolicy::SubgroupSize;
222223
auto thread_mma_k = tiled_mma.get_slice(0);
223224
auto thread_mma_q = tiled_mma.get_slice(first_thread_in_sg_idx);
@@ -285,7 +286,7 @@ struct FlashPrefillCachedMma<gemm::MainloopIntelXeXMX16<Stages>, ProblemShapeTyp
285286
TiledMmaPV tiled_mma;
286287
// Tile GV to the shape of <64,64> and loop over the HeadSize/64 to avoid Register spill
287288
Tensor gV_ = take<0,3>(local_tile(gV, select<1,2>(TileShapePV{}), make_coord(_, _)));
288-
auto sg = syclcompat::get_nd_item<1>().get_sub_group();
289+
auto sg = compat::get_nd_item<1>().get_sub_group();
289290
auto first_thread_in_sg_idx = sg.get_group_id()[0] * DispatchPolicy::SubgroupSize;
290291
auto thread_mma = tiled_mma.get_slice(first_thread_in_sg_idx);
291292
Tensor tCgV = thread_mma.partition_B(gV_);

applications/flash_attention_v2/collective/xe_flash_attn_prefill_softmax_epilogue.hpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
/***************************************************************************************************
22
* Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved.
3+
* Copyright (C) 2025 Intel Corporation, All rights reserved.
34
* SPDX-License-Identifier: BSD-3-Clause
45
*
56
* Redistribution and use in source and binary forms, with or without
@@ -106,7 +107,7 @@ class FlashPrefillSoftmaxEpilogue<CausalMask_, epilogue::IntelXeXMX16, Element_>
106107

107108
template <int Vec, int FragsM, int FragsN, class FragAcc, class FragMax, class FragSum>
108109
CUTLASS_DEVICE void scale_exp_log2(FragAcc &frag_s, FragMax const &max, FragSum &sum) {
109-
auto g = syclcompat::get_nd_item<1>().get_sub_group();
110+
auto g = compat::get_nd_item<1>().get_sub_group();
110111
const auto max_scale = max * params.scale;
111112
CUTLASS_PRAGMA_UNROLL
112113
for (int indx = 0; indx < Vec * FragsM; indx++) {
@@ -123,7 +124,7 @@ class FlashPrefillSoftmaxEpilogue<CausalMask_, epilogue::IntelXeXMX16, Element_>
123124

124125
template <int Vec, int FragsM, int FragsN, class FragSrc, class FragMax>
125126
CUTLASS_DEVICE void reduce_max(FragSrc &src, FragMax &max) {
126-
auto g = syclcompat::get_nd_item<1>().get_sub_group();
127+
auto g = compat::get_nd_item<1>().get_sub_group();
127128
CUTLASS_PRAGMA_UNROLL
128129
for (int indx = 0; indx < Vec * FragsM; indx++) {
129130
auto maxptr = group_broadcast(g, max, indx);
@@ -152,7 +153,7 @@ class FlashPrefillSoftmaxEpilogue<CausalMask_, epilogue::IntelXeXMX16, Element_>
152153
reduce_max<Vec, FragsM, FragsNAcc>(frag_s, max);
153154
static_assert(Vec * FragsM % 8 ==0, " No. of attention rows per subgroup should be >= 1 MMA Atom worth of rows.");
154155
if (!is_first) {
155-
auto g = syclcompat::get_nd_item<1>().get_sub_group();
156+
auto g = compat::get_nd_item<1>().get_sub_group();
156157
Element max_scale{max * params.scale};
157158
Element exp_scale{sycl::native::exp2(max_prev * params.scale - max_scale)};
158159
CUTLASS_PRAGMA_UNROLL

applications/flash_attention_v2/kernel/tile_scheduler.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
/***************************************************************************************************
22
* Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved.
3+
* Copyright (C) 2025 Intel Corporation, All rights reserved.
34
* SPDX-License-Identifier: BSD-3-Clause
45
*
56
* Redistribution and use in source and binary forms, with or without
@@ -192,7 +193,7 @@ struct XeFlashPersistentTileScheduler {
192193

193194
template <int Num_SGs>
194195
static dim3 get_grid_shape(Params const& params) {
195-
auto queue = syclcompat::get_default_queue();
196+
auto queue = compat::get_default_queue();
196197
auto dev = queue.get_device();
197198
const size_t maxSubgroups =
198199
dev.template get_info<sycl::info::device::max_num_sub_groups>();

0 commit comments

Comments
 (0)