Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Forward GEMM to GEMV when one argument is actually a vector #4814

Merged
merged 5 commits into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions Makefile.system
Original file line number Diff line number Diff line change
Expand Up @@ -274,9 +274,18 @@ endif
ifeq ($(ARCH), loongarch64)
SMALL_MATRIX_OPT = 1
endif
ifeq ($(ARCH), arm64)
GEMM_GEMV_FORWARD = 1
endif

ifeq ($(SMALL_MATRIX_OPT), 1)
CCOMMON_OPT += -DSMALL_MATRIX_OPT
endif
ifeq ($(GEMM_GEMV_FORWARD), 1)
ifneq ($(ONLY_CBLAS), 1)
CCOMMON_OPT += -DGEMM_GEMV_FORWARD
endif
endif

# This operation is expensive, so execution should be once.
ifndef GOTOBLAS_MAKEFILE
Expand Down
7 changes: 7 additions & 0 deletions cmake/system.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,13 @@ endif ()
if (X86_64 OR ${CORE} STREQUAL POWER10)
set(SMALL_MATRIX_OPT TRUE)
endif ()
if (ARM64)
set(GEMM_GEMV_FORWARD TRUE)
endif ()

if (GEMM_GEMV_FORWARD AND NOT ONLY_CBLAS)
set(CCOMMON_OPT "${CCOMMON_OPT} -DGEMM_GEMV_FORWARD")
endif ()
if (SMALL_MATRIX_OPT)
set(CCOMMON_OPT "${CCOMMON_OPT} -DSMALL_MATRIX_OPT")
endif ()
Expand Down
51 changes: 51 additions & 0 deletions interface/gemm.c
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
/*********************************************************************/
/* Copyright 2024 The OpenBLAS Project */
/* Copyright 2009, 2010 The University of Texas at Austin. */
/* All rights reserved. */
/* */
Expand Down Expand Up @@ -47,12 +48,16 @@
#define SMP_THRESHOLD_MIN 65536.0
#ifdef XDOUBLE
#define ERROR_NAME "QGEMM "
#define GEMV BLASFUNC(qgemv)
#elif defined(DOUBLE)
#define ERROR_NAME "DGEMM "
#define GEMV BLASFUNC(dgemv)
#elif defined(BFLOAT16)
#define ERROR_NAME "SBGEMM "
#define GEMV BLASFUNC(sbgemv)
#else
#define ERROR_NAME "SGEMM "
#define GEMV BLASFUNC(sgemv)
#endif
#else
#define SMP_THRESHOLD_MIN 8192.0
Expand Down Expand Up @@ -493,6 +498,52 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS
args.m, args.n, args.k, args.lda, args.ldb, args.ldc);
#endif

#if defined(GEMM_GEMV_FORWARD) && !defined(GEMM3M) && !defined(COMPLEX)
// Check if we can convert GEMM -> GEMV
if (args.k != 0) {
if (args.n == 1) {
blasint inc_x = 1;
blasint inc_y = 1;
// These were passed in as blasint, but the struct translates them to blaslong
blasint m = args.m;
blasint n = args.k;
blasint lda = args.lda;
// Create new transpose parameters
char NT = 'N';
if (transa & 1) {
NT = 'T';
m = args.k;
n = args.m;
}
if (transb & 1) {
inc_x = args.ldb;
}
GEMV(&NT, &m, &n, args.alpha, args.a, &lda, args.b, &inc_x, args.beta, args.c, &inc_y);
return;
}
if (args.m == 1) {
blasint inc_x = args.lda;
blasint inc_y = args.ldc;
// These were passed in as blasint, but the struct translates them to blaslong
blasint m = args.k;
blasint n = args.n;
blasint ldb = args.ldb;
// Create new transpose parameters
char NT = 'T';
if (transa & 1) {
inc_x = 1;
}
if (transb & 1) {
NT = 'N';
m = args.n;
n = args.k;
}
GEMV(&NT, &m, &n, args.alpha, args.b, &ldb, args.a, &inc_x, args.beta, args.c, &inc_y);
return;
}
}
#endif

IDEBUG_START;

FUNCTION_PROFILE_START();
Expand Down
3 changes: 3 additions & 0 deletions kernel/arm64/KERNEL.NEOVERSEV1
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
include $(KERNELDIR)/KERNEL.ARMV8SVE

SGEMVTKERNEL = gemv_t_sve.c
DGEMVTKERNEL = gemv_t_sve.c
67 changes: 38 additions & 29 deletions kernel/arm64/gemv_t.S
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
Copyright (c) 2015, The OpenBLAS Project
Copyright (c) 2015, 2024 The OpenBLAS Project
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
Expand Down Expand Up @@ -170,39 +170,48 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

.macro KERNEL_F32_FINALIZE
#if !defined(DOUBLE)
fadd v1.4s, v1.4s, v2.4s
// F8 only has 2 accumulators
// so add into those pairs
fadd v1.4s, v1.4s, v3.4s
fadd v1.4s, v1.4s, v4.4s
#else
fadd v1.2d, v1.2d, v2.2d
fadd v1.2d, v1.2d, v3.2d
fadd v1.2d, v1.2d, v4.2d
fadd v2.4s, v2.4s, v4.4s
#endif
.endm

.macro KERNEL_F4
.macro KERNEL_F8
#if !defined(DOUBLE)
ld1 {v2.4s}, [A_PTR], #16
ld1 {v3.4s}, [X_PTR], #16
fmla v1.4s, v2.4s, v3.4s
#else
ld1 {v2.2d}, [A_PTR], #16
ld1 {v3.2d}, [X_PTR], #16
fmla v1.2d, v2.2d, v3.2d

ld1 {v4.2d}, [A_PTR], #16
ld1 {v5.2d}, [X_PTR], #16
fmla v1.2d, v4.2d, v5.2d
ld1 {v13.4s, v14.4s}, [A_PTR], #32
ld1 {v17.4s, v18.4s}, [X_PTR], #32
fmla v1.4s, v13.4s, v17.4s
fmla v2.4s, v14.4s, v18.4s
#else
ld1 {v13.2d, v14.2d, v15.2d, v16.2d}, [A_PTR], #64
ld1 {v17.2d, v18.2d, v19.2d, v20.2d}, [X_PTR], #64
fmla v1.2d, v13.2d, v17.2d
fmla v2.2d, v14.2d, v18.2d
fmla v3.2d, v15.2d, v19.2d
fmla v4.2d, v16.2d, v20.2d
#endif
.endm

.macro KERNEL_F4_FINALIZE
.macro KERNEL_F8_FINALIZE
#if !defined(DOUBLE)
ext v2.16b, v1.16b, v1.16b, #8
// Take the top two elements of v1 and
// put them into the first two lanes of v3
ext v3.16b, v1.16b, v1.16b, #8
fadd v1.2s, v1.2s, v3.2s
ext v4.16b, v2.16b, v2.16b, #8
fadd v2.2s, v2.2s, v4.2s
// Final pair
fadd v1.2s, v1.2s, v2.2s
faddp TEMP, v1.2s
#else
faddp TEMP, v1.2d
faddp TEMP1, v2.2d
faddp TEMP2, v3.2d
faddp TEMP3, v4.2d
fadd TEMP, TEMP, TEMP1
fadd TEMP2, TEMP2, TEMP3
fadd TEMP, TEMP, TEMP2
#endif
.endm

Expand Down Expand Up @@ -258,7 +267,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

asr I, M, #5
cmp I, xzr
beq .Lgemv_t_kernel_F4
beq .Lgemv_t_kernel_F8

.Lgemv_t_kernel_F320:

Expand All @@ -269,24 +278,24 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

KERNEL_F32_FINALIZE

.Lgemv_t_kernel_F4:
.Lgemv_t_kernel_F8:
ands I, M, #31
asr I, I, #2
asr I, I, #3
cmp I, xzr
beq .Lgemv_t_kernel_F1

.Lgemv_t_kernel_F40:
.Lgemv_t_kernel_F80:

KERNEL_F4
KERNEL_F8

subs I, I, #1
bne .Lgemv_t_kernel_F40
bne .Lgemv_t_kernel_F80

.Lgemv_t_kernel_F1:

KERNEL_F4_FINALIZE
KERNEL_F8_FINALIZE

ands I, M, #3
ands I, M, #7
ble .Lgemv_t_kernel_F_END

.Lgemv_t_kernel_F10:
Expand Down
40 changes: 33 additions & 7 deletions kernel/arm64/gemv_t_sve.c
Original file line number Diff line number Diff line change
Expand Up @@ -59,20 +59,46 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT alpha, FLOAT *a, BLASLO
a_ptr = a;

if (inc_x == 1) {
svbool_t pg_true = SV_TRUE();
uint64_t sve_size = SV_COUNT();
uint64_t sve_size2 = sve_size * 2;
BLASLONG m1 = m & -sve_size;
BLASLONG m2 = m & -sve_size2;

for (j = 0; j < n; j++) {
BLASLONG i = 0;

SV_TYPE temp_vec_v2_0 = SV_DUP(0.0);
SV_TYPE temp_vec_v2_1 = SV_DUP(0.0);
for (; i < m2; i += sve_size2) {
SV_TYPE a_vec0 = svld1(pg_true, a_ptr + i);
SV_TYPE x_vec0 = svld1(pg_true, x + i);
SV_TYPE a_vec1 = svld1(pg_true, a_ptr + i + sve_size);
SV_TYPE x_vec1 = svld1(pg_true, x + i + sve_size);
temp_vec_v2_0 = svmla_m(pg_true, temp_vec_v2_0, a_vec0, x_vec0);
temp_vec_v2_1 = svmla_m(pg_true, temp_vec_v2_1, a_vec1, x_vec1);
}

SV_TYPE temp_vec_v1 = SV_DUP(0.0);
for (; i < m1; i += sve_size) {
SV_TYPE a_vec0 = svld1(pg_true, a_ptr + i);
SV_TYPE x_vec0 = svld1(pg_true, x + i);
temp_vec_v1 = svmla_m(pg_true, temp_vec_v1, a_vec0, x_vec0);
}

SV_TYPE temp_vec = SV_DUP(0.0);
i = 0;
svbool_t pg = SV_WHILE(i, m);
while (svptest_any(SV_TRUE(), pg)) {
for (; i < m; i += sve_size) {
svbool_t pg = SV_WHILE(i, m);
SV_TYPE a_vec = svld1(pg, a_ptr + i);
SV_TYPE x_vec = svld1(pg, x + i);
temp_vec = svmla_m(pg, temp_vec, a_vec, x_vec);
i += sve_size;
pg = SV_WHILE(i, m);
}
temp = svaddv(SV_TRUE(), temp_vec);
y[iy] += alpha * temp;

y[iy] += alpha * (
(svaddv(SV_TRUE(), temp_vec_v2_0) + svaddv(SV_TRUE(), temp_vec)) +
(svaddv(SV_TRUE(), temp_vec_v2_1) + svaddv(SV_TRUE(), temp_vec_v1))
);

iy += inc_y;
a_ptr += lda;
}
Expand Down
Loading