Skip to content

Commit b92119c

Browse files
committed
complt
1 parent 403cccd commit b92119c

File tree

5 files changed

+37
-37
lines changed

5 files changed

+37
-37
lines changed

applications/flash_attention_v2/collective/xe_flash_attn_chunk_prefill_epilogue.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ class FlashChunkPrefillEpilogue<epilogue::IntelXeXMX16, MMAOperation_, TileShap
171171
constexpr int FragsM = shape<1>(FragOutLayout{});
172172
constexpr int FragsN = size(select<2,3>(shape(FragOutLayout{})));
173173

174-
auto sg = syclcompat::get_nd_item<1>().get_sub_group();
174+
auto sg = compat::get_nd_item<1>().get_sub_group();
175175
auto out_reg = make_tensor(static_cast<decltype(out) &&>(out).data() , Shape<Int<Vec>, Int<FragsM>, Int<FragsN>>{});
176176

177177
CUTLASS_PRAGMA_UNROLL

applications/flash_attention_v2/collective/xe_flash_attn_chunk_prefill_mma.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ struct FlashChunkPrefillMma<
273273
TiledMmaQK tiled_mma;
274274
// To make all threads in a warp have the same global tensors pass in the
275275
// index of thread 0 in each warp
276-
auto sg = syclcompat::get_nd_item<1>().get_sub_group();
276+
auto sg = compat::get_nd_item<1>().get_sub_group();
277277
auto first_thread_in_sg_idx =
278278
sg.get_group_id()[0] * DispatchPolicy::SubgroupSize;
279279
auto thread_mma_q = tiled_mma.get_slice(first_thread_in_sg_idx);
@@ -361,7 +361,7 @@ struct FlashChunkPrefillMma<
361361
// Register spill
362362
Tensor gV_ = take<0, 3>(
363363
local_tile(gV, select<1, 2>(TileShapePV{}), make_coord(_, _)));
364-
auto sg = syclcompat::get_nd_item<1>().get_sub_group();
364+
auto sg = compat::get_nd_item<1>().get_sub_group();
365365
auto first_thread_in_sg_idx =
366366
sg.get_group_id()[0] * DispatchPolicy::SubgroupSize;
367367
auto thread_mma = tiled_mma.get_slice(first_thread_in_sg_idx);

applications/flash_attention_v2/collective/xe_flash_attn_chunk_prefill_softmax_epilogue.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ class FlashChunkPrefillSoftmaxEpilogue<CausalMask_, LocalMask_, epilogue::IntelX
107107

108108
template <int Vec, int FragsM, int FragsN, class FragAcc, class FragMax, class FragSum>
109109
CUTLASS_DEVICE void scale_exp_log2(FragAcc &frag_s, FragMax const &max, FragSum &sum) {
110-
auto g = syclcompat::get_nd_item<1>().get_sub_group();
110+
auto g = compat::get_nd_item<1>().get_sub_group();
111111
const auto max_scale = max * params.scale;
112112
CUTLASS_PRAGMA_UNROLL
113113
for (int indx = 0; indx < Vec * FragsM; indx++) {
@@ -135,7 +135,7 @@ class FlashChunkPrefillSoftmaxEpilogue<CausalMask_, LocalMask_, epilogue::IntelX
135135

136136
template <int Vec, int FragsM, int FragsN, class FragSrc, class FragMax>
137137
CUTLASS_DEVICE void reduce_max(FragSrc &src, FragMax &max) {
138-
auto sg = syclcompat::get_nd_item<1>().get_sub_group();
138+
auto sg = compat::get_nd_item<1>().get_sub_group();
139139
CUTLASS_PRAGMA_UNROLL
140140
for (int indx = 0; indx < Vec * FragsM; indx++) {
141141
auto maxptr = group_broadcast(sg, max, indx);
@@ -164,7 +164,7 @@ class FlashChunkPrefillSoftmaxEpilogue<CausalMask_, LocalMask_, epilogue::IntelX
164164
reduce_max<Vec, FragsM, FragsNAcc>(frag_s, max);
165165
static_assert(Vec * FragsM % 8 == 0, " No. of attention rows per subgroup should be >= 1 MMA Atom worth of rows.");
166166
if (!is_first) {
167-
auto sg = syclcompat::get_nd_item<1>().get_sub_group();
167+
auto sg = compat::get_nd_item<1>().get_sub_group();
168168
Element max_scale{max * params.scale};
169169
Element exp_scale;
170170
if constexpr (LocalMask) {

applications/flash_attention_v2/kernel/tile_scheduler_chunk_prefill.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ struct XeFlashPersistentTileScheduler {
161161
}
162162

163163
template <int Num_SGs> static dim3 get_grid_shape(Params const &params) {
164-
auto queue = syclcompat::get_default_queue();
164+
auto queue = compat::get_default_queue();
165165
auto dev = queue.get_device();
166166
const size_t maxSubgroups =
167167
dev.template get_info<sycl::info::device::max_num_sub_groups>();

examples/06_bmg_flash_attention/bmg_flash_chunk_prefill_runner.hpp

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -261,8 +261,8 @@ bool verify(ProblemShapeType problem_size, Options options) {
261261
int num_pages = paged_kv_cache.page_table.size();
262262
std::vector<int> host_page_table(paged_kv_cache.page_table.size());
263263
std::vector<int> host_num_pages_per_seq(paged_kv_cache.num_pages_per_seq.size());
264-
syclcompat::memcpy<int>(host_page_table.data(), paged_kv_cache.page_table.get(), paged_kv_cache.page_table.size());
265-
syclcompat::memcpy<int>(host_num_pages_per_seq.data(), paged_kv_cache.num_pages_per_seq.get(), paged_kv_cache.num_pages_per_seq.size());
264+
compat::memcpy<int>(host_page_table.data(), paged_kv_cache.page_table.get(), paged_kv_cache.page_table.size());
265+
compat::memcpy<int>(host_num_pages_per_seq.data(), paged_kv_cache.num_pages_per_seq.get(), paged_kv_cache.num_pages_per_seq.size());
266266

267267
int curr_batch_pages = isVarLen ? host_num_pages_per_seq[b + 1] - host_num_pages_per_seq[b] : ceil_div(seq_len_kv_cache, paged_kv_cache.page_size);
268268
int batch_offset = isVarLen ? host_num_pages_per_seq[b] : b * curr_batch_pages;
@@ -272,57 +272,57 @@ bool verify(ProblemShapeType problem_size, Options options) {
272272
for (int p = 0; p < curr_batch_pages; p++) {
273273
int page_idx = host_page_table[batch_offset + p];
274274
// copy the page from KV cache to the concatenated buffer
275-
syclcompat::memcpy<ElementK>(
275+
compat::memcpy<ElementK>(
276276
block_K_concat.get() + p * paged_kv_cache.page_size * num_heads_kv * head_size_qk,
277277
block_K_cache.get() + page_idx * paged_kv_cache.page_size * num_heads_kv * head_size_qk,
278278
paged_kv_cache.page_size * num_heads_kv * head_size_qk
279279
);
280-
syclcompat::memcpy<ElementV>(
280+
compat::memcpy<ElementV>(
281281
block_V_concat.get() + p * paged_kv_cache.page_size * num_heads_kv * head_size_vo,
282282
block_V_cache.get() + page_idx * paged_kv_cache.page_size * num_heads_kv * head_size_vo,
283283
paged_kv_cache.page_size * num_heads_kv * head_size_vo
284284
);
285285
}
286286
if (seq_len_kv > 0) {
287-
syclcompat::memcpy<ElementK>(
287+
compat::memcpy<ElementK>(
288288
// block_K_concat.get() + curr_batch_pages * paged_kv_cache.page_sze * num_heads_kv *head_size_qk,
289289
block_K_concat.get() + seq_len_kv_cache * num_heads_kv * head_size_qk,
290290
block_K.get() + offset_k,
291291
seq_len_kv * num_heads_kv * head_size_qk
292292
);
293-
syclcompat::memcpy<ElementV>(
293+
compat::memcpy<ElementV>(
294294
block_V_concat.get() + seq_len_kv_cache * num_heads_kv * head_size_vo,
295295
block_V.get() + offset_v,
296296
seq_len_kv * num_heads_kv * head_size_vo
297297
);
298298
}
299-
syclcompat::wait();
299+
compat::wait();
300300
} else {
301301
block_K_concat.reset(seq_len_kv_total * num_heads_kv * head_size_qk);
302302
block_V_concat.reset(seq_len_kv_total * num_heads_kv * head_size_vo);
303303
// Concatenate K_cache and K
304-
syclcompat::memcpy<ElementK>(
304+
compat::memcpy<ElementK>(
305305
block_K_concat.get(),
306306
block_K_cache.get() + offset_k_cache,
307307
seq_len_kv_cache * num_heads_kv * head_size_qk
308308
);
309-
syclcompat::memcpy<ElementK>(
309+
compat::memcpy<ElementK>(
310310
block_K_concat.get() + seq_len_kv_cache * num_heads_kv * head_size_qk,
311311
block_K.get() + offset_k,
312312
seq_len_kv * num_heads_kv * head_size_qk
313313
);
314314
// Concatenate V_cache and V
315-
syclcompat::memcpy<ElementV>(
315+
compat::memcpy<ElementV>(
316316
block_V_concat.get(),
317317
block_V_cache.get() + offset_v_cache,
318318
seq_len_kv_cache * num_heads_kv * head_size_vo
319319
);
320-
syclcompat::memcpy<ElementV>(
320+
compat::memcpy<ElementV>(
321321
block_V_concat.get() + seq_len_kv_cache * num_heads_kv * head_size_vo,
322322
block_V.get() + offset_v,
323323
seq_len_kv * num_heads_kv * head_size_vo
324324
);
325-
// syclcompat::wait();
325+
// compat::wait();
326326
}
327327
k_ptr = block_K_concat.get();
328328
v_ptr = block_V_concat.get();
@@ -350,9 +350,9 @@ bool verify(ProblemShapeType problem_size, Options options) {
350350
seq_len_qo * seq_len_kv_total, // batch_stride_S
351351
seq_len_qo * seq_len_kv_total // batch_stride_S
352352
);
353-
syclcompat::wait();
353+
compat::wait();
354354
std::vector<ElementAccumulator> host_S(block_S.size());
355-
syclcompat::memcpy<ElementAccumulator>(host_S.data(), block_S.get(), host_S.size());
355+
compat::memcpy<ElementAccumulator>(host_S.data(), block_S.get(), host_S.size());
356356

357357
// delete this memory as it is no longer needed
358358
block_S.reset();
@@ -427,7 +427,7 @@ bool verify(ProblemShapeType problem_size, Options options) {
427427
cutlass::DeviceAllocation<ElementV> block_P;
428428
block_P.reset(host_P.size());
429429

430-
syclcompat::memcpy<ElementV>(block_P.get(), host_P.data(), host_P.size());
430+
compat::memcpy<ElementV>(block_P.get(), host_P.data(), host_P.size());
431431

432432
cutlass::TensorRef ref_P(block_P.get(), LayoutQ::packed({seq_len_qo, seq_len_kv_total}));
433433

@@ -445,12 +445,12 @@ bool verify(ProblemShapeType problem_size, Options options) {
445445
seq_len_qo * head_size_vo // batch_stride_O
446446
);
447447

448-
syclcompat::wait();
448+
compat::wait();
449449
// delete this memory as it is no longer needed
450450
block_P.reset();
451451

452452
std::vector<ElementAccumulator> vec_acc(block_acc.size());
453-
syclcompat::memcpy<ElementAccumulator>(vec_acc.data(), block_acc.get(), vec_acc.size());
453+
compat::memcpy<ElementAccumulator>(vec_acc.data(), block_acc.get(), vec_acc.size());
454454

455455
// delete this memory as it is no longer needed
456456
block_acc.reset();
@@ -475,8 +475,8 @@ bool verify(ProblemShapeType problem_size, Options options) {
475475
offset_o += seq_len_qo * num_heads_q * head_size_vo;
476476
} // end of batch loop
477477

478-
syclcompat::wait();
479-
syclcompat::memcpy<ElementOutput>(block_ref_O.get(), host_O.data(), host_O.size());
478+
compat::wait();
479+
compat::memcpy<ElementOutput>(block_ref_O.get(), host_O.data(), host_O.size());
480480
// Check if output from CUTLASS kernel and reference kernel are equal or not
481481
bool passed = cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_O.get(), block_O.get(),
482482
block_O.size(), ElementOutput{0.5}, ElementOutput{0.5});
@@ -623,10 +623,10 @@ bool verify(ProblemShapeType problem_size, Options options) {
623623
page_mapping[logical_idx] = physical_pages[blk];
624624
}
625625
}
626-
syclcompat::memcpy(paged_kv_cache.page_table.get(), page_mapping.data(), page_mapping.size() * sizeof(int));
626+
compat::memcpy(paged_kv_cache.page_table.get(), page_mapping.data(), page_mapping.size() * sizeof(int));
627627

628628
paged_kv_cache.num_pages_per_seq.reset(num_pages_per_seq.size());
629-
syclcompat::memcpy(paged_kv_cache.num_pages_per_seq.get(), num_pages_per_seq.data(), num_pages_per_seq.size() * sizeof(int));
629+
compat::memcpy(paged_kv_cache.num_pages_per_seq.get(), num_pages_per_seq.data(), num_pages_per_seq.size() * sizeof(int));
630630

631631
block_K_cache.reset(num_pages * paged_kv_cache.page_size * num_heads_kv * head_size_qk);
632632
block_V_cache.reset(num_pages * paged_kv_cache.page_size * num_heads_kv * head_size_vo);
@@ -683,25 +683,25 @@ bool verify(ProblemShapeType problem_size, Options options) {
683683
// configure smem size and carveout
684684
int smem_size = FMHAChunkPrefillKernel::SharedStorageSize;
685685

686-
const auto sycl_block = syclcompat::dim3(block.x, block.y, block.z);
687-
const auto sycl_grid = syclcompat::dim3(grid.x, grid.y, grid.z);
686+
const auto sycl_block = compat::dim3(block.x, block.y, block.z);
687+
const auto sycl_grid = compat::dim3(grid.x, grid.y, grid.z);
688688

689689
// Launch parameters depend on whether SYCL compiler supports work-group scratch memory extension
690690
#if !defined(SYCL_EXT_ONEAPI_WORK_GROUP_SCRATCH_MEMORY)
691-
using namespace syclcompat::experimental;
691+
using namespace compat::experimental;
692692
auto event = launch<cutlass::device_kernel<FMHAChunkPrefillKernel>>(
693693
launch_policy{sycl_grid, sycl_block, local_mem_size{static_cast<std::size_t>(smem_size)},
694694
kernel_properties{sycl_exp::sub_group_size<FMHAChunkPrefillKernel::DispatchPolicy::SubgroupSize>}},
695695
params);
696696
#else
697-
syclcompat::experimental::launch_properties launch_props {
697+
compat::experimental::launch_properties launch_props {
698698
sycl::ext::oneapi::experimental::work_group_scratch_size(smem_size),
699699
};
700-
syclcompat::experimental::kernel_properties kernel_props{
700+
compat::experimental::kernel_properties kernel_props{
701701
sycl::ext::oneapi::experimental::sub_group_size<FMHAChunkPrefillKernel::DispatchPolicy::SubgroupSize>
702702
};
703-
syclcompat::experimental::launch_policy policy{sycl_grid, sycl_block, launch_props, kernel_props};
704-
auto event = syclcompat::experimental::launch<cutlass::device_kernel<FMHAChunkPrefillKernel>>(policy, params);
703+
compat::experimental::launch_policy policy{sycl_grid, sycl_block, launch_props, kernel_props};
704+
auto event = compat::experimental::launch<cutlass::device_kernel<FMHAChunkPrefillKernel>>(policy, params);
705705
#endif
706706

707707
EventManager::getInstance().addEvent(event);
@@ -748,7 +748,7 @@ bool verify(ProblemShapeType problem_size, Options options) {
748748
// Run the Flash Attention implementation.
749749
run(params);
750750

751-
syclcompat::wait();
751+
compat::wait();
752752

753753
// Verify that the result is correct
754754
bool passed = verify(problem_size, options);
@@ -764,7 +764,7 @@ bool verify(ProblemShapeType problem_size, Options options) {
764764
for (int i = 0; i < options.iterations; ++i) {
765765
run(params);
766766
}
767-
syclcompat::wait();
767+
compat::wait();
768768

769769
auto offset = cute::min(options.seq_len_qo, options.seq_len_kv);
770770
auto discard_seq_coord = options.seq_len_qo - offset;

0 commit comments

Comments
 (0)