Skip to content

Commit

Permalink
Add accumulators to AArch64 GEMV Kernels
Browse files Browse the repository at this point in the history
This helps to reduce values going missing as we accumulate.
  • Loading branch information
Mousius committed Jul 31, 2024
1 parent b26424c commit ba2e989
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 36 deletions.
3 changes: 3 additions & 0 deletions kernel/arm64/KERNEL.NEOVERSEV1
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
include $(KERNELDIR)/KERNEL.ARMV8SVE

SGEMVTKERNEL = gemv_t_sve.c
DGEMVTKERNEL = gemv_t_sve.c
67 changes: 38 additions & 29 deletions kernel/arm64/gemv_t.S
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
Copyright (c) 2015, The OpenBLAS Project
Copyright (c) 2015, 2024 The OpenBLAS Project
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
Expand Down Expand Up @@ -170,39 +170,48 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

.macro KERNEL_F32_FINALIZE
#if !defined(DOUBLE)
fadd v1.4s, v1.4s, v2.4s
// F8 only has 2 accumulators
// so add into those pairs
fadd v1.4s, v1.4s, v3.4s
fadd v1.4s, v1.4s, v4.4s
#else
fadd v1.2d, v1.2d, v2.2d
fadd v1.2d, v1.2d, v3.2d
fadd v1.2d, v1.2d, v4.2d
fadd v2.4s, v2.4s, v4.4s
#endif
.endm

.macro KERNEL_F4
.macro KERNEL_F8
#if !defined(DOUBLE)
ld1 {v2.4s}, [A_PTR], #16
ld1 {v3.4s}, [X_PTR], #16
fmla v1.4s, v2.4s, v3.4s
#else
ld1 {v2.2d}, [A_PTR], #16
ld1 {v3.2d}, [X_PTR], #16
fmla v1.2d, v2.2d, v3.2d

ld1 {v4.2d}, [A_PTR], #16
ld1 {v5.2d}, [X_PTR], #16
fmla v1.2d, v4.2d, v5.2d
ld1 {v13.4s, v14.4s}, [A_PTR], #32
ld1 {v17.4s, v18.4s}, [X_PTR], #32
fmla v1.4s, v13.4s, v17.4s
fmla v2.4s, v14.4s, v18.4s
#else
ld1 {v13.2d, v14.2d, v15.2d, v16.2d}, [A_PTR], #64
ld1 {v17.2d, v18.2d, v19.2d, v20.2d}, [X_PTR], #64
fmla v1.2d, v13.2d, v17.2d
fmla v2.2d, v14.2d, v18.2d
fmla v3.2d, v15.2d, v19.2d
fmla v4.2d, v16.2d, v20.2d
#endif
.endm

.macro KERNEL_F4_FINALIZE
.macro KERNEL_F8_FINALIZE
#if !defined(DOUBLE)
ext v2.16b, v1.16b, v1.16b, #8
// Take the top two elements of v1 and
// put them into the first two lanes of v3
ext v3.16b, v1.16b, v1.16b, #8
fadd v1.2s, v1.2s, v3.2s
ext v4.16b, v2.16b, v2.16b, #8
fadd v2.2s, v2.2s, v4.2s
// Final pair
fadd v1.2s, v1.2s, v2.2s
faddp TEMP, v1.2s
#else
faddp TEMP, v1.2d
faddp TEMP1, v2.2d
faddp TEMP2, v3.2d
faddp TEMP3, v4.2d
fadd TEMP, TEMP, TEMP1
fadd TEMP2, TEMP2, TEMP3
fadd TEMP, TEMP, TEMP2
#endif
.endm

Expand Down Expand Up @@ -258,7 +267,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

asr I, M, #5
cmp I, xzr
beq .Lgemv_t_kernel_F4
beq .Lgemv_t_kernel_F8

.Lgemv_t_kernel_F320:

Expand All @@ -269,24 +278,24 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

KERNEL_F32_FINALIZE

.Lgemv_t_kernel_F4:
.Lgemv_t_kernel_F8:
ands I, M, #31
asr I, I, #2
asr I, I, #3
cmp I, xzr
beq .Lgemv_t_kernel_F1

.Lgemv_t_kernel_F40:
.Lgemv_t_kernel_F80:

KERNEL_F4
KERNEL_F8

subs I, I, #1
bne .Lgemv_t_kernel_F40
bne .Lgemv_t_kernel_F80

.Lgemv_t_kernel_F1:

KERNEL_F4_FINALIZE
KERNEL_F8_FINALIZE

ands I, M, #3
ands I, M, #7
ble .Lgemv_t_kernel_F_END

.Lgemv_t_kernel_F10:
Expand Down
40 changes: 33 additions & 7 deletions kernel/arm64/gemv_t_sve.c
Original file line number Diff line number Diff line change
Expand Up @@ -59,20 +59,46 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT alpha, FLOAT *a, BLASLO
a_ptr = a;

if (inc_x == 1) {
svbool_t pg_true = SV_TRUE();
uint64_t sve_size = SV_COUNT();
uint64_t sve_size2 = sve_size * 2;
BLASLONG m1 = m & -sve_size;
BLASLONG m2 = m & -sve_size2;

for (j = 0; j < n; j++) {
BLASLONG i = 0;

SV_TYPE temp_vec_v2_0 = SV_DUP(0.0);
SV_TYPE temp_vec_v2_1 = SV_DUP(0.0);
for (; i < m2; i += sve_size2) {
SV_TYPE a_vec0 = svld1(pg_true, a_ptr + i);
SV_TYPE x_vec0 = svld1(pg_true, x + i);
SV_TYPE a_vec1 = svld1(pg_true, a_ptr + i + sve_size);
SV_TYPE x_vec1 = svld1(pg_true, x + i + sve_size);
temp_vec_v2_0 = svmla_m(pg_true, temp_vec_v2_0, a_vec0, x_vec0);
temp_vec_v2_1 = svmla_m(pg_true, temp_vec_v2_1, a_vec1, x_vec1);
}

SV_TYPE temp_vec_v1 = SV_DUP(0.0);
for (; i < m1; i += sve_size) {
SV_TYPE a_vec0 = svld1(pg_true, a_ptr + i);
SV_TYPE x_vec0 = svld1(pg_true, x + i);
temp_vec_v1 = svmla_m(pg_true, temp_vec_v1, a_vec0, x_vec0);
}

SV_TYPE temp_vec = SV_DUP(0.0);
i = 0;
svbool_t pg = SV_WHILE(i, m);
while (svptest_any(SV_TRUE(), pg)) {
for (; i < m; i += sve_size) {
svbool_t pg = SV_WHILE(i, m);
SV_TYPE a_vec = svld1(pg, a_ptr + i);
SV_TYPE x_vec = svld1(pg, x + i);
temp_vec = svmla_m(pg, temp_vec, a_vec, x_vec);
i += sve_size;
pg = SV_WHILE(i, m);
}
temp = svaddv(SV_TRUE(), temp_vec);
y[iy] += alpha * temp;

y[iy] += alpha * (
(svaddv(SV_TRUE(), temp_vec_v2_0) + svaddv(SV_TRUE(), temp_vec)) +
(svaddv(SV_TRUE(), temp_vec_v2_1) + svaddv(SV_TRUE(), temp_vec_v1))
);

iy += inc_y;
a_ptr += lda;
}
Expand Down

0 comments on commit ba2e989

Please sign in to comment.