diff --git a/src/cpu/rv64/gemm/rvv_gemm_f32.cpp b/src/cpu/rv64/gemm/rvv_gemm_f32.cpp index 4c342588100..e51de0285bc 100644 --- a/src/cpu/rv64/gemm/rvv_gemm_f32.cpp +++ b/src/cpu/rv64/gemm/rvv_gemm_f32.cpp @@ -34,11 +34,29 @@ namespace rv64 { using namespace dnnl::impl::utils; using namespace gemm_utils; +using gemm_f32_traits = gemm_utils::gemm_utils_traits; namespace { + +#define STORE_C(C_PTR, V_C_REG, ALPHA, BETA, VL) \ + do { \ + float *c_final_ptr = (C_PTR); \ + if ((BETA) == 0.0f) { \ + vfloat32m1_t v_res \ + = __riscv_vfmul_vf_f32m1((V_C_REG), (ALPHA), (VL)); \ + __riscv_vse32_v_f32m1(c_final_ptr, v_res, (VL)); \ + } else { \ + vfloat32m1_t v_c_old = __riscv_vle32_v_f32m1(c_final_ptr, (VL)); \ + vfloat32m1_t v_res \ + = __riscv_vfmul_vf_f32m1(v_c_old, (BETA), (VL)); \ + v_res = __riscv_vfmacc_vf_f32m1(v_res, (ALPHA), (V_C_REG), (VL)); \ + __riscv_vse32_v_f32m1(c_final_ptr, v_res, (VL)); \ + } \ + } while (0) + void copy_A( bool isTransA, dim_t K, const float *A, const dim_t lda, float *ws) { - constexpr dim_t m = unroll_factor::m; + constexpr dim_t m = gemm_f32_traits::get_m_unroll_factor(); for (dim_t k = 0; k < K; k++) { dim_t i = 0; @@ -64,65 +82,306 @@ void copy_A( } } +template +struct kernel_mxn_impl { + static void execute(dim_t K, const float *A, dim_t lda, const float *B, + dim_t ldb, float *C, dim_t ldc, float alpha, float beta, + int ithr = -1); +}; + template -void kernel_mxn(dim_t K, const float *A, const dim_t lda, const float *B, - const dim_t ldb, float *C, const dim_t ldc, const float alpha, - const float beta, int ithr = -1) { - constexpr dim_t m = unroll_factor::m; - constexpr dim_t n = unroll_factor::n; +struct kernel_mxn_impl { + static void execute(dim_t K, const float *A, dim_t lda, const float *B, + dim_t ldb, float *C, dim_t ldc, float alpha, float beta, + int ithr = -1) { + constexpr dim_t m = gemm_f32_traits::get_m_unroll_factor(); + constexpr dim_t n = 2; + MAYBE_UNUSED(ithr); + MAYBE_UNUSED(n); - static_assert(n == 4, "This kernel is specialized for n=4"); + dim_t i = 0; + while (i < m) { + size_t vl = __riscv_vsetvl_e32m1(m - i); - dim_t i = 0; - while (i < m) { - size_t vl = __riscv_vsetvl_e32m1(m - i); + vfloat32m1_t v_c0 = __riscv_vfmv_v_f_f32m1(0.0f, vl); + vfloat32m1_t v_c1 = __riscv_vfmv_v_f_f32m1(0.0f, vl); - vfloat32m1_t v_c0, v_c1, v_c2, v_c3; + for (dim_t k = 0; k < K; ++k) { + vfloat32m1_t v_a; + if (isTransA) { + ptrdiff_t stride_a = lda * sizeof(float); + v_a = __riscv_vlse32_v_f32m1(A + i * lda + k, stride_a, vl); + } else { + v_a = __riscv_vle32_v_f32m1(A + i + k * lda, vl); + } - v_c0 = __riscv_vfmv_v_f_f32m1(0.0f, vl); - v_c1 = __riscv_vfmv_v_f_f32m1(0.0f, vl); - v_c2 = __riscv_vfmv_v_f_f32m1(0.0f, vl); - v_c3 = __riscv_vfmv_v_f_f32m1(0.0f, vl); + const float *b_ptr = isTransB ? &B[k * ldb] : &B[k]; + const dim_t b_stride = isTransB ? 1 : ldb; - for (dim_t k = 0; k < K; ++k) { - vfloat32m1_t v_a; - if (isTransA) { - ptrdiff_t stride_a = lda * sizeof(float); - v_a = __riscv_vlse32_v_f32m1(A + i * lda + k, stride_a, vl); - } else { - v_a = __riscv_vle32_v_f32m1(A + i + k * lda, vl); + v_c0 = __riscv_vfmacc_vf_f32m1( + v_c0, b_ptr[0 * b_stride], v_a, vl); + v_c1 = __riscv_vfmacc_vf_f32m1( + v_c1, b_ptr[1 * b_stride], v_a, vl); } - const float *b_ptr = isTransB ? &B[k * ldb] : &B[k]; - const dim_t b_stride = isTransB ? 1 : ldb; + STORE_C(C + 0 * ldc + i, v_c0, alpha, beta, vl); + STORE_C(C + 1 * ldc + i, v_c1, alpha, beta, vl); - v_c0 = __riscv_vfmacc_vf_f32m1(v_c0, b_ptr[0 * b_stride], v_a, vl); - v_c1 = __riscv_vfmacc_vf_f32m1(v_c1, b_ptr[1 * b_stride], v_a, vl); - v_c2 = __riscv_vfmacc_vf_f32m1(v_c2, b_ptr[2 * b_stride], v_a, vl); - v_c3 = __riscv_vfmacc_vf_f32m1(v_c3, b_ptr[3 * b_stride], v_a, vl); + i += vl; } + } +}; -#define STORE_C(J, V_C) \ - do { \ - float *c_final_ptr = C + (J)*ldc + i; \ - if (beta == 0.0f) { \ - vfloat32m1_t v_res = __riscv_vfmul_vf_f32m1(V_C, alpha, vl); \ - __riscv_vse32_v_f32m1(c_final_ptr, v_res, vl); \ - } else { \ - vfloat32m1_t v_c_old = __riscv_vle32_v_f32m1(c_final_ptr, vl); \ - vfloat32m1_t v_res = __riscv_vfmul_vf_f32m1(v_c_old, beta, vl); \ - v_res = __riscv_vfmacc_vf_f32m1(v_res, alpha, V_C, vl); \ - __riscv_vse32_v_f32m1(c_final_ptr, v_res, vl); \ - } \ - } while (0) +template +struct kernel_mxn_impl { + static void execute(dim_t K, const float *A, dim_t lda, const float *B, + dim_t ldb, float *C, dim_t ldc, float alpha, float beta, + int ithr = -1) { + constexpr dim_t m = gemm_f32_traits::get_m_unroll_factor(); + constexpr dim_t n = 4; + MAYBE_UNUSED(ithr); + MAYBE_UNUSED(n); - STORE_C(0, v_c0); - STORE_C(1, v_c1); - STORE_C(2, v_c2); - STORE_C(3, v_c3); + dim_t i = 0; + while (i < m) { + size_t vl = __riscv_vsetvl_e32m1(m - i); -#undef STORE_C - i += vl; + vfloat32m1_t v_c0, v_c1, v_c2, v_c3; + + v_c0 = __riscv_vfmv_v_f_f32m1(0.0f, vl); + v_c1 = __riscv_vfmv_v_f_f32m1(0.0f, vl); + v_c2 = __riscv_vfmv_v_f_f32m1(0.0f, vl); + v_c3 = __riscv_vfmv_v_f_f32m1(0.0f, vl); + + for (dim_t k = 0; k < K; ++k) { + vfloat32m1_t v_a; + if (isTransA) { + ptrdiff_t stride_a = lda * sizeof(float); + v_a = __riscv_vlse32_v_f32m1(A + i * lda + k, stride_a, vl); + } else { + v_a = __riscv_vle32_v_f32m1(A + i + k * lda, vl); + } + + const float *b_ptr = isTransB ? &B[k * ldb] : &B[k]; + const dim_t b_stride = isTransB ? 1 : ldb; + + v_c0 = __riscv_vfmacc_vf_f32m1( + v_c0, b_ptr[0 * b_stride], v_a, vl); + v_c1 = __riscv_vfmacc_vf_f32m1( + v_c1, b_ptr[1 * b_stride], v_a, vl); + v_c2 = __riscv_vfmacc_vf_f32m1( + v_c2, b_ptr[2 * b_stride], v_a, vl); + v_c3 = __riscv_vfmacc_vf_f32m1( + v_c3, b_ptr[3 * b_stride], v_a, vl); + } + + STORE_C(C + 0 * ldc + i, v_c0, alpha, beta, vl); + STORE_C(C + 1 * ldc + i, v_c1, alpha, beta, vl); + STORE_C(C + 2 * ldc + i, v_c2, alpha, beta, vl); + STORE_C(C + 3 * ldc + i, v_c3, alpha, beta, vl); + + i += vl; + } + } +}; + +template +struct kernel_mxn_impl { + static void execute(dim_t K, const float *A, dim_t lda, const float *B, + dim_t ldb, float *C, dim_t ldc, float alpha, float beta, + int ithr = -1) { + constexpr dim_t m = gemm_f32_traits::get_m_unroll_factor(); + constexpr dim_t n = 8; + MAYBE_UNUSED(ithr); + MAYBE_UNUSED(n); + + dim_t i = 0; + while (i < m) { + size_t vl = __riscv_vsetvl_e32m1(m - i); + + vfloat32m1_t v_c0, v_c1, v_c2, v_c3, v_c4, v_c5, v_c6, v_c7; + + v_c0 = __riscv_vfmv_v_f_f32m1(0.0f, vl); + v_c1 = __riscv_vfmv_v_f_f32m1(0.0f, vl); + v_c2 = __riscv_vfmv_v_f_f32m1(0.0f, vl); + v_c3 = __riscv_vfmv_v_f_f32m1(0.0f, vl); + v_c4 = __riscv_vfmv_v_f_f32m1(0.0f, vl); + v_c5 = __riscv_vfmv_v_f_f32m1(0.0f, vl); + v_c6 = __riscv_vfmv_v_f_f32m1(0.0f, vl); + v_c7 = __riscv_vfmv_v_f_f32m1(0.0f, vl); + + for (dim_t k = 0; k < K; ++k) { + vfloat32m1_t v_a; + if (isTransA) { + ptrdiff_t stride_a = lda * sizeof(float); + v_a = __riscv_vlse32_v_f32m1(A + i * lda + k, stride_a, vl); + } else { + v_a = __riscv_vle32_v_f32m1(A + i + k * lda, vl); + } + + const float *b_ptr = isTransB ? &B[k * ldb] : &B[k]; + const dim_t b_stride = isTransB ? 1 : ldb; + + v_c0 = __riscv_vfmacc_vf_f32m1( + v_c0, b_ptr[0 * b_stride], v_a, vl); + v_c1 = __riscv_vfmacc_vf_f32m1( + v_c1, b_ptr[1 * b_stride], v_a, vl); + v_c2 = __riscv_vfmacc_vf_f32m1( + v_c2, b_ptr[2 * b_stride], v_a, vl); + v_c3 = __riscv_vfmacc_vf_f32m1( + v_c3, b_ptr[3 * b_stride], v_a, vl); + v_c4 = __riscv_vfmacc_vf_f32m1( + v_c4, b_ptr[4 * b_stride], v_a, vl); + v_c5 = __riscv_vfmacc_vf_f32m1( + v_c5, b_ptr[5 * b_stride], v_a, vl); + v_c6 = __riscv_vfmacc_vf_f32m1( + v_c6, b_ptr[6 * b_stride], v_a, vl); + v_c7 = __riscv_vfmacc_vf_f32m1( + v_c7, b_ptr[7 * b_stride], v_a, vl); + } + + STORE_C(C + 0 * ldc + i, v_c0, alpha, beta, vl); + STORE_C(C + 1 * ldc + i, v_c1, alpha, beta, vl); + STORE_C(C + 2 * ldc + i, v_c2, alpha, beta, vl); + STORE_C(C + 3 * ldc + i, v_c3, alpha, beta, vl); + STORE_C(C + 4 * ldc + i, v_c4, alpha, beta, vl); + STORE_C(C + 5 * ldc + i, v_c5, alpha, beta, vl); + STORE_C(C + 6 * ldc + i, v_c6, alpha, beta, vl); + STORE_C(C + 7 * ldc + i, v_c7, alpha, beta, vl); + + i += vl; + } + } +}; + +template +struct kernel_mxn_impl { + static void execute(dim_t K, const float *A, dim_t lda, const float *B, + dim_t ldb, float *C, dim_t ldc, float alpha, float beta, + int ithr = -1) { + constexpr dim_t m = gemm_f32_traits::get_m_unroll_factor(); + constexpr dim_t n = 16; + MAYBE_UNUSED(ithr); + MAYBE_UNUSED(n); + + dim_t i = 0; + while (i < m) { + size_t vl = __riscv_vsetvl_e32m1(m - i); + + vfloat32m1_t v_c0, v_c1, v_c2, v_c3, v_c4, v_c5, v_c6, v_c7; + vfloat32m1_t v_c8, v_c9, v_c10, v_c11, v_c12, v_c13, v_c14, v_c15; + + v_c0 = __riscv_vfmv_v_f_f32m1(0.0f, vl); + v_c1 = __riscv_vfmv_v_f_f32m1(0.0f, vl); + v_c2 = __riscv_vfmv_v_f_f32m1(0.0f, vl); + v_c3 = __riscv_vfmv_v_f_f32m1(0.0f, vl); + v_c4 = __riscv_vfmv_v_f_f32m1(0.0f, vl); + v_c5 = __riscv_vfmv_v_f_f32m1(0.0f, vl); + v_c6 = __riscv_vfmv_v_f_f32m1(0.0f, vl); + v_c7 = __riscv_vfmv_v_f_f32m1(0.0f, vl); + v_c8 = __riscv_vfmv_v_f_f32m1(0.0f, vl); + v_c9 = __riscv_vfmv_v_f_f32m1(0.0f, vl); + v_c10 = __riscv_vfmv_v_f_f32m1(0.0f, vl); + v_c11 = __riscv_vfmv_v_f_f32m1(0.0f, vl); + v_c12 = __riscv_vfmv_v_f_f32m1(0.0f, vl); + v_c13 = __riscv_vfmv_v_f_f32m1(0.0f, vl); + v_c14 = __riscv_vfmv_v_f_f32m1(0.0f, vl); + v_c15 = __riscv_vfmv_v_f_f32m1(0.0f, vl); + + for (dim_t k = 0; k < K; ++k) { + vfloat32m1_t v_a; + if (isTransA) { + ptrdiff_t stride_a = lda * sizeof(float); + v_a = __riscv_vlse32_v_f32m1(A + i * lda + k, stride_a, vl); + } else { + v_a = __riscv_vle32_v_f32m1(A + i + k * lda, vl); + } + + const float *b_ptr = isTransB ? &B[k * ldb] : &B[k]; + const dim_t b_stride = isTransB ? 1 : ldb; + + v_c0 = __riscv_vfmacc_vf_f32m1( + v_c0, b_ptr[0 * b_stride], v_a, vl); + v_c1 = __riscv_vfmacc_vf_f32m1( + v_c1, b_ptr[1 * b_stride], v_a, vl); + v_c2 = __riscv_vfmacc_vf_f32m1( + v_c2, b_ptr[2 * b_stride], v_a, vl); + v_c3 = __riscv_vfmacc_vf_f32m1( + v_c3, b_ptr[3 * b_stride], v_a, vl); + v_c4 = __riscv_vfmacc_vf_f32m1( + v_c4, b_ptr[4 * b_stride], v_a, vl); + v_c5 = __riscv_vfmacc_vf_f32m1( + v_c5, b_ptr[5 * b_stride], v_a, vl); + v_c6 = __riscv_vfmacc_vf_f32m1( + v_c6, b_ptr[6 * b_stride], v_a, vl); + v_c7 = __riscv_vfmacc_vf_f32m1( + v_c7, b_ptr[7 * b_stride], v_a, vl); + v_c8 = __riscv_vfmacc_vf_f32m1( + v_c8, b_ptr[8 * b_stride], v_a, vl); + v_c9 = __riscv_vfmacc_vf_f32m1( + v_c9, b_ptr[9 * b_stride], v_a, vl); + v_c10 = __riscv_vfmacc_vf_f32m1( + v_c10, b_ptr[10 * b_stride], v_a, vl); + v_c11 = __riscv_vfmacc_vf_f32m1( + v_c11, b_ptr[11 * b_stride], v_a, vl); + v_c12 = __riscv_vfmacc_vf_f32m1( + v_c12, b_ptr[12 * b_stride], v_a, vl); + v_c13 = __riscv_vfmacc_vf_f32m1( + v_c13, b_ptr[13 * b_stride], v_a, vl); + v_c14 = __riscv_vfmacc_vf_f32m1( + v_c14, b_ptr[14 * b_stride], v_a, vl); + v_c15 = __riscv_vfmacc_vf_f32m1( + v_c15, b_ptr[15 * b_stride], v_a, vl); + } + + STORE_C(C + 0 * ldc + i, v_c0, alpha, beta, vl); + STORE_C(C + 1 * ldc + i, v_c1, alpha, beta, vl); + STORE_C(C + 2 * ldc + i, v_c2, alpha, beta, vl); + STORE_C(C + 3 * ldc + i, v_c3, alpha, beta, vl); + STORE_C(C + 4 * ldc + i, v_c4, alpha, beta, vl); + STORE_C(C + 5 * ldc + i, v_c5, alpha, beta, vl); + STORE_C(C + 6 * ldc + i, v_c6, alpha, beta, vl); + STORE_C(C + 7 * ldc + i, v_c7, alpha, beta, vl); + STORE_C(C + 8 * ldc + i, v_c8, alpha, beta, vl); + STORE_C(C + 9 * ldc + i, v_c9, alpha, beta, vl); + STORE_C(C + 10 * ldc + i, v_c10, alpha, beta, vl); + STORE_C(C + 11 * ldc + i, v_c11, alpha, beta, vl); + STORE_C(C + 12 * ldc + i, v_c12, alpha, beta, vl); + STORE_C(C + 13 * ldc + i, v_c13, alpha, beta, vl); + STORE_C(C + 14 * ldc + i, v_c14, alpha, beta, vl); + STORE_C(C + 15 * ldc + i, v_c15, alpha, beta, vl); + + i += vl; + } + } +}; + +template +void kernel_mxn(dim_t K, const float *A, const dim_t lda, const float *B, + const dim_t ldb, float *C, const dim_t ldc, const float alpha, + const float beta, int ithr = -1) { + dim_t n_unroll = gemm_f32_traits::get_n_unroll_factor(); + + switch (n_unroll) { + case 2: + kernel_mxn_impl::execute( + K, A, lda, B, ldb, C, ldc, alpha, beta, ithr); + break; + case 4: + kernel_mxn_impl::execute( + K, A, lda, B, ldb, C, ldc, alpha, beta, ithr); + break; + case 8: + kernel_mxn_impl::execute( + K, A, lda, B, ldb, C, ldc, alpha, beta, ithr); + break; + case 16: + kernel_mxn_impl::execute( + K, A, lda, B, ldb, C, ldc, alpha, beta, ithr); + break; + default: + kernel_mxn_impl::execute( + K, A, lda, B, ldb, C, ldc, alpha, beta, ithr); } } @@ -132,16 +391,20 @@ void block_ker(const dim_t M, const dim_t N, const dim_t K, const float *A, const dim_t ldc, const float alpha, const float beta, float *ws, bool do_copy, int ithr = -1) { - dim_t Nu = rnd_dn(N, unroll_factor::n); - dim_t Mu = rnd_dn(M, unroll_factor::m); - for (dim_t i = 0; i < Mu; i += unroll_factor::m) { - for (dim_t j = 0; j < Nu; j += unroll_factor::n) { + dim_t n_unroll = gemm_f32_traits::get_n_unroll_factor(); + dim_t m_unroll = gemm_f32_traits::get_m_unroll_factor(); + + dim_t Nu = rnd_dn(N, n_unroll); + dim_t Mu = rnd_dn(M, m_unroll); + + for (dim_t i = 0; i < Mu; i += m_unroll) { + for (dim_t j = 0; j < Nu; j += n_unroll) { const float *b = isTransB ? &B[j] : &B[j * ldb]; const float *a = isTransA ? &A[i * lda] : &A[i]; if (do_copy) { if (j == 0) { copy_A(isTransA, K, a, lda, ws); } - kernel_mxn(K, ws, unroll_factor::m, b, - ldb, &C[i + j * ldc], ldc, alpha, beta, ithr); + kernel_mxn(K, ws, m_unroll, b, ldb, + &C[i + j * ldc], ldc, alpha, beta, ithr); } else { kernel_mxn(K, a, lda, b, ldb, &C[i + j * ldc], ldc, alpha, beta, ithr); @@ -237,11 +500,6 @@ status_t rvv_gemm_f32(const char *transa_, const char *transb_, const dim_t *M_, bool isTransA = (*transa_ == 'T' || *transa_ == 't'); bool isTransB = (*transb_ == 'T' || *transb_ == 't'); - if (isTransA && !isTransB) { - return ref_gemm(transa_, transb_, M_, N_, K_, alpha_, A, lda_, B, - ldb_, beta_, C, ldc_, bias); - } - const dim_t M = *M_, N = *N_, K = *K_; const dim_t lda = *lda_, ldb = *ldb_, ldc = *ldc_; const float alpha = *alpha_, beta = *beta_; @@ -269,10 +527,10 @@ status_t rvv_gemm_f32(const char *transa_, const char *transb_, const dim_t *M_, } } - bool do_copy = (NB / unroll_factor::n > 3); + bool do_copy = (NB / gemm_f32_traits::get_n_unroll_factor() > 3); const int nthr_mn = nthr_m * nthr_n; const int nthr_to_use = nthr_mn * nthr_k; - const size_t ws_elems_per_thr = K * unroll_factor::m; + const size_t ws_elems_per_thr = K * gemm_f32_traits::get_m_unroll_factor(); const size_t ws_size_per_thr = rnd_up(ws_elems_per_thr * sizeof(float), PAGE_4K); if (do_copy) { @@ -366,6 +624,7 @@ status_t rvv_gemm_f32(const char *transa_, const char *transb_, const dim_t *M_, get_thr_block(m_from, m_to, myM, MB, M, ithr_m); dim_t offset = 0, block = 0; + gemm_utils::partition_unit_diff( ithr_k, nthr_k, myN, &offset, &block); for (int ik = 1; ik < nthr_k; ++ik) { @@ -388,7 +647,9 @@ status_t rvv_gemm_f32(const char *transa_, const char *transb_, const dim_t *M_, return status::success; } +#undef STORE_C + } // namespace rv64 } // namespace cpu } // namespace impl -} // namespace dnnl \ No newline at end of file +} // namespace dnnl diff --git a/src/cpu/rv64/gemm/rvv_gemm_utils_f32.hpp b/src/cpu/rv64/gemm/rvv_gemm_utils_f32.hpp index 1a5f7d051d2..7a7ee52798b 100644 --- a/src/cpu/rv64/gemm/rvv_gemm_utils_f32.hpp +++ b/src/cpu/rv64/gemm/rvv_gemm_utils_f32.hpp @@ -20,31 +20,52 @@ #include "common/c_types_map.hpp" #include +#include namespace dnnl { namespace impl { namespace cpu { namespace rv64 { namespace gemm_utils { + template struct gemm_traits_t {}; template struct gemm_traits_t { static constexpr dim_t m = 8; - static constexpr dim_t n = 4; static constexpr dim_t BM = 4032; static constexpr dim_t BN = isTransA ? 96 : 48; static constexpr dim_t BK = isTransB ? 96 : 256; }; template -struct unroll_factor {}; +struct gemm_utils_traits; template <> -struct unroll_factor { - static constexpr dim_t m = gemm_traits_t::m; - static constexpr dim_t n = gemm_traits_t::n; +struct gemm_utils_traits { + static constexpr dim_t get_m_unroll_factor() { + return gemm_traits_t::m; + } + + static dim_t get_n_unroll_factor() { + long l1d_size = get_l1d_cache_size(); + if (l1d_size >= 128 * 1024) + return 16; + else if (l1d_size >= 64 * 1024) + return 8; + else if (l1d_size >= 32 * 1024) + return 4; + else + return 2; + } + +private: + static long get_l1d_cache_size() { + static long l1d_size = sysconf(_SC_LEVEL1_DCACHE_SIZE); + if (l1d_size == -1) { l1d_size = 32 * 1024; } + return l1d_size; + } }; // Sum the m*n values from p_src into p_dst, assuming the two-dimensional