@@ -1695,7 +1695,7 @@ static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, con
16951695 dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
16961696}
16971697
1698- static void scale_f32 (const float * x, float * dst, const float scale, const int k,
1698+ static void scale_f32 (const float * x, float * dst, const float scale, const float bias, const int k,
16991699 const sycl::nd_item<3 > &item_ct1) {
17001700 const int i = item_ct1.get_local_range (2 ) * item_ct1.get_group (2 ) +
17011701 item_ct1.get_local_id (2 );
@@ -1704,7 +1704,7 @@ static void scale_f32(const float * x, float * dst, const float scale, const int
17041704 return ;
17051705 }
17061706
1707- dst[i] = scale * x[i];
1707+ dst[i] = scale * x[i] + bias ;
17081708}
17091709
17101710
@@ -1842,15 +1842,15 @@ static void ggml_mul_mat_vec_nc_f16_f32_sycl(
18421842
18431843
18441844
1845- static void scale_f32_sycl (const float *x, float *dst, const float scale,
1845+ static void scale_f32_sycl (const float *x, float *dst, const float scale, const float bias,
18461846 const int k, queue_ptr stream) {
18471847 const int num_blocks = (k + SYCL_SCALE_BLOCK_SIZE - 1 ) / SYCL_SCALE_BLOCK_SIZE;
18481848 stream->parallel_for (
18491849 sycl::nd_range<3 >(sycl::range<3 >(1 , 1 , num_blocks) *
18501850 sycl::range<3 >(1 , 1 , SYCL_SCALE_BLOCK_SIZE),
18511851 sycl::range<3 >(1 , 1 , SYCL_SCALE_BLOCK_SIZE)),
18521852 [=](sycl::nd_item<3 > item_ct1) {
1853- scale_f32 (x, dst, scale, k, item_ct1);
1853+ scale_f32 (x, dst, scale, bias, k, item_ct1);
18541854 });
18551855}
18561856
@@ -2319,9 +2319,11 @@ inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, ggml_tensor * ds
23192319 float * dst_dd = static_cast <float *>(dst->data );
23202320
23212321 float scale;
2322- memcpy (&scale, dst->op_params , sizeof (float ));
2322+ float bias;
2323+ memcpy (&scale, (float *) dst->op_params + 0 , sizeof (float ));
2324+ memcpy (&bias, (float *) dst->op_params + 1 , sizeof (float ));
23232325
2324- scale_f32_sycl (src0_dd, dst_dd, scale, ggml_nelements (dst->src [0 ]), main_stream);
2326+ scale_f32_sycl (src0_dd, dst_dd, scale, bias, ggml_nelements (dst->src [0 ]), main_stream);
23252327 /*
23262328 DPCT1010:87: SYCL uses exceptions to report errors and does not use the
23272329 error codes. The call was replaced with 0. You need to rewrite this code.
0 commit comments