@@ -404,6 +404,72 @@ void ggml_vec_swiglu_f32(const int n, float * y, const float * x, const float *
404404 }
405405}
406406
407+ ggml_float ggml_vec_cvar_f32 (const int n, float * y, const float * x, const float mean) {
408+ int i = 0 ;
409+ ggml_float sum = 0 ;
410+ // TODO: optimize to process the remaining elements in groups using the smaller vector sizes from AVX2 and SSE
411+ // ref: https://github.com/ggml-org/llama.cpp/pull/15953#pullrequestreview-3310928344
412+ #if defined(__AVX512F__) && defined(__AVX512DQ__)
413+ for (; i + 15 < n; i += 16 ) {
414+ __m512 val = _mm512_sub_ps (_mm512_loadu_ps (x + i),
415+ _mm512_set1_ps (mean));
416+ _mm512_storeu_ps (y + i, val);
417+ sum += (ggml_float)_mm512_reduce_add_ps (_mm512_mul_ps (val, val));
418+ }
419+ #elif defined(__AVX2__) && defined(__FMA__)
420+ for (; i + 7 < n; i += 8 ) {
421+ __m256 val = _mm256_sub_ps (_mm256_loadu_ps (x + i),
422+ _mm256_set1_ps (mean));
423+ _mm256_storeu_ps (y + i, val);
424+ val = _mm256_mul_ps (val,val);
425+ __m128 val2 = _mm_add_ps (_mm256_extractf128_ps (val, 1 ),
426+ _mm256_castps256_ps128 (val));
427+ val2 = _mm_add_ps (val2, _mm_movehl_ps (val2, val2));
428+ val2 = _mm_add_ss (val2, _mm_movehdup_ps (val2));
429+ sum += (ggml_float)_mm_cvtss_f32 (val2);
430+ }
431+ #elif defined(__SSE2__)
432+ for (; i + 3 < n; i += 4 ) {
433+ __m128 val = _mm_sub_ps (_mm_loadu_ps (x + i),
434+ _mm_set1_ps (mean));
435+ _mm_storeu_ps (y + i, val);
436+ val = _mm_mul_ps (val, val);
437+ #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
438+ val = _mm_add_ps (val, _mm_movehl_ps (val, val));
439+ val = _mm_add_ss (val, _mm_movehdup_ps (val));
440+ #else
441+ __m128 tmp = _mm_shuffle_ps (val, val, _MM_SHUFFLE (2 , 3 , 0 , 1 ));
442+ val = _mm_add_ps (val, tmp);
443+ tmp = _mm_movehl_ps (tmp, val);
444+ val = _mm_add_ss (val, tmp);
445+ #endif // __AVX__ || __AVX2__ || __AVX512F__
446+ sum += (ggml_float)_mm_cvtss_f32 (val);
447+ }
448+ #elif defined(__ARM_NEON) && defined(__aarch64__)
449+ for (; i + 3 < n; i += 4 ) {
450+ float32x4_t val = vsubq_f32 (vld1q_f32 (x + i),
451+ vdupq_n_f32 (mean));
452+ vst1q_f32 (y + i, val);
453+ val = vmulq_f32 (val, val);
454+ sum += (ggml_float)vaddvq_f32 (val);
455+ }
456+ #elif defined(__VXE__) || defined(__VXE2__)
457+ for (; i + 3 < n; i += 4 ) {
458+ float32x4_t val = vec_sub (vec_xl (0 , x + i), vec_splats (mean));
459+ vec_xst (val, 0 , y + i);
460+ val = vec_mul (val, val);
461+ sum += (ggml_float)vec_hsum_f32x4 (val);
462+ }
463+ #endif
464+ for (; i < n; ++i) {
465+ float val = x[i] - mean;
466+ val *= val;
467+ sum += (ggml_float)val;
468+ y[i] = val;
469+ }
470+ return sum/n;
471+ }
472+
407473ggml_float ggml_vec_soft_max_f32 (const int n, float * y, const float * x, float max) {
408474 int i = 0 ;
409475 ggml_float sum = 0 ;
0 commit comments