@@ -89,24 +89,33 @@ static void concat_f32_sycl(const float *x, const float *y, float *dst,
8989 sycl::range<3 > gridDim (ne2, ne1, num_blocks);
9090 switch (dim) {
9191 case 0 :
92- sycl_parallel_for (stream,
93- sycl::nd_range<3 >(gridDim * sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE),
94- sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE)),
95- [=](sycl::nd_item<3 > item_ct1) { concat_f32_dim0 (x, y, dst, ne0, ne00, item_ct1); });
96- break ;
92+ stream->parallel_for (
93+ sycl::nd_range<3 >(gridDim *
94+ sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE),
95+ sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE)),
96+ [=](sycl::nd_item<3 > item_ct1) {
97+ concat_f32_dim0 (x, y, dst, ne0, ne00, item_ct1);
98+ });
99+ break ;
97100 case 1 :
98- sycl_parallel_for (stream,
99- sycl::nd_range<3 >(gridDim * sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE),
100- sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE)),
101- [=](sycl::nd_item<3 > item_ct1) { concat_f32_dim1 (x, y, dst, ne0, ne01, item_ct1); });
102- break ;
101+ stream->parallel_for (
102+ sycl::nd_range<3 >(gridDim *
103+ sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE),
104+ sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE)),
105+ [=](sycl::nd_item<3 > item_ct1) {
106+ concat_f32_dim1 (x, y, dst, ne0, ne01, item_ct1);
107+ });
108+ break ;
103109 // dim >=2 will be dispatched to the default path
104110 default :
105- sycl_parallel_for (stream,
106- sycl::nd_range<3 >(gridDim * sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE),
107- sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE)),
108- [=](sycl::nd_item<3 > item_ct1) { concat_f32_dim2 (x, y, dst, ne0, ne02, item_ct1); });
109- break ;
111+ stream->parallel_for (
112+ sycl::nd_range<3 >(gridDim *
113+ sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE),
114+ sycl::range<3 >(1 , 1 , SYCL_CONCAT_BLOCK_SIZE)),
115+ [=](sycl::nd_item<3 > item_ct1) {
116+ concat_f32_dim2 (x, y, dst, ne0, ne02, item_ct1);
117+ });
118+ break ;
110119 }
111120}
112121
@@ -120,29 +129,33 @@ static void concat_f32_sycl_non_cont(
120129 int64_t ne2, int64_t ne3, uint64_t nb0, uint64_t nb1, uint64_t nb2,
121130 uint64_t nb3, int32_t dim) {
122131 sycl::range<3 > gridDim (ne3, ne2, ne1);
123- sycl_parallel_for (stream, sycl::nd_range<3 >(gridDim, sycl::range<3 >(1 , 1 , 1 )), [=](sycl::nd_item<3 > item_ct1) {
124- int64_t i3 = item_ct1.get_group (0 );
125- int64_t i2 = item_ct1.get_group (1 );
126- int64_t i1 = item_ct1.get_group (2 );
132+ stream->parallel_for (
133+ sycl::nd_range<3 >(gridDim, sycl::range<3 >(1 , 1 , 1 )),
134+ [=](sycl::nd_item<3 > item_ct1) {
135+ int64_t i3 = item_ct1.get_group (0 );
136+ int64_t i2 = item_ct1.get_group (1 );
137+ int64_t i1 = item_ct1.get_group (2 );
127138
128- int64_t o[4 ] = { 0 , 0 , 0 , 0 };
129- o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03));
139+ int64_t o[4 ] = {0 , 0 , 0 , 0 };
140+ o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03));
130141
131- const float * x;
142+ const float *x;
132143
133- for (int i0 = item_ct1.get_local_id (2 ); i0 < ne0; i0 += item_ct1.get_local_range (2 )) {
144+ for (int i0 = item_ct1.get_local_id (2 ); i0 < ne0;
145+ i0 += item_ct1.get_local_range (2 )) {
134146 if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
135- x = (const float *) (src0 + (i3) *nb03 + (i2) *nb02 + (i1) *nb01 + (i0) *nb00);
147+ x = (const float *)(src0 + (i3)*nb03 + (i2)*nb02 + (i1)*nb01 +
148+ (i0)*nb00);
136149 } else {
137- x = (const float *) (src1 + (i3 - o[3 ]) * nb13 + (i2 - o[2 ]) * nb12 + (i1 - o[ 1 ]) * nb11 +
138- (i0 - o[0 ]) * nb10);
150+ x = (const float *)(src1 + (i3 - o[3 ]) * nb13 + (i2 - o[2 ]) * nb12 +
151+ (i1 - o[ 1 ]) * nb11 + (i0 - o[0 ]) * nb10);
139152 }
140153
141154 float *y = (float *)(dst + i3 * nb3 + i2 * nb2 + i1 * nb1 + i0 * nb0);
142155
143156 *y = *x;
144- }
145- });
157+ }
158+ });
146159}
147160
148161void ggml_sycl_op_concat (ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
0 commit comments