@@ -261,8 +261,8 @@ bool verify(ProblemShapeType problem_size, Options options) {
261261 int num_pages = paged_kv_cache.page_table .size ();
262262 std::vector<int > host_page_table (paged_kv_cache.page_table .size ());
263263 std::vector<int > host_num_pages_per_seq (paged_kv_cache.num_pages_per_seq .size ());
264- syclcompat ::memcpy<int >(host_page_table.data (), paged_kv_cache.page_table .get (), paged_kv_cache.page_table .size ());
265- syclcompat ::memcpy<int >(host_num_pages_per_seq.data (), paged_kv_cache.num_pages_per_seq .get (), paged_kv_cache.num_pages_per_seq .size ());
264+ compat ::memcpy<int >(host_page_table.data (), paged_kv_cache.page_table .get (), paged_kv_cache.page_table .size ());
265+ compat ::memcpy<int >(host_num_pages_per_seq.data (), paged_kv_cache.num_pages_per_seq .get (), paged_kv_cache.num_pages_per_seq .size ());
266266
267267 int curr_batch_pages = isVarLen ? host_num_pages_per_seq[b + 1 ] - host_num_pages_per_seq[b] : ceil_div (seq_len_kv_cache, paged_kv_cache.page_size );
268268 int batch_offset = isVarLen ? host_num_pages_per_seq[b] : b * curr_batch_pages;
@@ -272,57 +272,57 @@ bool verify(ProblemShapeType problem_size, Options options) {
272272 for (int p = 0 ; p < curr_batch_pages; p++) {
273273 int page_idx = host_page_table[batch_offset + p];
274274 // copy the page from KV cache to the concatenated buffer
275- syclcompat ::memcpy<ElementK>(
275+ compat ::memcpy<ElementK>(
276276 block_K_concat.get () + p * paged_kv_cache.page_size * num_heads_kv * head_size_qk,
277277 block_K_cache.get () + page_idx * paged_kv_cache.page_size * num_heads_kv * head_size_qk,
278278 paged_kv_cache.page_size * num_heads_kv * head_size_qk
279279 );
280- syclcompat ::memcpy<ElementV>(
280+ compat ::memcpy<ElementV>(
281281 block_V_concat.get () + p * paged_kv_cache.page_size * num_heads_kv * head_size_vo,
282282 block_V_cache.get () + page_idx * paged_kv_cache.page_size * num_heads_kv * head_size_vo,
283283 paged_kv_cache.page_size * num_heads_kv * head_size_vo
284284 );
285285 }
286286 if (seq_len_kv > 0 ) {
287- syclcompat ::memcpy<ElementK>(
287+ compat ::memcpy<ElementK>(
288288 // block_K_concat.get() + curr_batch_pages * paged_kv_cache.page_sze * num_heads_kv *head_size_qk,
289289 block_K_concat.get () + seq_len_kv_cache * num_heads_kv * head_size_qk,
290290 block_K.get () + offset_k,
291291 seq_len_kv * num_heads_kv * head_size_qk
292292 );
293- syclcompat ::memcpy<ElementV>(
293+ compat ::memcpy<ElementV>(
294294 block_V_concat.get () + seq_len_kv_cache * num_heads_kv * head_size_vo,
295295 block_V.get () + offset_v,
296296 seq_len_kv * num_heads_kv * head_size_vo
297297 );
298298 }
299- syclcompat ::wait ();
299+ compat ::wait ();
300300 } else {
301301 block_K_concat.reset (seq_len_kv_total * num_heads_kv * head_size_qk);
302302 block_V_concat.reset (seq_len_kv_total * num_heads_kv * head_size_vo);
303303 // Concatenate K_cache and K
304- syclcompat ::memcpy<ElementK>(
304+ compat ::memcpy<ElementK>(
305305 block_K_concat.get (),
306306 block_K_cache.get () + offset_k_cache,
307307 seq_len_kv_cache * num_heads_kv * head_size_qk
308308 );
309- syclcompat ::memcpy<ElementK>(
309+ compat ::memcpy<ElementK>(
310310 block_K_concat.get () + seq_len_kv_cache * num_heads_kv * head_size_qk,
311311 block_K.get () + offset_k,
312312 seq_len_kv * num_heads_kv * head_size_qk
313313 );
314314 // Concatenate V_cache and V
315- syclcompat ::memcpy<ElementV>(
315+ compat ::memcpy<ElementV>(
316316 block_V_concat.get (),
317317 block_V_cache.get () + offset_v_cache,
318318 seq_len_kv_cache * num_heads_kv * head_size_vo
319319 );
320- syclcompat ::memcpy<ElementV>(
320+ compat ::memcpy<ElementV>(
321321 block_V_concat.get () + seq_len_kv_cache * num_heads_kv * head_size_vo,
322322 block_V.get () + offset_v,
323323 seq_len_kv * num_heads_kv * head_size_vo
324324 );
325- // syclcompat ::wait();
325+ // compat ::wait();
326326 }
327327 k_ptr = block_K_concat.get ();
328328 v_ptr = block_V_concat.get ();
@@ -350,9 +350,9 @@ bool verify(ProblemShapeType problem_size, Options options) {
350350 seq_len_qo * seq_len_kv_total, // batch_stride_S
351351 seq_len_qo * seq_len_kv_total // batch_stride_S
352352 );
353- syclcompat ::wait ();
353+ compat ::wait ();
354354 std::vector<ElementAccumulator> host_S (block_S.size ());
355- syclcompat ::memcpy<ElementAccumulator>(host_S.data (), block_S.get (), host_S.size ());
355+ compat ::memcpy<ElementAccumulator>(host_S.data (), block_S.get (), host_S.size ());
356356
357357 // delete this memory as it is no longer needed
358358 block_S.reset ();
@@ -427,7 +427,7 @@ bool verify(ProblemShapeType problem_size, Options options) {
427427 cutlass::DeviceAllocation<ElementV> block_P;
428428 block_P.reset (host_P.size ());
429429
430- syclcompat ::memcpy<ElementV>(block_P.get (), host_P.data (), host_P.size ());
430+ compat ::memcpy<ElementV>(block_P.get (), host_P.data (), host_P.size ());
431431
432432 cutlass::TensorRef ref_P (block_P.get (), LayoutQ::packed ({seq_len_qo, seq_len_kv_total}));
433433
@@ -445,12 +445,12 @@ bool verify(ProblemShapeType problem_size, Options options) {
445445 seq_len_qo * head_size_vo // batch_stride_O
446446 );
447447
448- syclcompat ::wait ();
448+ compat ::wait ();
449449 // delete this memory as it is no longer needed
450450 block_P.reset ();
451451
452452 std::vector<ElementAccumulator> vec_acc (block_acc.size ());
453- syclcompat ::memcpy<ElementAccumulator>(vec_acc.data (), block_acc.get (), vec_acc.size ());
453+ compat ::memcpy<ElementAccumulator>(vec_acc.data (), block_acc.get (), vec_acc.size ());
454454
455455 // delete this memory as it is no longer needed
456456 block_acc.reset ();
@@ -475,8 +475,8 @@ bool verify(ProblemShapeType problem_size, Options options) {
475475 offset_o += seq_len_qo * num_heads_q * head_size_vo;
476476 } // end of batch loop
477477
478- syclcompat ::wait ();
479- syclcompat ::memcpy<ElementOutput>(block_ref_O.get (), host_O.data (), host_O.size ());
478+ compat ::wait ();
479+ compat ::memcpy<ElementOutput>(block_ref_O.get (), host_O.data (), host_O.size ());
480480 // Check if output from CUTLASS kernel and reference kernel are equal or not
481481 bool passed = cutlass::reference::device::BlockCompareRelativelyEqual (block_ref_O.get (), block_O.get (),
482482 block_O.size (), ElementOutput{0.5 }, ElementOutput{0.5 });
@@ -623,10 +623,10 @@ bool verify(ProblemShapeType problem_size, Options options) {
623623 page_mapping[logical_idx] = physical_pages[blk];
624624 }
625625 }
626- syclcompat ::memcpy (paged_kv_cache.page_table .get (), page_mapping.data (), page_mapping.size () * sizeof (int ));
626+ compat ::memcpy (paged_kv_cache.page_table .get (), page_mapping.data (), page_mapping.size () * sizeof (int ));
627627
628628 paged_kv_cache.num_pages_per_seq .reset (num_pages_per_seq.size ());
629- syclcompat ::memcpy (paged_kv_cache.num_pages_per_seq .get (), num_pages_per_seq.data (), num_pages_per_seq.size () * sizeof (int ));
629+ compat ::memcpy (paged_kv_cache.num_pages_per_seq .get (), num_pages_per_seq.data (), num_pages_per_seq.size () * sizeof (int ));
630630
631631 block_K_cache.reset (num_pages * paged_kv_cache.page_size * num_heads_kv * head_size_qk);
632632 block_V_cache.reset (num_pages * paged_kv_cache.page_size * num_heads_kv * head_size_vo);
@@ -683,25 +683,25 @@ bool verify(ProblemShapeType problem_size, Options options) {
683683 // configure smem size and carveout
684684 int smem_size = FMHAChunkPrefillKernel::SharedStorageSize;
685685
686- const auto sycl_block = syclcompat ::dim3 (block.x , block.y , block.z );
687- const auto sycl_grid = syclcompat ::dim3 (grid.x , grid.y , grid.z );
686+ const auto sycl_block = compat ::dim3 (block.x , block.y , block.z );
687+ const auto sycl_grid = compat ::dim3 (grid.x , grid.y , grid.z );
688688
689689// Launch parameters depend on whether SYCL compiler supports work-group scratch memory extension
690690#if !defined(SYCL_EXT_ONEAPI_WORK_GROUP_SCRATCH_MEMORY)
691- using namespace syclcompat ::experimental;
691+ using namespace compat ::experimental;
692692 auto event = launch<cutlass::device_kernel<FMHAChunkPrefillKernel>>(
693693 launch_policy{sycl_grid, sycl_block, local_mem_size{static_cast <std::size_t >(smem_size)},
694694 kernel_properties{sycl_exp::sub_group_size<FMHAChunkPrefillKernel::DispatchPolicy::SubgroupSize>}},
695695 params);
696696#else
697- syclcompat ::experimental::launch_properties launch_props {
697+ compat ::experimental::launch_properties launch_props {
698698 sycl::ext::oneapi::experimental::work_group_scratch_size (smem_size),
699699 };
700- syclcompat ::experimental::kernel_properties kernel_props{
700+ compat ::experimental::kernel_properties kernel_props{
701701 sycl::ext::oneapi::experimental::sub_group_size<FMHAChunkPrefillKernel::DispatchPolicy::SubgroupSize>
702702 };
703- syclcompat ::experimental::launch_policy policy{sycl_grid, sycl_block, launch_props, kernel_props};
704- auto event = syclcompat ::experimental::launch<cutlass::device_kernel<FMHAChunkPrefillKernel>>(policy, params);
703+ compat ::experimental::launch_policy policy{sycl_grid, sycl_block, launch_props, kernel_props};
704+ auto event = compat ::experimental::launch<cutlass::device_kernel<FMHAChunkPrefillKernel>>(policy, params);
705705#endif
706706
707707 EventManager::getInstance ().addEvent (event);
@@ -748,7 +748,7 @@ bool verify(ProblemShapeType problem_size, Options options) {
748748 // Run the Flash Attention implementation.
749749 run (params);
750750
751- syclcompat ::wait ();
751+ compat ::wait ();
752752
753753 // Verify that the result is correct
754754 bool passed = verify (problem_size, options);
@@ -764,7 +764,7 @@ bool verify(ProblemShapeType problem_size, Options options) {
764764 for (int i = 0 ; i < options.iterations ; ++i) {
765765 run (params);
766766 }
767- syclcompat ::wait ();
767+ compat ::wait ();
768768
769769 auto offset = cute::min (options.seq_len_qo , options.seq_len_kv );
770770 auto discard_seq_coord = options.seq_len_qo - offset;
0 commit comments