Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 37 additions & 53 deletions ggml/src/ggml-cpu/vec.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -273,67 +273,51 @@ void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * G

#if defined(GGML_SIMD)
#if defined(__ARM_FEATURE_SVE)
const int sve_register_length = svcntb() * 8; //get vector length
const int ggml_f16_epr = sve_register_length / 16; // running when 16
const int ggml_f16_step = 8 * ggml_f16_epr; // choose 8 SVE registers

const int np= (n & ~(ggml_f16_step - 1));
svfloat16_t sum1 = svdup_n_f16(0.0f);
svfloat16_t sum2 = svdup_n_f16(0.0f);
svfloat16_t sum3 = svdup_n_f16(0.0f);
svfloat16_t sum4 = svdup_n_f16(0.0f);

svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8;
svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8;
for (int i = 0; i < np; i += ggml_f16_step) {
ax1 = GGML_F16x_VEC_LOAD(x + i + 0 * ggml_f16_epr, 0);
ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0);
sum1 = GGML_F16x_VEC_FMA(sum1, ax1, ay1);

ax2 = GGML_F16x_VEC_LOAD(x + i + 1 * ggml_f16_epr, 1);
ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1);
sum2 = GGML_F16x_VEC_FMA(sum2, ax2, ay2);

ax3 = GGML_F16x_VEC_LOAD(x + i + 2 * ggml_f16_epr, 2);
ay3 = GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2);
sum3 = GGML_F16x_VEC_FMA(sum3, ax3, ay3);

ax4 = GGML_F16x_VEC_LOAD(x + i + 3 * ggml_f16_epr, 3);
ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3);
sum4 = GGML_F16x_VEC_FMA(sum4, ax4, ay4);

ax5 = GGML_F16x_VEC_LOAD(x + i + 4 * ggml_f16_epr, 4);
ay5 = GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4);
sum1 = GGML_F16x_VEC_FMA(sum1, ax5, ay5);
const int ggml_f16_epr = svcnth();
const int ggml_f16_step = 8 * ggml_f16_epr;
const int np = n - (n % ggml_f16_step);
const int np2 = n - (n % ggml_f16_epr);

svfloat32_t sum1_lo = svdup_n_f32(0.0f);
svfloat32_t sum1_hi = svdup_n_f32(0.0f);
svfloat32_t sum2_lo = svdup_n_f32(0.0f);
svfloat32_t sum2_hi = svdup_n_f32(0.0f);
svfloat32_t sum3_lo = svdup_n_f32(0.0f);
svfloat32_t sum3_hi = svdup_n_f32(0.0f);
svfloat32_t sum4_lo = svdup_n_f32(0.0f);
svfloat32_t sum4_hi = svdup_n_f32(0.0f);

ax6 = GGML_F16x_VEC_LOAD(x + i + 5 * ggml_f16_epr, 5);
ay6 = GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5);
sum2 = GGML_F16x_VEC_FMA(sum2, ax6, ay6);

ax7 = GGML_F16x_VEC_LOAD(x + i + 6 * ggml_f16_epr, 6);
ay7 = GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6);
sum3 = GGML_F16x_VEC_FMA(sum3, ax7, ay7);

ax8 = GGML_F16x_VEC_LOAD(x + i + 7 * ggml_f16_epr, 7);
ay8 = GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7);
sum4 = GGML_F16x_VEC_FMA(sum4, ax8, ay8);
for (int i = 0; i < np; i += ggml_f16_step) {
ggml_sve_f16_fma_widened(&sum1_lo, &sum1_hi, GGML_F16x_VEC_LOAD(x + i + 0 * ggml_f16_epr, 0), GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0));
ggml_sve_f16_fma_widened(&sum2_lo, &sum2_hi, GGML_F16x_VEC_LOAD(x + i + 1 * ggml_f16_epr, 1), GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1));
ggml_sve_f16_fma_widened(&sum3_lo, &sum3_hi, GGML_F16x_VEC_LOAD(x + i + 2 * ggml_f16_epr, 2), GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2));
ggml_sve_f16_fma_widened(&sum4_lo, &sum4_hi, GGML_F16x_VEC_LOAD(x + i + 3 * ggml_f16_epr, 3), GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3));
ggml_sve_f16_fma_widened(&sum1_lo, &sum1_hi, GGML_F16x_VEC_LOAD(x + i + 4 * ggml_f16_epr, 4), GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4));
ggml_sve_f16_fma_widened(&sum2_lo, &sum2_hi, GGML_F16x_VEC_LOAD(x + i + 5 * ggml_f16_epr, 5), GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5));
ggml_sve_f16_fma_widened(&sum3_lo, &sum3_hi, GGML_F16x_VEC_LOAD(x + i + 6 * ggml_f16_epr, 6), GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6));
ggml_sve_f16_fma_widened(&sum4_lo, &sum4_hi, GGML_F16x_VEC_LOAD(x + i + 7 * ggml_f16_epr, 7), GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7));
}

const int np2 = (n & ~(ggml_f16_epr - 1)); // round down to multiple of 8
for (int k = np; k < np2; k += ggml_f16_epr) {
svfloat16_t rx = GGML_F16x_VEC_LOAD(x + k, 0);
svfloat16_t ry = GGML_F16x_VEC_LOAD(y + k, 0);
sum1 = GGML_F16x_VEC_FMA(sum1, rx, ry);
for (int i = np; i < np2; i += ggml_f16_epr) {
ggml_sve_f16_fma_widened(&sum1_lo, &sum1_hi, GGML_F16x_VEC_LOAD(x + i, 0), GGML_F16x_VEC_LOAD(y + i, 0));
}

if (np2 < n) {
svbool_t pg = svwhilelt_b16(np2, n);
svfloat16_t hx = svld1_f16(pg, (const __fp16 *)(x + np2));
svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2));
const svbool_t pg = svwhilelt_b16(np2, n);
const svfloat16_t rx = svld1_f16(pg, (const __fp16 *)(x + np2));
const svfloat16_t ry = svld1_f16(pg, (const __fp16 *)(y + np2));

sum1 = svmad_f16_x(pg, hx, hy, sum1);
ggml_sve_f16_fma_widened(&sum1_lo, &sum1_hi, rx, ry);
}
GGML_F16x_VEC_REDUCE(sumf, sum1, sum2, sum3, sum4);

sum1_lo = svadd_f32_m(DEFAULT_PG32, sum1_lo, sum2_lo);
sum1_hi = svadd_f32_m(DEFAULT_PG32, sum1_hi, sum2_hi);
sum3_lo = svadd_f32_m(DEFAULT_PG32, sum3_lo, sum4_lo);
sum3_hi = svadd_f32_m(DEFAULT_PG32, sum3_hi, sum4_hi);
sum1_lo = svadd_f32_m(DEFAULT_PG32, sum1_lo, sum3_lo);
sum1_hi = svadd_f32_m(DEFAULT_PG32, sum1_hi, sum3_hi);

sumf = ggml_sve_sum_f32x2(sum1_lo, sum1_hi);
#elif defined(__riscv_v_intrinsic)
#if defined(__riscv_zvfh)
int vl = __riscv_vsetvlmax_e32m2();
Expand Down
158 changes: 70 additions & 88 deletions ggml/src/ggml-cpu/vec.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,35 @@
// floating point type used to accumulate sums
typedef double ggml_float;

#if defined(__ARM_FEATURE_SVE)
inline static void ggml_sve_f16_fma_widened(
svfloat32_t * acc_lo,
svfloat32_t * acc_hi,
svfloat16_t x,
svfloat16_t y) {
#if defined(__ARM_FEATURE_SVE2)
*acc_lo = svmlalb_f32(*acc_lo, x, y);
*acc_hi = svmlalt_f32(*acc_hi, x, y);
#else
// Plain SVE fallback path if SVE2 instructions not available
svfloat16_t x_even = svtrn1_f16(x, x);
svfloat16_t x_odd = svtrn2_f16(x, x);

svfloat16_t y_even = svtrn1_f16(y, y);
svfloat16_t y_odd = svtrn2_f16(y, y);

svbool_t pg = svptrue_b32();

*acc_lo = svmla_f32_x(pg, *acc_lo, svcvt_f32_f16_x(pg, x_even), svcvt_f32_f16_x(pg, y_even));
*acc_hi = svmla_f32_x(pg, *acc_hi, svcvt_f32_f16_x(pg, x_odd), svcvt_f32_f16_x(pg, y_odd));
#endif
}

inline static ggml_float ggml_sve_sum_f32x2(svfloat32_t sum_lo, svfloat32_t sum_hi) {
return (ggml_float) (svaddv_f32(svptrue_b32(), sum_lo) + svaddv_f32(svptrue_b32(), sum_hi));
}
#endif

#define GGML_GELU_FP16
#define GGML_GELU_QUICK_FP16

Expand Down Expand Up @@ -122,108 +151,61 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG
#if defined(GGML_SIMD)
#if defined(__ARM_FEATURE_SVE)

const int sve_register_length = svcntb() * 8;
const int ggml_f16_epr = sve_register_length / 16; // running when 16
const int ggml_f16_step = 8 * ggml_f16_epr; // choose 8 SVE registers

int np = (n & ~(ggml_f16_step - 1));

svfloat16_t sum_00 = svdup_n_f16(0.0f);
svfloat16_t sum_01 = svdup_n_f16(0.0f);
svfloat16_t sum_02 = svdup_n_f16(0.0f);
svfloat16_t sum_03 = svdup_n_f16(0.0f);
const int ggml_f16_epr = svcnth();
const int ggml_f16_step = 2 * ggml_f16_epr;
int np = n - (n % ggml_f16_step);
int np2 = n - (n % ggml_f16_epr);

svfloat16_t sum_10 = svdup_n_f16(0.0f);
svfloat16_t sum_11 = svdup_n_f16(0.0f);
svfloat16_t sum_12 = svdup_n_f16(0.0f);
svfloat16_t sum_13 = svdup_n_f16(0.0f);

svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8;
svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8;
svfloat32_t sum_0_0_lo = svdup_n_f32(0.0f);
svfloat32_t sum_0_0_hi = svdup_n_f32(0.0f);
svfloat32_t sum_0_1_lo = svdup_n_f32(0.0f);
svfloat32_t sum_0_1_hi = svdup_n_f32(0.0f);
svfloat32_t sum_1_0_lo = svdup_n_f32(0.0f);
svfloat32_t sum_1_0_hi = svdup_n_f32(0.0f);
svfloat32_t sum_1_1_lo = svdup_n_f32(0.0f);
svfloat32_t sum_1_1_hi = svdup_n_f32(0.0f);

for (int i = 0; i < np; i += ggml_f16_step) {
ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0); // 8 elements

ax1 = GGML_F16x_VEC_LOAD(x[0] + i + 0*ggml_f16_epr, 0); // 8 elements
sum_00 = GGML_F16x_VEC_FMA(sum_00, ax1, ay1); // sum_00 = sum_00+ax1*ay1
ax1 = GGML_F16x_VEC_LOAD(x[1] + i + 0*ggml_f16_epr, 0); // 8 elements
sum_10 = GGML_F16x_VEC_FMA(sum_10, ax1, ay1);

ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1); // next 8 elements
const svfloat16_t ay0 = GGML_F16x_VEC_LOAD(y + i, 0);
const svfloat16_t ax00 = GGML_F16x_VEC_LOAD(x[0] + i, 0);
const svfloat16_t ax01 = GGML_F16x_VEC_LOAD(x[1] + i, 0);

ax2 = GGML_F16x_VEC_LOAD(x[0] + i + 1*ggml_f16_epr, 1); // next 8 elements
sum_01 = GGML_F16x_VEC_FMA(sum_01, ax2, ay2);
ax2 = GGML_F16x_VEC_LOAD(x[1] + i + 1*ggml_f16_epr, 1);
sum_11 = GGML_F16x_VEC_FMA(sum_11, ax2, ay2);
ggml_sve_f16_fma_widened(&sum_0_0_lo, &sum_0_0_hi, ax00, ay0);
ggml_sve_f16_fma_widened(&sum_1_0_lo, &sum_1_0_hi, ax01, ay0);

ay3 = GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2);
const svfloat16_t ay1 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 0);
const svfloat16_t ax10 = GGML_F16x_VEC_LOAD(x[0] + i + 1 * ggml_f16_epr, 0);
const svfloat16_t ax11 = GGML_F16x_VEC_LOAD(x[1] + i + 1 * ggml_f16_epr, 0);

ax3 = GGML_F16x_VEC_LOAD(x[0] + i + 2*ggml_f16_epr, 2);
sum_02 = GGML_F16x_VEC_FMA(sum_02, ax3, ay3);
ax3 = GGML_F16x_VEC_LOAD(x[1] + i + 2*ggml_f16_epr, 2);
sum_12 = GGML_F16x_VEC_FMA(sum_12, ax3, ay3);

ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3);

ax4 = GGML_F16x_VEC_LOAD(x[0] + i + 3*ggml_f16_epr, 3);
sum_03 = GGML_F16x_VEC_FMA(sum_03, ax4, ay4);
ax4 = GGML_F16x_VEC_LOAD(x[1] + i + 3*ggml_f16_epr, 3);
sum_13 = GGML_F16x_VEC_FMA(sum_13, ax4, ay4);

ay5 = GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4);

ax5 = GGML_F16x_VEC_LOAD(x[0] + i + 4*ggml_f16_epr, 4);

sum_00 = GGML_F16x_VEC_FMA(sum_00, ax5, ay5);
ax5 = GGML_F16x_VEC_LOAD(x[1] + i + 4*ggml_f16_epr, 4);
sum_10 = GGML_F16x_VEC_FMA(sum_10, ax5, ay5);

ay6 = GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5);

ax6 = GGML_F16x_VEC_LOAD(x[0] + i + 5*ggml_f16_epr, 5);

sum_01 = GGML_F16x_VEC_FMA(sum_01, ax6, ay6);
ax6 = GGML_F16x_VEC_LOAD(x[1] + i + 5*ggml_f16_epr, 5);
sum_11 = GGML_F16x_VEC_FMA(sum_11, ax6, ay6);

ay7 = GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6);

ax7 = GGML_F16x_VEC_LOAD(x[0] + i + 6*ggml_f16_epr, 6);

sum_02 = GGML_F16x_VEC_FMA(sum_02, ax7, ay7);
ax7 = GGML_F16x_VEC_LOAD(x[1] + i + 6*ggml_f16_epr, 6);
sum_12 = GGML_F16x_VEC_FMA(sum_12, ax7, ay7);

ay8 = GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7);

ax8 = GGML_F16x_VEC_LOAD(x[0] + i + 7*ggml_f16_epr, 7);

sum_03 = GGML_F16x_VEC_FMA(sum_03, ax8, ay8);
ax8 = GGML_F16x_VEC_LOAD(x[1] + i + 7*ggml_f16_epr, 7);
sum_13 = GGML_F16x_VEC_FMA(sum_13, ax8, ay8);
ggml_sve_f16_fma_widened(&sum_0_1_lo, &sum_0_1_hi, ax10, ay1);
ggml_sve_f16_fma_widened(&sum_1_1_lo, &sum_1_1_hi, ax11, ay1);
}

const int np2 = (n & ~(ggml_f16_epr - 1));
for (int k = np; k < np2; k += ggml_f16_epr) {
svfloat16_t ry = GGML_F16x_VEC_LOAD(y + k, 0);
for (int i = np; i < np2; i += ggml_f16_epr) {
const svfloat16_t ry = GGML_F16x_VEC_LOAD(y + i, 0);
const svfloat16_t rx0 = GGML_F16x_VEC_LOAD(x[0] + i, 0);
const svfloat16_t rx1 = GGML_F16x_VEC_LOAD(x[1] + i, 0);

svfloat16_t rx = GGML_F16x_VEC_LOAD(x[0] + k, 0);
sum_00 = GGML_F16x_VEC_FMA(sum_00, rx, ry);
rx = GGML_F16x_VEC_LOAD(x[1] + k, 0);
sum_10 = GGML_F16x_VEC_FMA(sum_10, rx, ry);
ggml_sve_f16_fma_widened(&sum_0_0_lo, &sum_0_0_hi, rx0, ry);
ggml_sve_f16_fma_widened(&sum_1_0_lo, &sum_1_0_hi, rx1, ry);
}

if (np2 < n) {
svbool_t pg = svwhilelt_b16(np2, n);
svfloat16_t hx_0 = svld1_f16(pg, (const __fp16 *)(x[0] + np2));
svfloat16_t hx_1 = svld1_f16(pg, (const __fp16 *)(x[1] + np2));
svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2));
const svbool_t pg = svwhilelt_b16(np2, n);
const svfloat16_t ay = svld1_f16(pg, (const __fp16 *)(y + np2));
const svfloat16_t ax0 = svld1_f16(pg, (const __fp16 *)(x[0] + np2));
const svfloat16_t ax1 = svld1_f16(pg, (const __fp16 *)(x[1] + np2));

sum_00 = svmad_f16_x(pg, hx_0, hy, sum_00);
sum_10 = svmad_f16_x(pg, hx_1, hy, sum_10);
ggml_sve_f16_fma_widened(&sum_0_0_lo, &sum_0_0_hi, ax0, ay);
ggml_sve_f16_fma_widened(&sum_1_0_lo, &sum_1_0_hi, ax1, ay);
}
GGML_F16x_VEC_REDUCE(sumf[0], sum_00, sum_01, sum_02, sum_03);
GGML_F16x_VEC_REDUCE(sumf[1], sum_10, sum_11, sum_12, sum_13);

svfloat32_t sum_0_lo = svadd_f32_x(DEFAULT_PG32, sum_0_0_lo, sum_0_1_lo);
svfloat32_t sum_0_hi = svadd_f32_x(DEFAULT_PG32, sum_0_0_hi, sum_0_1_hi);
svfloat32_t sum_1_lo = svadd_f32_x(DEFAULT_PG32, sum_1_0_lo, sum_1_1_lo);
svfloat32_t sum_1_hi = svadd_f32_x(DEFAULT_PG32, sum_1_0_hi, sum_1_1_hi);
sumf[0] = ggml_sve_sum_f32x2(sum_0_lo, sum_0_hi);
sumf[1] = ggml_sve_sum_f32x2(sum_1_lo, sum_1_hi);
np = n;
#elif defined(__riscv_v_intrinsic)
#if defined(__riscv_zvfh)
Expand Down
Loading