Skip to content

Commit 501b9c2

Browse files
committed
Rename cutlasscompat to compat
1 parent 33eebdc commit 501b9c2

File tree

124 files changed

+1034
-1034
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

124 files changed

+1034
-1034
lines changed

applications/dual_gemm/collective/xe_dual_gemm_mma.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ struct DualGemmMma<MainloopIntelXeXMX16<Stages, Schedule>, TileShape_, ElementA_
170170
TiledMma tiled_mma;
171171
// TODO(Codeplay): see if we can make this nicer
172172
// To make all work items in a subgroup have the same global tensors pass in the index of work item 0 in each subgroup
173-
auto sg = cutlasscompat::get_nd_item<1>().get_sub_group();
173+
auto sg = compat::get_nd_item<1>().get_sub_group();
174174
auto first_thread_in_sg_idx = sg.get_group_linear_id() * DispatchPolicy::SubgroupSize;
175175
auto thr_mma = tiled_mma.get_slice(first_thread_in_sg_idx);
176176

applications/flash_attention_v2/collective/xe_flash_attn_decode_epilogue.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,8 +168,8 @@ class FlashDecodeEpilogue<epilogue::IntelXeXMX16, MMAOp_, TileShapeOutput_, Subg
168168
using namespace cute;
169169
static constexpr bool is_var_len = cutlass::fmha::collective::is_variable_length_v<tuple_element_t<2, ProblemShape>>;
170170

171-
auto sg = cutlasscompat::get_nd_item<1>().get_sub_group();
172-
auto group = cutlasscompat::get_nd_item<1>().get_group();
171+
auto sg = compat::get_nd_item<1>().get_sub_group();
172+
auto group = compat::get_nd_item<1>().get_group();
173173
const int sg_local_id = sg.get_local_id()[0];
174174
const int sg_group_id = sg.get_group_id()[0];
175175

applications/flash_attention_v2/collective/xe_flash_attn_decode_mma.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ struct FlashDecodeMma<gemm::MainloopIntelXeXMX16<Stages>, ProblemShapeType_, Ele
227227
auto thr_copy_K = gmem_tiled_copy_k.get_slice(thread_idx);
228228
// Instantiate the MMA object
229229
TiledMmaQK tiled_mma;
230-
auto sg = cutlasscompat::get_nd_item<1>().get_sub_group();
230+
auto sg = compat::get_nd_item<1>().get_sub_group();
231231
auto first_thread_in_sg_idx = sg.get_group_id()[0] * DispatchPolicy::SubgroupSize;
232232
// For Normal Attention, K matrix tile_id = subgroup_id (cache and new both)
233233
// For Paged Attention, K matrix tile_id = page_table[subgroup_id] (cache, new keys follow normal attention)
@@ -316,7 +316,7 @@ struct FlashDecodeMma<gemm::MainloopIntelXeXMX16<Stages>, ProblemShapeType_, Ele
316316
int thread_idx = static_cast<int>(ThreadIdxX());
317317
// Instantiate the MMA object
318318
TiledMmaPV tiled_mma;
319-
auto sg = cutlasscompat::get_nd_item<1>().get_sub_group();
319+
auto sg = compat::get_nd_item<1>().get_sub_group();
320320
auto thread_mma = tiled_mma.get_slice(0);
321321
// convert X*512|1024 to 32*64*x*8|16 and use (_, sg.get_group_id()[0] / ATOM_N) to index in the (x,8|16) coordinate
322322
Tensor gV_ = take<0,3>(local_tile(gV, select<1,2>(SubgroupTileShapePV{}), make_coord(_, kv_tile_idx)));

applications/flash_attention_v2/collective/xe_flash_attn_decode_softmax_epilogue.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ class FlashDecodeSoftmaxEpilogue<CausalMask_, epilogue::IntelXeXMX16, Element_>
107107

108108
template <int FragsN, class FragAcc, class FragMax, class FragSum>
109109
CUTLASS_DEVICE void scale_exp_log2(FragAcc &frag_s, FragMax const &max, FragSum &sum) {
110-
auto sg = cutlasscompat::get_nd_item<1>().get_sub_group();
110+
auto sg = compat::get_nd_item<1>().get_sub_group();
111111
const auto max_scale = max * params.scale;
112112
const auto max_scale_bcast = group_broadcast(sg, max_scale, 0);
113113
CUTLASS_PRAGMA_UNROLL
@@ -120,8 +120,8 @@ class FlashDecodeSoftmaxEpilogue<CausalMask_, epilogue::IntelXeXMX16, Element_>
120120

121121
template <int Num_SGs, int FragsN, class FragSrc, class STensorMax>
122122
CUTLASS_DEVICE void reduce_max(FragSrc &src, STensorMax &stensor_max, Element& max_val) {
123-
auto sg = cutlasscompat::get_nd_item<1>().get_sub_group();
124-
auto group = cutlasscompat::get_nd_item<1>().get_group();
123+
auto sg = compat::get_nd_item<1>().get_sub_group();
124+
auto group = compat::get_nd_item<1>().get_group();
125125
const int sg_group_id = sg.get_group_id()[0];
126126
const int sg_local_id = sg.get_local_id()[0];
127127

@@ -163,7 +163,7 @@ class FlashDecodeSoftmaxEpilogue<CausalMask_, epilogue::IntelXeXMX16, Element_>
163163
reduce_max<Num_SGs,FragsNS>(frag_s, shmem_tensor_max, max_val);
164164

165165
if (!is_first) {
166-
auto sg = cutlasscompat::get_nd_item<1>().get_sub_group();
166+
auto sg = compat::get_nd_item<1>().get_sub_group();
167167
const int sg_group_id = sg.get_group_id()[0];
168168
const int sg_local_id = sg.get_local_id()[0];
169169
const int sg_size = sg.get_local_range()[0];

applications/flash_attention_v2/collective/xe_flash_attn_prefill_epilogue.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ class FlashPrefillEpilogue<epilogue::IntelXeXMX16, MMAOperation_, TileShapeOutp
163163
constexpr int FragsM = shape<1>(FragOutLayout{});
164164
constexpr int FragsN = size(select<2,3>(shape(FragOutLayout{})));
165165

166-
auto g = cutlasscompat::get_nd_item<1>().get_sub_group();
166+
auto g = compat::get_nd_item<1>().get_sub_group();
167167
auto out_reg = make_tensor(static_cast<decltype(out) &&>(out).data() , Shape<Int<Vec>, Int<FragsM>, Int<FragsN>>{});
168168

169169
CUTLASS_PRAGMA_UNROLL

applications/flash_attention_v2/collective/xe_flash_attn_prefill_epilogue_cachedKV.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ class FlashPrefillCachedEpilogue<epilogue::IntelXeXMX16, MMAOperation_, TileSha
164164
constexpr int FragsM = shape<1>(FragOutLayout{});
165165
constexpr int FragsN = size(select<2,3>(shape(FragOutLayout{})));
166166

167-
auto g = cutlasscompat::get_nd_item<1>().get_sub_group();
167+
auto g = compat::get_nd_item<1>().get_sub_group();
168168
auto out_reg = make_tensor(static_cast<decltype(out) &&>(out).data() , Shape<Int<Vec>, Int<FragsM>, Int<FragsN>>{});
169169

170170
CUTLASS_PRAGMA_UNROLL

applications/flash_attention_v2/collective/xe_flash_attn_prefill_mma.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ struct FlashPrefillMma<gemm::MainloopIntelXeXMX16<Stages>, ProblemShapeType_, El
198198
// Instantiate the MMA object
199199
TiledMmaQK tiled_mma;
200200
// To make all threads in a warp have the same global tensors pass in the index of thread 0 in each warp
201-
auto sg = cutlasscompat::get_nd_item<1>().get_sub_group();
201+
auto sg = compat::get_nd_item<1>().get_sub_group();
202202
auto first_thread_in_sg_idx = sg.get_group_id()[0] * DispatchPolicy::SubgroupSize;
203203
auto thread_mma_k = tiled_mma.get_slice(0);
204204
auto thread_mma_q = tiled_mma.get_slice(first_thread_in_sg_idx);
@@ -283,7 +283,7 @@ struct FlashPrefillMma<gemm::MainloopIntelXeXMX16<Stages>, ProblemShapeType_, El
283283
TiledMmaPV tiled_mma;
284284
// Tile GV to the shape of <64,64> and loop over the HeadSize/64 to avoid Register spill
285285
Tensor gV_ = take<0,3>(local_tile(gV, select<1,2>(TileShapePV{}), make_coord(_, _)));
286-
auto sg = cutlasscompat::get_nd_item<1>().get_sub_group();
286+
auto sg = compat::get_nd_item<1>().get_sub_group();
287287
auto first_thread_in_sg_idx = sg.get_group_id()[0] * DispatchPolicy::SubgroupSize;
288288
auto thread_mma = tiled_mma.get_slice(first_thread_in_sg_idx);
289289
Tensor tCgV = thread_mma.partition_B(gV_);

applications/flash_attention_v2/collective/xe_flash_attn_prefill_mma_cachedKV.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ struct FlashPrefillCachedMma<gemm::MainloopIntelXeXMX16<Stages>, ProblemShapeTyp
218218
// Instantiate the MMA object
219219
TiledMmaQK tiled_mma;
220220
// To make all threads in a warp have the same global tensors pass in the index of thread 0 in each warp
221-
auto sg = cutlasscompat::get_nd_item<1>().get_sub_group();
221+
auto sg = compat::get_nd_item<1>().get_sub_group();
222222
auto first_thread_in_sg_idx = sg.get_group_id()[0] * DispatchPolicy::SubgroupSize;
223223
auto thread_mma_k = tiled_mma.get_slice(0);
224224
auto thread_mma_q = tiled_mma.get_slice(first_thread_in_sg_idx);
@@ -286,7 +286,7 @@ struct FlashPrefillCachedMma<gemm::MainloopIntelXeXMX16<Stages>, ProblemShapeTyp
286286
TiledMmaPV tiled_mma;
287287
// Tile GV to the shape of <64,64> and loop over the HeadSize/64 to avoid Register spill
288288
Tensor gV_ = take<0,3>(local_tile(gV, select<1,2>(TileShapePV{}), make_coord(_, _)));
289-
auto sg = cutlasscompat::get_nd_item<1>().get_sub_group();
289+
auto sg = compat::get_nd_item<1>().get_sub_group();
290290
auto first_thread_in_sg_idx = sg.get_group_id()[0] * DispatchPolicy::SubgroupSize;
291291
auto thread_mma = tiled_mma.get_slice(first_thread_in_sg_idx);
292292
Tensor tCgV = thread_mma.partition_B(gV_);

applications/flash_attention_v2/collective/xe_flash_attn_prefill_softmax_epilogue.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ class FlashPrefillSoftmaxEpilogue<CausalMask_, epilogue::IntelXeXMX16, Element_>
107107

108108
template <int Vec, int FragsM, int FragsN, class FragAcc, class FragMax, class FragSum>
109109
CUTLASS_DEVICE void scale_exp_log2(FragAcc &frag_s, FragMax const &max, FragSum &sum) {
110-
auto g = cutlasscompat::get_nd_item<1>().get_sub_group();
110+
auto g = compat::get_nd_item<1>().get_sub_group();
111111
const auto max_scale = max * params.scale;
112112
CUTLASS_PRAGMA_UNROLL
113113
for (int indx = 0; indx < Vec * FragsM; indx++) {
@@ -124,7 +124,7 @@ class FlashPrefillSoftmaxEpilogue<CausalMask_, epilogue::IntelXeXMX16, Element_>
124124

125125
template <int Vec, int FragsM, int FragsN, class FragSrc, class FragMax>
126126
CUTLASS_DEVICE void reduce_max(FragSrc &src, FragMax &max) {
127-
auto g = cutlasscompat::get_nd_item<1>().get_sub_group();
127+
auto g = compat::get_nd_item<1>().get_sub_group();
128128
CUTLASS_PRAGMA_UNROLL
129129
for (int indx = 0; indx < Vec * FragsM; indx++) {
130130
auto maxptr = group_broadcast(g, max, indx);
@@ -153,7 +153,7 @@ class FlashPrefillSoftmaxEpilogue<CausalMask_, epilogue::IntelXeXMX16, Element_>
153153
reduce_max<Vec, FragsM, FragsNAcc>(frag_s, max);
154154
static_assert(Vec * FragsM % 8 ==0, " No. of attention rows per subgroup should be >= 1 MMA Atom worth of rows.");
155155
if (!is_first) {
156-
auto g = cutlasscompat::get_nd_item<1>().get_sub_group();
156+
auto g = compat::get_nd_item<1>().get_sub_group();
157157
Element max_scale{max * params.scale};
158158
Element exp_scale{sycl::native::exp2(max_prev * params.scale - max_scale)};
159159
CUTLASS_PRAGMA_UNROLL

applications/flash_attention_v2/kernel/tile_scheduler.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ struct XeFlashPersistentTileScheduler {
193193

194194
template <int Num_SGs>
195195
static dim3 get_grid_shape(Params const& params) {
196-
auto queue = cutlasscompat::get_default_queue();
196+
auto queue = compat::get_default_queue();
197197
auto dev = queue.get_device();
198198
const size_t maxSubgroups =
199199
dev.template get_info<sycl::info::device::max_num_sub_groups>();

0 commit comments

Comments
 (0)