diff --git a/Makefile.system b/Makefile.system index 9c42dc115b..436245c356 100644 --- a/Makefile.system +++ b/Makefile.system @@ -274,9 +274,18 @@ endif ifeq ($(ARCH), loongarch64) SMALL_MATRIX_OPT = 1 endif +ifeq ($(ARCH), arm64) +GEMM_GEMV_FORWARD = 1 +endif + ifeq ($(SMALL_MATRIX_OPT), 1) CCOMMON_OPT += -DSMALL_MATRIX_OPT endif +ifeq ($(GEMM_GEMV_FORWARD), 1) +ifneq ($(ONLY_CBLAS), 1) +CCOMMON_OPT += -DGEMM_GEMV_FORWARD +endif +endif # This operation is expensive, so execution should be once. ifndef GOTOBLAS_MAKEFILE diff --git a/cmake/system.cmake b/cmake/system.cmake index b682c3af81..efb7aef94b 100644 --- a/cmake/system.cmake +++ b/cmake/system.cmake @@ -391,6 +391,13 @@ endif () if (X86_64 OR ${CORE} STREQUAL POWER10) set(SMALL_MATRIX_OPT TRUE) endif () +if (ARM64) + set(GEMM_GEMV_FORWARD TRUE) +endif () + +if (GEMM_GEMV_FORWARD AND NOT ONLY_CBLAS) + set(CCOMMON_OPT "${CCOMMON_OPT} -DGEMM_GEMV_FORWARD") +endif () if (SMALL_MATRIX_OPT) set(CCOMMON_OPT "${CCOMMON_OPT} -DSMALL_MATRIX_OPT") endif () diff --git a/interface/gemm.c b/interface/gemm.c index 4537b6a78f..ac58cf27fa 100644 --- a/interface/gemm.c +++ b/interface/gemm.c @@ -1,4 +1,5 @@ /*********************************************************************/ +/* Copyright 2024 The OpenBLAS Project */ /* Copyright 2009, 2010 The University of Texas at Austin. */ /* All rights reserved. */ /* */ @@ -47,12 +48,16 @@ #define SMP_THRESHOLD_MIN 65536.0 #ifdef XDOUBLE #define ERROR_NAME "QGEMM " +#define GEMV BLASFUNC(qgemv) #elif defined(DOUBLE) #define ERROR_NAME "DGEMM " +#define GEMV BLASFUNC(dgemv) #elif defined(BFLOAT16) #define ERROR_NAME "SBGEMM " +#define GEMV BLASFUNC(sbgemv) #else #define ERROR_NAME "SGEMM " +#define GEMV BLASFUNC(sgemv) #endif #else #define SMP_THRESHOLD_MIN 8192.0 @@ -493,6 +498,52 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS args.m, args.n, args.k, args.lda, args.ldb, args.ldc); #endif +#if defined(GEMM_GEMV_FORWARD) && !defined(GEMM3M) && !defined(COMPLEX) + // Check if we can convert GEMM -> GEMV + if (args.k != 0) { + if (args.n == 1) { + blasint inc_x = 1; + blasint inc_y = 1; + // These were passed in as blasint, but the struct translates them to blaslong + blasint m = args.m; + blasint n = args.k; + blasint lda = args.lda; + // Create new transpose parameters + char NT = 'N'; + if (transa & 1) { + NT = 'T'; + m = args.k; + n = args.m; + } + if (transb & 1) { + inc_x = args.ldb; + } + GEMV(&NT, &m, &n, args.alpha, args.a, &lda, args.b, &inc_x, args.beta, args.c, &inc_y); + return; + } + if (args.m == 1) { + blasint inc_x = args.lda; + blasint inc_y = args.ldc; + // These were passed in as blasint, but the struct translates them to blaslong + blasint m = args.k; + blasint n = args.n; + blasint ldb = args.ldb; + // Create new transpose parameters + char NT = 'T'; + if (transa & 1) { + inc_x = 1; + } + if (transb & 1) { + NT = 'N'; + m = args.n; + n = args.k; + } + GEMV(&NT, &m, &n, args.alpha, args.b, &ldb, args.a, &inc_x, args.beta, args.c, &inc_y); + return; + } + } +#endif + IDEBUG_START; FUNCTION_PROFILE_START(); diff --git a/kernel/arm64/KERNEL.NEOVERSEV1 b/kernel/arm64/KERNEL.NEOVERSEV1 index bc59990979..53d157a0aa 100644 --- a/kernel/arm64/KERNEL.NEOVERSEV1 +++ b/kernel/arm64/KERNEL.NEOVERSEV1 @@ -1 +1,4 @@ include $(KERNELDIR)/KERNEL.ARMV8SVE + +SGEMVTKERNEL = gemv_t_sve.c +DGEMVTKERNEL = gemv_t_sve.c diff --git a/kernel/arm64/gemv_t.S b/kernel/arm64/gemv_t.S index b04367ab37..a98eef49b0 100644 --- a/kernel/arm64/gemv_t.S +++ b/kernel/arm64/gemv_t.S @@ -1,5 +1,5 @@ /******************************************************************************* -Copyright (c) 2015, The OpenBLAS Project +Copyright (c) 2015, 2024 The OpenBLAS Project All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -170,39 +170,48 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. .macro KERNEL_F32_FINALIZE #if !defined(DOUBLE) - fadd v1.4s, v1.4s, v2.4s + // F8 only has 2 accumulators + // so add into those pairs fadd v1.4s, v1.4s, v3.4s - fadd v1.4s, v1.4s, v4.4s -#else - fadd v1.2d, v1.2d, v2.2d - fadd v1.2d, v1.2d, v3.2d - fadd v1.2d, v1.2d, v4.2d + fadd v2.4s, v2.4s, v4.4s #endif .endm -.macro KERNEL_F4 +.macro KERNEL_F8 #if !defined(DOUBLE) - ld1 {v2.4s}, [A_PTR], #16 - ld1 {v3.4s}, [X_PTR], #16 - fmla v1.4s, v2.4s, v3.4s -#else - ld1 {v2.2d}, [A_PTR], #16 - ld1 {v3.2d}, [X_PTR], #16 - fmla v1.2d, v2.2d, v3.2d - - ld1 {v4.2d}, [A_PTR], #16 - ld1 {v5.2d}, [X_PTR], #16 - fmla v1.2d, v4.2d, v5.2d + ld1 {v13.4s, v14.4s}, [A_PTR], #32 + ld1 {v17.4s, v18.4s}, [X_PTR], #32 + fmla v1.4s, v13.4s, v17.4s + fmla v2.4s, v14.4s, v18.4s +#else + ld1 {v13.2d, v14.2d, v15.2d, v16.2d}, [A_PTR], #64 + ld1 {v17.2d, v18.2d, v19.2d, v20.2d}, [X_PTR], #64 + fmla v1.2d, v13.2d, v17.2d + fmla v2.2d, v14.2d, v18.2d + fmla v3.2d, v15.2d, v19.2d + fmla v4.2d, v16.2d, v20.2d #endif .endm -.macro KERNEL_F4_FINALIZE +.macro KERNEL_F8_FINALIZE #if !defined(DOUBLE) - ext v2.16b, v1.16b, v1.16b, #8 + // Take the top two elements of v1 and + // put them into the first two lanes of v3 + ext v3.16b, v1.16b, v1.16b, #8 + fadd v1.2s, v1.2s, v3.2s + ext v4.16b, v2.16b, v2.16b, #8 + fadd v2.2s, v2.2s, v4.2s + // Final pair fadd v1.2s, v1.2s, v2.2s faddp TEMP, v1.2s #else faddp TEMP, v1.2d + faddp TEMP1, v2.2d + faddp TEMP2, v3.2d + faddp TEMP3, v4.2d + fadd TEMP, TEMP, TEMP1 + fadd TEMP2, TEMP2, TEMP3 + fadd TEMP, TEMP, TEMP2 #endif .endm @@ -258,7 +267,7 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. asr I, M, #5 cmp I, xzr - beq .Lgemv_t_kernel_F4 + beq .Lgemv_t_kernel_F8 .Lgemv_t_kernel_F320: @@ -269,24 +278,24 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. KERNEL_F32_FINALIZE -.Lgemv_t_kernel_F4: +.Lgemv_t_kernel_F8: ands I, M, #31 - asr I, I, #2 + asr I, I, #3 cmp I, xzr beq .Lgemv_t_kernel_F1 -.Lgemv_t_kernel_F40: +.Lgemv_t_kernel_F80: - KERNEL_F4 + KERNEL_F8 subs I, I, #1 - bne .Lgemv_t_kernel_F40 + bne .Lgemv_t_kernel_F80 .Lgemv_t_kernel_F1: - KERNEL_F4_FINALIZE + KERNEL_F8_FINALIZE - ands I, M, #3 + ands I, M, #7 ble .Lgemv_t_kernel_F_END .Lgemv_t_kernel_F10: diff --git a/kernel/arm64/gemv_t_sve.c b/kernel/arm64/gemv_t_sve.c index ab700a3748..183d9c3d16 100644 --- a/kernel/arm64/gemv_t_sve.c +++ b/kernel/arm64/gemv_t_sve.c @@ -59,20 +59,46 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT alpha, FLOAT *a, BLASLO a_ptr = a; if (inc_x == 1) { + svbool_t pg_true = SV_TRUE(); uint64_t sve_size = SV_COUNT(); + uint64_t sve_size2 = sve_size * 2; + BLASLONG m1 = m & -sve_size; + BLASLONG m2 = m & -sve_size2; + for (j = 0; j < n; j++) { + BLASLONG i = 0; + + SV_TYPE temp_vec_v2_0 = SV_DUP(0.0); + SV_TYPE temp_vec_v2_1 = SV_DUP(0.0); + for (; i < m2; i += sve_size2) { + SV_TYPE a_vec0 = svld1(pg_true, a_ptr + i); + SV_TYPE x_vec0 = svld1(pg_true, x + i); + SV_TYPE a_vec1 = svld1(pg_true, a_ptr + i + sve_size); + SV_TYPE x_vec1 = svld1(pg_true, x + i + sve_size); + temp_vec_v2_0 = svmla_m(pg_true, temp_vec_v2_0, a_vec0, x_vec0); + temp_vec_v2_1 = svmla_m(pg_true, temp_vec_v2_1, a_vec1, x_vec1); + } + + SV_TYPE temp_vec_v1 = SV_DUP(0.0); + for (; i < m1; i += sve_size) { + SV_TYPE a_vec0 = svld1(pg_true, a_ptr + i); + SV_TYPE x_vec0 = svld1(pg_true, x + i); + temp_vec_v1 = svmla_m(pg_true, temp_vec_v1, a_vec0, x_vec0); + } + SV_TYPE temp_vec = SV_DUP(0.0); - i = 0; - svbool_t pg = SV_WHILE(i, m); - while (svptest_any(SV_TRUE(), pg)) { + for (; i < m; i += sve_size) { + svbool_t pg = SV_WHILE(i, m); SV_TYPE a_vec = svld1(pg, a_ptr + i); SV_TYPE x_vec = svld1(pg, x + i); temp_vec = svmla_m(pg, temp_vec, a_vec, x_vec); - i += sve_size; - pg = SV_WHILE(i, m); } - temp = svaddv(SV_TRUE(), temp_vec); - y[iy] += alpha * temp; + + y[iy] += alpha * ( + (svaddv(SV_TRUE(), temp_vec_v2_0) + svaddv(SV_TRUE(), temp_vec)) + + (svaddv(SV_TRUE(), temp_vec_v2_1) + svaddv(SV_TRUE(), temp_vec_v1)) + ); + iy += inc_y; a_ptr += lda; }