@@ -6,46 +6,49 @@ static constexpr bool is_arithmetic_v() {
66    return  std::is_arithmetic_v<T> || std::is_same_v<T, sycl::half> || std::is_same_v<T, sycl::ext::oneapi::bfloat16>;
77}
88}
9+ 
910template <typename  TIn, typename  TOut>
1011static  inline  std::enable_if_t <utils::is_arithmetic_v<TIn>() && utils::is_arithmetic_v<TOut>(), void >
1112convert  (const  char * src, char * dst) {
1213    auto  src_val = *reinterpret_cast <const  TIn*>(src);
1314    auto  dst_val = sycl::vec<TIn, 1 >(src_val).template  convert <TOut, sycl::rounding_mode::automatic>()[0 ];
14-    *reinterpret_cast <TOut*>(dst) = dst_val;; 
15+    *reinterpret_cast <TOut*>(dst) = dst_val;
1516}
1617
1718template <typename  TIn, typename  TOut>
1819static  void  k_set_rows (
1920        const  char  * __restrict__ src0, const  int64_t  * __restrict__ src1, char  * __restrict__ dst,
20-         const  int64_t  ne00, const  int64_t  ne01, const  int64_t  ne11, const  int64_t  ne12,
21+         const  int64_t  ne00, const  int64_t  ne01, const  int64_t  ne02,
22+         const  int64_t  ne11, const  int64_t  ne12,
2123        const  size_t  nb01, const  size_t  nb02, const  size_t  nb03,
2224        const  size_t  nb10, const  size_t  nb11, const  size_t  nb12,
2325        const  size_t  nb1, const  size_t  nb2, const  size_t  nb3,
2426        const  size_t  src_type_size, const  size_t  dst_type_size,
25-         const  sycl::nd_item<3 > & item_ct1) {
26- 
27-     const  int  i03 = item_ct1.get_group (0 );
28-     const  int  i02 = item_ct1.get_group (1 );
29-     const  int  i01 = item_ct1.get_group (2 ) * item_ct1.get_local_range (1 ) + item_ct1.get_local_id (1 );  //  Row index
27+         const  int64_t  total_elements,
28+         const  sycl::nd_item<1 > & item_ct1) {
3029
31-     if  (i01 >= ne01) {
30+     const  int64_t  i = item_ct1.get_global_linear_id ();
31+     if  (i >= total_elements) {
3232        return ;
3333    }
3434
35-     const  int  i12 = i03 % ne12;
36-     const  int  i11 = i02 % ne11;
37-     const  int  i10 = i01;
35+     const  int64_t  i03 = i / (ne00 * ne01 * ne02);
36+     const  int64_t  i02 = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);
37+     const  int64_t  i01 = (i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01) / ne00;
38+     const  int64_t  i00 = i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01 - i01 * ne00;
39+ 
40+     const  int64_t  i12 = i03 % ne12;
41+     const  int64_t  i11 = i02 % ne11;
42+     const  int64_t  i10 = i01;
3843
3944    const  int64_t  dst_row = *(const  int64_t  *)((const  char  *)src1 + calculate_offset<3 >({nb10, nb11, nb12}, {i10, i11, i12}));
4045
4146    const  char  * src0_row = src0 + calculate_offset<3 >({nb01, nb02, nb03}, {i01, i02, i03});
42-     char  * dst_row_ptr    = dst + dst_row*nb1 + i02*nb2 + i03*nb3;
47+     const  char  * src_elem = src0_row + i00 * src_type_size;
48+     char  * dst_row_ptr = dst + dst_row*nb1 + i02*nb2 + i03*nb3;
49+     char  * dst_elem = dst_row_ptr + i00 * dst_type_size;
4350
44-     for  (int  col = item_ct1.get_local_id (0 ); col < ne00; col += item_ct1.get_local_range (0 )) {
45-         const  char  * src_elem = src0_row + col * src_type_size;
46-         char  * dst_elem       = dst_row_ptr + col * dst_type_size;
47-         convert<TIn, TOut>(src_elem, dst_elem);
48-     }
51+     convert<TIn, TOut>(src_elem, dst_elem);
4952}
5053
5154template <typename  TIn, typename  TOut>
@@ -58,32 +61,29 @@ static void set_rows_sycl(
5861        const  size_t  src_type_size, const  size_t  dst_type_size,
5962        queue_ptr stream) {
6063
61-     constexpr  int  max_threads_per_row = 64 ; //  KEEPING 64 for now
62-     const  int  threads_per_row     = std::min ((int )ne00, max_threads_per_row);
63- 
64-     constexpr  int  max_threads_per_block = 64 ;
65-     const  int  rows_per_block        = std::max (1 , max_threads_per_block / threads_per_row);
66- 
67-     const  sycl::range<3 > block_size (1 , rows_per_block, threads_per_row);
68-     const  sycl::range<3 > grid_size (ne03, ne02, (ne01 + rows_per_block - 1 ) / rows_per_block);
69- 
70-         sycl_parallel_for (
71-             stream,
72-             sycl::nd_range<3 >(grid_size * block_size, block_size),
73-             [=](sycl::nd_item<3 > item_ct1) {
74-                 k_set_rows<TIn, TOut>(
75-                     src0_d, src1_d, dst_d,
76-                     ne00, ne01, ne11, ne12,
77-                     nb01, nb02, nb03,
78-                     nb10, nb11, nb12,
79-                     nb1, nb2, nb3,
80-                     src_type_size, dst_type_size,
81-                     item_ct1
82-                 );
83-             }
84-         );
85- }
64+     const  int64_t  total_elements = ne00 * ne01 * ne02 * ne03;
8665
66+     constexpr  int  block_size = 64 ;
67+     const  int64_t  grid_size = ceil_div (total_elements, block_size);
68+ 
69+     sycl_parallel_for (
70+         stream,
71+         sycl::nd_range<1 >(grid_size * block_size, block_size),
72+         [=](sycl::nd_item<1 > item_ct1) {
73+             k_set_rows<TIn, TOut>(
74+                 src0_d, src1_d, dst_d,
75+                 ne00, ne01, ne02,
76+                 ne11, ne12,
77+                 nb01, nb02, nb03,
78+                 nb10, nb11, nb12,
79+                 nb1, nb2, nb3,
80+                 src_type_size, dst_type_size,
81+                 total_elements,
82+                 item_ct1
83+             );
84+         }
85+     );
86+ }
8787
8888void  ggml_sycl_op_set_rows (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
8989    scope_op_debug_print scope_dbg_print (__func__, dst, /* num_src=*/ 2 );
@@ -122,7 +122,7 @@ void ggml_sycl_op_set_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
122122                nb1, nb2, nb3,
123123                sizeof (float ), sizeof (sycl::half),
124124                stream
125-         );
125+              );
126126            break ;
127127        default :
128128            GGML_ABORT (" Unsupported tensor type!" 
0 commit comments