diff --git a/src/cpu/gemm/gemm.cpp b/src/cpu/gemm/gemm.cpp index f71eb451157..5ff06dbce8f 100644 --- a/src/cpu/gemm/gemm.cpp +++ b/src/cpu/gemm/gemm.cpp @@ -47,6 +47,11 @@ using namespace dnnl::impl::cpu::x64; using namespace dnnl::impl::cpu::ppc64; #elif DNNL_S390X #include "cpu/s390x/gemm.h" +#elif DNNL_RV64 +#if DNNL_RISCV_USE_RVV_INTRINSICS +#include "cpu/rv64/gemm/rvv_gemm_bf16.hpp" +using namespace dnnl::impl::cpu::rv64; +#endif #endif namespace dnnl { @@ -305,6 +310,9 @@ dnnl_status_t gemm_bf16bf16f32(const char *transa, const char *transb, *ldc); return dnnl_success; #endif +#elif DNNL_RV64 && DNNL_RISCV_USE_RVV_INTRINSICS + return rvv_gemm_bf16bf16f32( + transa, transb, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc); #endif return ref_gemm_bf16bf16f32( @@ -313,4 +321,4 @@ dnnl_status_t gemm_bf16bf16f32(const char *transa, const char *transb, } // namespace cpu } // namespace impl -} // namespace dnnl +} // namespace dnnl \ No newline at end of file diff --git a/src/cpu/platform.cpp b/src/cpu/platform.cpp index e72c9a37642..1f4c99f2f69 100644 --- a/src/cpu/platform.cpp +++ b/src/cpu/platform.cpp @@ -119,6 +119,12 @@ bool has_data_type_support(data_type_t data_type) { #endif #elif DNNL_AARCH64 return aarch64::mayiuse_bf16(); +#elif DNNL_RV64 +#if DNNL_RISCV_USE_RVV_INTRINSICS + return true; +#else + return false; +#endif #else return false; #endif diff --git a/src/cpu/rv64/gemm/rvv_gemm_bf16.cpp b/src/cpu/rv64/gemm/rvv_gemm_bf16.cpp new file mode 100644 index 00000000000..5fa846e5a2a --- /dev/null +++ b/src/cpu/rv64/gemm/rvv_gemm_bf16.cpp @@ -0,0 +1,310 @@ +/******************************************************************************* +* Copyright 2018-2025 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "oneapi/dnnl/dnnl_types.h" + +#include "common/dnnl_thread.hpp" +#include "common/nstl.hpp" +#include "common/utils.hpp" +#include "common/bfloat16.hpp" + +#include "cpu/platform.hpp" + +#include "cpu/rv64/gemm/rvv_gemm_bf16.hpp" +#include "cpu/rv64/gemm/rvv_gemm_utils_bf16.hpp" + +#include + +namespace dnnl { +namespace impl { +namespace cpu { +namespace rv64 { + +using namespace dnnl::impl::utils; +using namespace gemm_utils; + +namespace { + +inline void rvv_cvt_bf16_to_f32_vector(float *out, const bfloat16_t *inp, size_t nelems) { + size_t i = 0; + while (i < nelems) { + size_t vl = __riscv_vsetvl_e16m1(nelems - i); + + vuint16m1_t v_bf16 = __riscv_vle16_v_u16m1((const uint16_t*)(inp + i), vl); + + vuint32m2_t v_f32_bits = __riscv_vzext_vf2_u32m2(v_bf16, vl); + v_f32_bits = __riscv_vsll_vx_u32m2(v_f32_bits, 16, vl); + + vfloat32m2_t v_f32 = __riscv_vreinterpret_v_u32m2_f32m2(v_f32_bits); + + __riscv_vse32_v_f32m2(out + i, v_f32, vl); + + i += vl; + } +} + +void copy_A(bool isTransA, dim_t K, const bfloat16_t *A, const dim_t lda, float *ws) { + constexpr dim_t m = unroll_factor_bf16::m; + + for (dim_t k = 0; k < K; k++) { + if (isTransA) { + dim_t i = 0; + while (i < m) { + size_t vl = __riscv_vsetvl_e16m1(m - i); + ptrdiff_t stride = lda * sizeof(bfloat16_t); + const bfloat16_t *a_ptr = A + i * lda + k; + + vuint16m1_t v_a_bf16 = __riscv_vlse16_v_u16m1( + (const uint16_t*)a_ptr, stride, vl); + + vuint32m2_t v_a_f32_bits = __riscv_vzext_vf2_u32m2(v_a_bf16, vl); + v_a_f32_bits = __riscv_vsll_vx_u32m2(v_a_f32_bits, 16, vl); + vfloat32m2_t v_a_f32 = __riscv_vreinterpret_v_u32m2_f32m2(v_a_f32_bits); + + __riscv_vse32_v_f32m2(ws + i, v_a_f32, vl); + i += vl; + } + } else { + const bfloat16_t *a_ptr = A + k * lda; + rvv_cvt_bf16_to_f32_vector(ws, a_ptr, m); + } + ws += m; + } +} + +template +void kernel_mxn(dim_t K, const bfloat16_t *A, const dim_t lda, const bfloat16_t *B, + const dim_t ldb, float *C, const dim_t ldc, const float alpha, + const float beta, int ithr = -1) { + constexpr dim_t m = unroll_factor_bf16::m; + constexpr dim_t n = unroll_factor_bf16::n; + + float c[m * n] = {0.0f}; + + for (dim_t k = 0; k < K; k++) { + dim_t i = 0; + while (i < m) { + size_t vl = __riscv_vsetvl_e16m1(m - i); + vfloat32m2_t v_a; + + if (isTransA) { + ptrdiff_t stride_a = lda * sizeof(bfloat16_t); + vuint16m1_t v_a_bf16 = __riscv_vlse16_v_u16m1( + (const uint16_t*)(A + i * lda + k), stride_a, vl); + + vuint32m2_t v_a_f32_bits = __riscv_vzext_vf2_u32m2(v_a_bf16, vl); + v_a_f32_bits = __riscv_vsll_vx_u32m2(v_a_f32_bits, 16, vl); + v_a = __riscv_vreinterpret_v_u32m2_f32m2(v_a_f32_bits); + } else { + vuint16m1_t v_a_bf16 = __riscv_vle16_v_u16m1( + (const uint16_t*)(A + i + k * lda), vl); + + vuint32m2_t v_a_f32_bits = __riscv_vzext_vf2_u32m2(v_a_bf16, vl); + v_a_f32_bits = __riscv_vsll_vx_u32m2(v_a_f32_bits, 16, vl); + v_a = __riscv_vreinterpret_v_u32m2_f32m2(v_a_f32_bits); + } + + for (dim_t j = 0; j < n; j++) { + bfloat16_t b_bf16 = isTransB ? B[j + k * ldb] : B[k + j * ldb]; + float b = static_cast(b_bf16); + + float *c_col_ptr = c + m * j + i; + vfloat32m2_t v_c = __riscv_vle32_v_f32m2(c_col_ptr, vl); + + v_c = __riscv_vfmacc_vf_f32m2(v_c, b, v_a, vl); + __riscv_vse32_v_f32m2(c_col_ptr, v_c, vl); + } + i += vl; + } + } + + for (dim_t j = 0; j < n; j++) { + dim_t i = 0; + while (i < m) { + size_t vl = __riscv_vsetvl_e32m2(m - i); + + float *c_final_ptr = C + j * ldc + i; + float *c_acc_ptr = c + j * m + i; + + vfloat32m2_t v_acc = __riscv_vle32_v_f32m2(c_acc_ptr, vl); + vfloat32m2_t v_res; + + if (beta == 0.0f) { + v_res = __riscv_vfmul_vf_f32m2(v_acc, alpha, vl); + } else { + vfloat32m2_t v_c_old = __riscv_vle32_v_f32m2(c_final_ptr, vl); + v_res = __riscv_vfmul_vf_f32m2(v_c_old, beta, vl); + v_res = __riscv_vfmacc_vf_f32m2(v_res, alpha, v_acc, vl); + } + + __riscv_vse32_v_f32m2(c_final_ptr, v_res, vl); + i += vl; + } + } +} + +template +void block_ker(const dim_t M, const dim_t N, const dim_t K, const bfloat16_t *A, + const dim_t lda, const bfloat16_t *B, const dim_t ldb, float *C, + const dim_t ldc, const float alpha, const float beta, float *ws, + bool do_copy, int ithr = -1) { + + constexpr dim_t m = unroll_factor_bf16::m; + constexpr dim_t n = unroll_factor_bf16::n; + dim_t Nu = (N / n) * n; + dim_t Mu = (M / m) * m; + + for (dim_t i = 0; i < Mu; i += m) { + for (dim_t j = 0; j < Nu; j += n) { + const bfloat16_t *b = isTransB ? &B[j] : &B[j * ldb]; + const bfloat16_t *a = isTransA ? &A[i * lda] : &A[i]; + if (do_copy) { + if (j == 0) { copy_A(isTransA, K, a, lda, ws); } + kernel_mxn(K, a, lda, b, ldb, + &C[i + j * ldc], ldc, alpha, beta, ithr); + } else { + kernel_mxn(K, a, lda, b, ldb, + &C[i + j * ldc], ldc, alpha, beta, ithr); + } + } + } + + for (dim_t i = 0; i < M; i++) { + for (dim_t j = Nu; j < N; j++) { + float c = beta == 0.f ? 0.f : beta * C[i + j * ldc]; + + for (dim_t p = 0; p < K; p++) { + float b = static_cast(isTransB ? B[j + p * ldb] : B[p + j * ldb]); + float a = static_cast(isTransA ? A[p + i * lda] : A[i + p * lda]); + c += alpha * a * b; + } + C[i + j * ldc] = c; + } + } + for (dim_t i = Mu; i < M; i++) { + for (dim_t j = 0; j < Nu; j++) { + float c = beta == 0.f ? 0.f : beta * C[i + j * ldc]; + + for (dim_t p = 0; p < K; p++) { + float b = static_cast(isTransB ? B[j + p * ldb] : B[p + j * ldb]); + float a = static_cast(isTransA ? A[p + i * lda] : A[i + p * lda]); + c += alpha * a * b; + } + C[i + j * ldc] = c; + } + } +} + +template +void gemm_ithr(const dim_t M, const dim_t N, const dim_t K, const float alpha, + const bfloat16_t *A, const dim_t lda, const bfloat16_t *B, const dim_t ldb, + const float beta, float *C, const dim_t ldc, bool do_copy, float *ws, + int ithr = -1) { + + constexpr dim_t BM = gemm_traits_t::BM; + constexpr dim_t BN = gemm_traits_t::BN; + constexpr dim_t BK = gemm_traits_t::BK; + + const bfloat16_t *curA; + const bfloat16_t *curB; + float *curC; + + if ((M <= 0) || (N <= 0)) return; + + if ((K <= 0) || (alpha == 0.f)) { + dim_t MN = N * M; + if (beta == 0.f) { + dim_t j = 0; + while (j < MN) { + size_t vl = __riscv_vsetvl_e32m1(MN - j); + vfloat32m1_t v_zero = __riscv_vfmv_v_f_f32m1(0.0f, vl); + __riscv_vse32_v_f32m1(C + j, v_zero, vl); + j += vl; + } + } else if (beta != 1.f) { + dim_t j = 0; + while (j < MN) { + size_t vl = __riscv_vsetvl_e32m1(MN - j); + vfloat32m1_t v_c = __riscv_vle32_v_f32m1(C + j, vl); + v_c = __riscv_vfmul_vf_f32m1(v_c, beta, vl); + __riscv_vse32_v_f32m1(C + j, v_c, vl); + j += vl; + } + } + return; + } + + for (dim_t Bk = 0; Bk < K; Bk += BK) { + dim_t kb = nstl::min(K - Bk, BK); + for (dim_t Bm = 0; Bm < M; Bm += BM) { + dim_t mb = nstl::min(M - Bm, BM); + for (dim_t Bn = 0; Bn < N; Bn += BN) { + dim_t nb = nstl::min(N - Bn, BN); + curA = isTransA ? &A[Bk + Bm * lda] : &A[Bm + Bk * lda]; + curB = isTransB ? &B[Bn + Bk * ldb] : &B[Bk + Bn * ldb]; + curC = &C[Bm + Bn * ldc]; + + if (Bk == 0) { + block_ker(mb, nb, kb, curA, lda, curB, + ldb, curC, ldc, alpha, beta, ws, do_copy, ithr); + } else { + block_ker(mb, nb, kb, curA, lda, curB, + ldb, curC, ldc, alpha, 1.0f, ws, do_copy, ithr); + } + } + } + } +} + +} // namespace + +dnnl_status_t rvv_gemm_bf16bf16f32(const char *transa_, const char *transb_, + const dim_t *M_, const dim_t *N_, const dim_t *K_, const float *alpha_, + const bfloat16_t *A, const dim_t *lda_, const bfloat16_t *B, const dim_t *ldb_, + const float *beta_, float *C, const dim_t *ldc_) { + + if (!(utils::one_of(*transa_, 'n', 'N', 't', 'T') + && utils::one_of(*transb_, 'n', 'N', 't', 'T'))) + return dnnl_unimplemented; + + bool isTransA = (*transa_ == 'T' || *transa_ == 't'); + bool isTransB = (*transb_ == 'T' || *transb_ == 't'); + const dim_t M = *M_, N = *N_, K = *K_; + const dim_t lda = *lda_, ldb = *ldb_, ldc = *ldc_; + const float alpha = *alpha_, beta = *beta_; + + if (utils::one_of(0, M, N)) return dnnl_success; + + const bool do_copy = false; + float *ws = nullptr; + + if (!isTransA && !isTransB) { + gemm_ithr(M, N, K, alpha, A, lda, B, ldb, beta, C, ldc, do_copy, ws); + } else if (!isTransA && isTransB) { + gemm_ithr(M, N, K, alpha, A, lda, B, ldb, beta, C, ldc, do_copy, ws); + } else if (isTransA && !isTransB) { + gemm_ithr(M, N, K, alpha, A, lda, B, ldb, beta, C, ldc, do_copy, ws); + } else { + gemm_ithr(M, N, K, alpha, A, lda, B, ldb, beta, C, ldc, do_copy, ws); + } + + return dnnl_success; +} + +} // namespace rv64 +} // namespace cpu +} // namespace impl +} // namespace dnnl \ No newline at end of file diff --git a/src/cpu/rv64/gemm/rvv_gemm_bf16.hpp b/src/cpu/rv64/gemm/rvv_gemm_bf16.hpp new file mode 100644 index 00000000000..be671958c43 --- /dev/null +++ b/src/cpu/rv64/gemm/rvv_gemm_bf16.hpp @@ -0,0 +1,39 @@ +/******************************************************************************* +* Copyright 2018-2025 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_RV64_GEMM_RVV_GEMM_BF16_HPP +#define CPU_RV64_GEMM_RVV_GEMM_BF16_HPP + +#include "oneapi/dnnl/dnnl_types.h" +#include "common/c_types_map.hpp" +#include "common/bfloat16.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { +namespace rv64 { + +dnnl_status_t rvv_gemm_bf16bf16f32(const char *transa, const char *transb, + const dim_t *M, const dim_t *N, const dim_t *K, const float *alpha, + const bfloat16_t *A, const dim_t *lda, const bfloat16_t *B, const dim_t *ldb, + const float *beta, float *C, const dim_t *ldc); + +} // namespace rv64 +} // namespace cpu +} // namespace impl +} // namespace dnnl + +#endif // CPU_RV64_GEMM_RVV_GEMM_BF16_HPP \ No newline at end of file diff --git a/src/cpu/rv64/gemm/rvv_gemm_utils_bf16.cpp b/src/cpu/rv64/gemm/rvv_gemm_utils_bf16.cpp new file mode 100644 index 00000000000..ad63640e442 --- /dev/null +++ b/src/cpu/rv64/gemm/rvv_gemm_utils_bf16.cpp @@ -0,0 +1,97 @@ +/******************************************************************************* +* Copyright 2018-2025 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include "cpu/rv64/gemm/rvv_gemm_utils_bf16.hpp" +#include "common/nstl.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { +namespace rv64 { +namespace gemm_utils { + +#define BM_NOCOPY_RVV_BF16 64 +#define BN_NOCOPY_RVV_BF16 48 +#define BK_NOCOPY_RVV_BF16 384 +#define BN_LARGE_NOCOPY_RVV_BF16 192 +#define BM_SMALL_NOCOPY_RVV_BF16 16 +#define BN_SMALL_NOCOPY_RVV_BF16 1 +#define BK_SMALL_NOCOPY_RVV_BF16 4 + +// Threading calculation for bf16 GEMM, similar to f32 version +void calc_nthr_nocopy_rvv_bf16(dim_t m, dim_t n, dim_t k, int nthrs, int *nthrs_m, + int *nthrs_n, int *nthrs_k, dim_t *BM, dim_t *BN, dim_t *BK) { + + int nthr_m = 1, nthr_n = 1, nthr_k = 1; + dim_t bm = BM_NOCOPY_RVV_BF16, bn = BN_NOCOPY_RVV_BF16, bk = BK_NOCOPY_RVV_BF16; + + if (nthrs <= 1) { + *nthrs_m = nthr_m; + *nthrs_n = nthr_n; + *nthrs_k = nthr_k; + *BM = bm; + *BN = bn; + *BK = bk; + return; + } + + // For small problems, use smaller block sizes + if (m * n * k < 1000000) { + bm = BM_SMALL_NOCOPY_RVV_BF16; + bn = BN_SMALL_NOCOPY_RVV_BF16; + bk = BK_SMALL_NOCOPY_RVV_BF16; + } + + // Simple heuristic for thread distribution + if (m >= n && m >= k) { + nthr_m = nthrs; + } else if (n >= k) { + nthr_n = nthrs; + } else { + nthr_k = nthrs; + } + + // Adjust for large N dimension + if (n > 1000) { + bn = BN_LARGE_NOCOPY_RVV_BF16; + } + + *nthrs_m = nthr_m; + *nthrs_n = nthr_n; + *nthrs_k = nthr_k; + *BM = bm; + *BN = bn; + *BK = bk; +} + +void partition_unit_diff_bf16( + int ithr, int nthr, dim_t n, dim_t *t_offset, dim_t *t_block) { + dim_t band = n / nthr; + dim_t tail = n % nthr; + if (ithr < tail) { + band++; + *t_offset = band * ithr; + } else { + *t_offset = band * ithr + tail; + } + *t_block = band; +} + +} // namespace gemm_utils +} // namespace rv64 +} // namespace cpu +} // namespace impl +} // namespace dnnl \ No newline at end of file diff --git a/src/cpu/rv64/gemm/rvv_gemm_utils_bf16.hpp b/src/cpu/rv64/gemm/rvv_gemm_utils_bf16.hpp new file mode 100644 index 00000000000..49a65d2bac1 --- /dev/null +++ b/src/cpu/rv64/gemm/rvv_gemm_utils_bf16.hpp @@ -0,0 +1,59 @@ +/******************************************************************************* +* Copyright 2018-2025 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef CPU_RV64_GEMM_RVV_GEMM_UTILS_BF16_HPP +#define CPU_RV64_GEMM_RVV_GEMM_UTILS_BF16_HPP + +#include +#include "common/c_types_map.hpp" +#include "common/bfloat16.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { +namespace rv64 { +namespace gemm_utils { + +// GEMM traits for bfloat16, similar to f32 version +template +struct gemm_traits_t {}; + +template +struct gemm_traits_t { + static constexpr dim_t m = 16; + static constexpr dim_t n = 4; + static constexpr dim_t BM = 4032; + static constexpr dim_t BN = isTransA ? 96 : 48; + static constexpr dim_t BK = isTransB ? 96 : 256; +}; + +template +using unroll_factor_bf16 = gemm_traits_t; + +// Threading and blocking utilities for bf16 +void calc_nthr_nocopy_rvv_bf16(dim_t m, dim_t n, dim_t k, int nthrs, int *nthrs_m, + int *nthrs_n, int *nthrs_k, dim_t *BM, dim_t *BN, dim_t *BK); + +void partition_unit_diff_bf16( + int ithr, int nthr, dim_t n, dim_t *t_offset, dim_t *t_block); + +} // namespace gemm_utils +} // namespace rv64 +} // namespace cpu +} // namespace impl +} // namespace dnnl + +#endif // CPU_RV64_GEMM_RVV_GEMM_UTILS_BF16_HPP \ No newline at end of file