From 237068c4d25a5e1582df586c7826d7345415f9cf Mon Sep 17 00:00:00 2001 From: krishnasai-mcw Date: Tue, 19 Aug 2025 10:12:16 +0530 Subject: [PATCH 1/4] cpu: riscv: matmul: add RVV row/col kernels with bias, ReLU post-op --- src/cpu/matmul/cpu_matmul_list.cpp | 7 + src/cpu/rv64/rvv_matmul.cpp | 271 +++++++++++++++++++++++++++++ src/cpu/rv64/rvv_matmul.hpp | 148 ++++++++++++++++ src/cpu/rv64/rvv_postops.hpp | 44 +++++ 4 files changed, 470 insertions(+) create mode 100644 src/cpu/rv64/rvv_matmul.cpp create mode 100644 src/cpu/rv64/rvv_matmul.hpp create mode 100644 src/cpu/rv64/rvv_postops.hpp diff --git a/src/cpu/matmul/cpu_matmul_list.cpp b/src/cpu/matmul/cpu_matmul_list.cpp index 45ec811db0f..52fbc16a20b 100644 --- a/src/cpu/matmul/cpu_matmul_list.cpp +++ b/src/cpu/matmul/cpu_matmul_list.cpp @@ -41,6 +41,12 @@ using namespace dnnl::impl::cpu::x64; #endif using namespace dnnl::impl::cpu::aarch64::matmul; using namespace dnnl::impl::cpu::aarch64; +#elif DNNL_RV64 +#if DNNL_RISCV_USE_RVV_INTRINSICS +#include "cpu/rv64/rvv_matmul.hpp" +using namespace dnnl::impl::cpu::rv64::matmul; +using namespace dnnl::impl::cpu::rv64; +#endif #endif @@ -71,6 +77,7 @@ constexpr impl_list_item_t impl_list[] = REG_MATMUL_P({ CPU_INSTANCE_AVX512(brgemm_matmul_t) CPU_INSTANCE_AVX2(brgemm_matmul_t) CPU_INSTANCE_AVX2(brgemm_matmul_t) + CPU_INSTANCE_RV64GCV(rvv_matmul_t) CPU_INSTANCE(gemm_f32_matmul_t) CPU_INSTANCE(gemm_bf16_matmul_t) CPU_INSTANCE(gemm_bf16_matmul_t) diff --git a/src/cpu/rv64/rvv_matmul.cpp b/src/cpu/rv64/rvv_matmul.cpp new file mode 100644 index 00000000000..13ecb5505f1 --- /dev/null +++ b/src/cpu/rv64/rvv_matmul.cpp @@ -0,0 +1,271 @@ +#include "rvv_matmul.hpp" +#include "common/dnnl_thread.hpp" +#include "cpu/cpu_primitive.hpp" +#include "cpu/matmul/matmul_utils.hpp" +#include "cpu/rv64/rvv_postops.hpp" +#include + +namespace dnnl { +namespace impl { +namespace cpu { +namespace rv64 { +namespace matmul { + +void rvv_matmul_colmajor(const float *src, const float *weights, float *dst, + const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d, const float *bias, + const memory_desc_wrapper &bias_d, + const rvv_postops_t &postops_handler) { + + const int ndims = src_d.ndims(); + const dim_t *src_dims = src_d.dims(); + const dim_t *wei_dims = weights_d.dims(); + const int weights_ndims = weights_d.ndims(); + + dim_t batch = 1; + for (int i = 0; i < ndims - 2; ++i) + batch *= src_dims[i]; + + const dim_t M = src_dims[ndims - 2]; + const dim_t K = src_dims[ndims - 1]; + const dim_t N = wei_dims[weights_ndims - 1]; + + dim_t weights_batch_size = 1; + for (int i = 0; i < weights_ndims - 2; ++i) + weights_batch_size *= wei_dims[i]; + const bool weights_are_broadcasted = (weights_batch_size == 1 && batch > 1); + + parallel_nd(batch, M, [&](dim_t b, dim_t m) { + std::vector dst_idx_prefix(ndims - 1); + if (ndims > 2) { + utils::l_dims_by_l_offset( + dst_idx_prefix.data(), b, src_dims, ndims - 2); + } + dst_idx_prefix[ndims - 2] = m; + + size_t weights_batch_offset = 0; + if (!weights_are_broadcasted) { + for (int i = 0; i < weights_ndims - 2; ++i) { + if (wei_dims[i] != 1) { + dim_t b_idx = dst_idx_prefix[i + (ndims - weights_ndims)]; + weights_batch_offset + += b_idx * weights_d.blocking_desc().strides[i]; + } + } + } + + const float *src_base_ptr = src + (size_t)b * M * K + (size_t)m * K; + float *dst_base_ptr = dst + (size_t)b * M * N + (size_t)m * N; + const float *weights_base_ptr = weights + weights_batch_offset; + + for (dim_t n0 = 0; n0 < N;) { + size_t vl = __riscv_vsetvl_e32m1(N - n0); + std::vector out_vals(vl, 0.0f); + + for (dim_t k0 = 0; k0 < K;) { + size_t k_vl = __riscv_vsetvl_e32m1(K - k0); + + vfloat32m1_t src_vec + = __riscv_vle32_v_f32m1(src_base_ptr + k0, k_vl); + + for (size_t ni = 0; ni < vl; ++ni) { + const float *weight_col_ptr + = weights_base_ptr + (size_t)(n0 + ni) * (size_t)K; + vfloat32m1_t wei_vec + = __riscv_vle32_v_f32m1(weight_col_ptr + k0, k_vl); + + vfloat32m1_t prod + = __riscv_vfmul_vv_f32m1(src_vec, wei_vec, k_vl); + vfloat32m1_t reduced = __riscv_vfredusum_vs_f32m1_f32m1( + prod, __riscv_vfmv_v_f_f32m1(0.0f, k_vl), k_vl); + float partial = __riscv_vfmv_f_s_f32m1_f32(reduced); + + out_vals[ni] += partial; + } + + k0 += k_vl; + } + + vfloat32m1_t acc = __riscv_vle32_v_f32m1(out_vals.data(), vl); + + if (bias) { + if (bias_d.nelems() == 1) { + acc = __riscv_vfadd_vf_f32m1(acc, bias[0], vl); + } else { + const int dst_ndims = dst_d.ndims(); + const int bias_ndims = bias_d.ndims(); + const dim_t *bias_dims = bias_d.dims(); + + std::vector bias_strides(bias_ndims); + bias_strides[bias_ndims - 1] = 1; + for (int d = bias_ndims - 2; d >= 0; --d) + bias_strides[d] = bias_strides[d + 1] + * (size_t)bias_dims[d + 1]; + + size_t base_bias_off = 0; + for (int d = 0; d < bias_ndims - 1; ++d) { + int dst_dim_idx = d + (dst_ndims - bias_ndims); + dim_t idx = (bias_dims[d] == 1) + ? 0 + : dst_idx_prefix[dst_dim_idx]; + base_bias_off += idx * bias_strides[d]; + } + + if (bias_dims[bias_ndims - 1] == 1) { + acc = __riscv_vfadd_vf_f32m1( + acc, bias[base_bias_off], vl); + } else { + const float *bias_ptr = bias + base_bias_off + n0; + vfloat32m1_t bias_vec + = __riscv_vle32_v_f32m1(bias_ptr, vl); + acc = __riscv_vfadd_vv_f32m1(acc, bias_vec, vl); + } + } + } + + acc = postops_handler.apply(acc, vl); + __riscv_vse32_v_f32m1(&dst_base_ptr[n0], acc, vl); + n0 += vl; + } + }); +} + +void rvv_matmul_rowmajor(const float *src, const float *weights, float *dst, + const memory_desc_wrapper &src_d, const memory_desc_wrapper &weights_d, + const memory_desc_wrapper &dst_d, const float *bias, + const memory_desc_wrapper &bias_d, + const rvv_postops_t &postops_handler) { + + const int ndims = src_d.ndims(); + const dim_t *src_dims = src_d.dims(); + const dim_t *wei_dims = weights_d.dims(); + const int weights_ndims = weights_d.ndims(); + + dim_t batch = 1; + for (int i = 0; i < ndims - 2; ++i) + batch *= src_dims[i]; + + const dim_t M = src_dims[ndims - 2]; + const dim_t K = src_dims[ndims - 1]; + const dim_t N = wei_dims[weights_ndims - 1]; + + dim_t weights_batch_size = 1; + for (int i = 0; i < weights_ndims - 2; ++i) + weights_batch_size *= wei_dims[i]; + const bool weights_are_broadcasted = (weights_batch_size == 1 && batch > 1); + + parallel_nd(batch, M, [&](dim_t b, dim_t m) { + std::vector dst_idx_prefix(ndims - 1); + if (ndims > 2) { + utils::l_dims_by_l_offset( + dst_idx_prefix.data(), b, src_dims, ndims - 2); + } + dst_idx_prefix[ndims - 2] = m; + + size_t weights_batch_offset = 0; + if (!weights_are_broadcasted) { + for (int i = 0; i < weights_ndims - 2; ++i) { + if (wei_dims[i] != 1) { + dim_t b_idx = dst_idx_prefix[i + (ndims - weights_ndims)]; + weights_batch_offset + += b_idx * weights_d.blocking_desc().strides[i]; + } + } + } + + const float *src_base_ptr = src + (size_t)b * M * K + (size_t)m * K; + float *dst_base_ptr = dst + (size_t)b * M * N + (size_t)m * N; + const float *weights_base_ptr = weights + weights_batch_offset; + + for (dim_t n0 = 0; n0 < N;) { + size_t vl = __riscv_vsetvl_e32m1(N - n0); + vfloat32m1_t acc = __riscv_vfmv_v_f_f32m1(0.0f, vl); + + for (dim_t k = 0; k < K; ++k) { + vfloat32m1_t a_vec + = __riscv_vfmv_v_f_f32m1(src_base_ptr[k], vl); + const float *b_ptr = weights_base_ptr + (size_t)k * N + n0; + vfloat32m1_t b_vec = __riscv_vle32_v_f32m1(b_ptr, vl); + acc = __riscv_vfmacc_vv_f32m1(acc, a_vec, b_vec, vl); + } + + if (bias) { + if (bias_d.nelems() == 1) { + acc = __riscv_vfadd_vf_f32m1(acc, bias[0], vl); + } else { + const int dst_ndims = dst_d.ndims(); + const int bias_ndims = bias_d.ndims(); + const dim_t *bias_dims = bias_d.dims(); + + std::vector bias_strides(bias_ndims); + bias_strides[bias_ndims - 1] = 1; + for (int d = bias_ndims - 2; d >= 0; --d) + bias_strides[d] = bias_strides[d + 1] + * (size_t)bias_dims[d + 1]; + + size_t base_bias_off = 0; + for (int d = 0; d < bias_ndims - 1; ++d) { + int dst_dim_idx = d + (dst_ndims - bias_ndims); + dim_t idx = (bias_dims[d] == 1) + ? 0 + : dst_idx_prefix[dst_dim_idx]; + base_bias_off += idx * bias_strides[d]; + } + + if (bias_dims[bias_ndims - 1] == 1) { + acc = __riscv_vfadd_vf_f32m1( + acc, bias[base_bias_off], vl); + } else { + const float *bias_ptr = bias + base_bias_off + n0; + vfloat32m1_t bias_vec + = __riscv_vle32_v_f32m1(bias_ptr, vl); + acc = __riscv_vfadd_vv_f32m1(acc, bias_vec, vl); + } + } + } + + acc = postops_handler.apply(acc, vl); + __riscv_vse32_v_f32m1(&dst_base_ptr[n0], acc, vl); + n0 += vl; + } + }); +} + +template +rvv_matmul_t::rvv_matmul_t(const pd_t *apd) : primitive_t(apd) {} + +template <> +status_t rvv_matmul_t::execute(const exec_ctx_t &ctx) const { + auto src = CTX_IN_MEM(const float *, DNNL_ARG_SRC); + auto weights = CTX_IN_MEM(const float *, DNNL_ARG_WEIGHTS); + auto dst = CTX_OUT_MEM(float *, DNNL_ARG_DST); + + const memory_desc_wrapper src_d(pd()->src_md()); + const memory_desc_wrapper weights_d(pd()->weights_md()); + const memory_desc_wrapper dst_d(pd()->dst_md()); + const memory_desc_wrapper bias_d(pd()->desc()->bias_desc); + + const post_ops_t &post_ops = pd()->attr()->post_ops_; + rvv_postops_t postops_handler(post_ops); + + const float *bias = nullptr; + if (!bias_d.is_zero()) bias = CTX_IN_MEM(const float *, DNNL_ARG_BIAS); + + if (pd()->is_col_major(weights_d)) { + rvv_matmul_colmajor(src, weights, dst, src_d, weights_d, dst_d, bias, + bias_d, postops_handler); + } else { + rvv_matmul_rowmajor(src, weights, dst, src_d, weights_d, dst_d, bias, + bias_d, postops_handler); + } + + return status::success; +} + +template struct rvv_matmul_t; + +} // namespace matmul +} // namespace rv64 +} // namespace cpu +} // namespace impl +} // namespace dnnl \ No newline at end of file diff --git a/src/cpu/rv64/rvv_matmul.hpp b/src/cpu/rv64/rvv_matmul.hpp new file mode 100644 index 00000000000..ec5e3cad096 --- /dev/null +++ b/src/cpu/rv64/rvv_matmul.hpp @@ -0,0 +1,148 @@ +#ifndef CPU_RV64_RVV_MATMUL_HPP +#define CPU_RV64_RVV_MATMUL_HPP + +#include "common/c_types_map.hpp" +#include "common/primitive.hpp" +#include "common/type_helpers.hpp" +#include "cpu/matmul/cpu_matmul_pd.hpp" +#include "cpu/rv64/rvv_postops.hpp" + +namespace dnnl { +namespace impl { +namespace cpu { +namespace rv64 { +namespace matmul { + +template +struct rvv_matmul_t : public primitive_t { + struct pd_t : public ::dnnl::impl::cpu::matmul::cpu_matmul_pd_t { + using ::dnnl::impl::cpu::matmul::cpu_matmul_pd_t::cpu_matmul_pd_t; + + DECLARE_COMMON_PD_T("RISCV64GCV", rvv_matmul_t) + + status_t init(engine_t *engine) { + UNUSED(engine); + + const memory_desc_wrapper src_mdw(src_md(0)); + const memory_desc_wrapper weights_mdw(weights_md(0)); + const memory_desc_wrapper dst_mdw(dst_md(0)); + const memory_desc_wrapper bias_mdw = bias_md_; + + if (has_zero_dim_memory() || src_mdw.has_runtime_dims_or_strides() + || weights_mdw.has_runtime_dims_or_strides() + || dst_mdw.has_runtime_dims_or_strides() + || bias_mdw.has_runtime_dims_or_strides()) + return status::unimplemented; + + const bool types_ok = src_mdw.data_type() == d_type + && weights_mdw.data_type() == d_type + && dst_mdw.data_type() == d_type + && desc()->accum_data_type == d_type; + if (!types_ok) return status::unimplemented; + + if (!attr()->scales_.has_default_values()) + return status::unimplemented; + + if (attr()->post_ops_.len() != 0) { + rvv_postops_t po_handler(attr()->post_ops_); + if (!po_handler.has_postops()) return status::unimplemented; + } + + if (!set_default_formats()) return status::unimplemented; + + if (!check_layouts(src_mdw, weights_mdw, dst_mdw)) + return status::unimplemented; + + const auto wei_ndims = weights_mdw.ndims(); + for (int i = 0; i < wei_ndims - 2; ++i) { + if (src_mdw.dims()[i] != weights_mdw.dims()[i] + && weights_mdw.dims()[i] != 1) + return status::unimplemented; + } + + if (!check_bias(dst_mdw, bias_mdw)) return status::unimplemented; + + if (!set_default_formats()) return status::unimplemented; + + return status::success; + } + + bool is_row_major(const memory_desc_wrapper &mdw) const { + const int ndims = mdw.ndims(); + if (ndims < 2) return false; + + const auto &strides = mdw.blocking_desc().strides; + if (strides[ndims - 1] != 1) return false; + + dim_t expected_stride = mdw.dims()[ndims - 1]; + for (int d = ndims - 2; d >= 0; --d) { + if (strides[d] != expected_stride) return false; + expected_stride *= mdw.dims()[d]; + } + return true; + } + + bool is_col_major(const memory_desc_wrapper &mdw) const { + const int ndims = mdw.ndims(); + if (ndims < 2) return false; + + const auto &strides = mdw.blocking_desc().strides; + const auto &dims = mdw.dims(); + + if (strides[ndims - 2] != 1) return false; + if (strides[ndims - 1] != dims[ndims - 2]) return false; + + dim_t expected_stride = dims[ndims - 2] * dims[ndims - 1]; + for (int d = ndims - 3; d >= 0; --d) { + if (strides[d] != expected_stride) return false; + expected_stride *= dims[d]; + } + return true; + } + + bool check_layouts(const memory_desc_wrapper &src_mdw, + const memory_desc_wrapper &wei_mdw, + const memory_desc_wrapper &dst_mdw) const { + if (!is_row_major(src_mdw) || !is_row_major(dst_mdw)) return false; + + if (!is_row_major(wei_mdw) && !is_col_major(wei_mdw)) return false; + + return true; + } + + bool check_bias(const memory_desc_wrapper &dst_mdw, + const memory_desc_wrapper &bias_mdw) const { + if (bias_mdw.is_zero()) return true; + + if (bias_mdw.data_type() != d_type) return false; + + const int dst_ndims = dst_mdw.ndims(); + const int bias_ndims = bias_mdw.ndims(); + if (bias_ndims > dst_ndims) return false; + + const auto *dst_dims = dst_mdw.dims(); + const auto *bias_dims = bias_mdw.dims(); + + for (int d = 1; d <= bias_ndims; ++d) { + const dim_t bias_dim = bias_dims[bias_ndims - d]; + const dim_t dst_dim = dst_dims[dst_ndims - d]; + if (bias_dim != 1 && bias_dim != dst_dim) return false; + } + return true; + } + }; + + rvv_matmul_t(const pd_t *apd); + status_t execute(const exec_ctx_t &ctx) const; + +private: + const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } +}; + +} // namespace matmul +} // namespace rv64 +} // namespace cpu +} // namespace impl +} // namespace dnnl + +#endif // CPU_RV64_RVV_MATMUL_HPP \ No newline at end of file diff --git a/src/cpu/rv64/rvv_postops.hpp b/src/cpu/rv64/rvv_postops.hpp new file mode 100644 index 00000000000..7cd386d4c59 --- /dev/null +++ b/src/cpu/rv64/rvv_postops.hpp @@ -0,0 +1,44 @@ +#ifndef CPU_RV64_RVV_POSTOPS_HPP +#define CPU_RV64_RVV_POSTOPS_HPP + +#include "common/primitive_attr.hpp" +#include "common/utils.hpp" +#include + +namespace dnnl { +namespace impl { +namespace cpu { +namespace rv64 { + +struct rvv_postops_t { + rvv_postops_t(const post_ops_t &po) { + is_supported_ = false; + + if (po.len() == 1 && po.entry_[0].is_eltwise()) { + const auto &e = po.entry_[0]; + if (e.eltwise.alg == alg_kind::eltwise_relu) { + is_supported_ = true; + } + } + } + + inline bool has_postops() const { return is_supported_; } + + inline vfloat32m1_t apply(vfloat32m1_t v, size_t vl) const { + if (is_supported_) { + vfloat32m1_t zero = __riscv_vfmv_v_f_f32m1(0.f, vl); + return __riscv_vfmax_vv_f32m1(v, zero, vl); + } + return v; + } + +private: + bool is_supported_; +}; + +} // namespace rv64 +} // namespace cpu +} // namespace impl +} // namespace dnnl + +#endif // CPU_RV64_RVV_POSTOPS_HPP From cd5c4329569322eb528b95aa338c53b6db12fa5f Mon Sep 17 00:00:00 2001 From: krishnasai-mcw Date: Tue, 19 Aug 2025 12:47:59 +0530 Subject: [PATCH 2/4] cpu: riscv: matmul: fix post-ops handler and rvv matmul fixes --- src/cpu/matmul/cpu_matmul_list.cpp | 4 +- src/cpu/rv64/rvv_matmul.cpp | 14 ++---- src/cpu/rv64/rvv_matmul.hpp | 68 ++++++++++++++++-------------- src/cpu/rv64/rvv_postops.hpp | 37 +++++++++------- 4 files changed, 65 insertions(+), 58 deletions(-) diff --git a/src/cpu/matmul/cpu_matmul_list.cpp b/src/cpu/matmul/cpu_matmul_list.cpp index 52fbc16a20b..e61b2cf84df 100644 --- a/src/cpu/matmul/cpu_matmul_list.cpp +++ b/src/cpu/matmul/cpu_matmul_list.cpp @@ -42,7 +42,7 @@ using namespace dnnl::impl::cpu::x64; using namespace dnnl::impl::cpu::aarch64::matmul; using namespace dnnl::impl::cpu::aarch64; #elif DNNL_RV64 -#if DNNL_RISCV_USE_RVV_INTRINSICS +#ifdef DNNL_RISCV_USE_RVV_INTRINSICS #include "cpu/rv64/rvv_matmul.hpp" using namespace dnnl::impl::cpu::rv64::matmul; using namespace dnnl::impl::cpu::rv64; @@ -77,7 +77,7 @@ constexpr impl_list_item_t impl_list[] = REG_MATMUL_P({ CPU_INSTANCE_AVX512(brgemm_matmul_t) CPU_INSTANCE_AVX2(brgemm_matmul_t) CPU_INSTANCE_AVX2(brgemm_matmul_t) - CPU_INSTANCE_RV64GCV(rvv_matmul_t) + CPU_INSTANCE_RV64GCV(rvv_matmul_t) CPU_INSTANCE(gemm_f32_matmul_t) CPU_INSTANCE(gemm_bf16_matmul_t) CPU_INSTANCE(gemm_bf16_matmul_t) diff --git a/src/cpu/rv64/rvv_matmul.cpp b/src/cpu/rv64/rvv_matmul.cpp index 13ecb5505f1..eba90906f35 100644 --- a/src/cpu/rv64/rvv_matmul.cpp +++ b/src/cpu/rv64/rvv_matmul.cpp @@ -1,4 +1,4 @@ -#include "rvv_matmul.hpp" +#include "cpu/rv64/rvv_matmul.hpp" #include "common/dnnl_thread.hpp" #include "cpu/cpu_primitive.hpp" #include "cpu/matmul/matmul_utils.hpp" @@ -231,11 +231,9 @@ void rvv_matmul_rowmajor(const float *src, const float *weights, float *dst, }); } -template -rvv_matmul_t::rvv_matmul_t(const pd_t *apd) : primitive_t(apd) {} +rvv_matmul_t::rvv_matmul_t(const pd_t *apd) : primitive_t(apd) {} -template <> -status_t rvv_matmul_t::execute(const exec_ctx_t &ctx) const { +status_t rvv_matmul_t::execute(const exec_ctx_t &ctx) const { auto src = CTX_IN_MEM(const float *, DNNL_ARG_SRC); auto weights = CTX_IN_MEM(const float *, DNNL_ARG_WEIGHTS); auto dst = CTX_OUT_MEM(float *, DNNL_ARG_DST); @@ -248,9 +246,7 @@ status_t rvv_matmul_t::execute(const exec_ctx_t &ctx) const { const post_ops_t &post_ops = pd()->attr()->post_ops_; rvv_postops_t postops_handler(post_ops); - const float *bias = nullptr; - if (!bias_d.is_zero()) bias = CTX_IN_MEM(const float *, DNNL_ARG_BIAS); - + const float *bias = CTX_IN_MEM(const float *, DNNL_ARG_BIAS); if (pd()->is_col_major(weights_d)) { rvv_matmul_colmajor(src, weights, dst, src_d, weights_d, dst_d, bias, bias_d, postops_handler); @@ -262,8 +258,6 @@ status_t rvv_matmul_t::execute(const exec_ctx_t &ctx) const { return status::success; } -template struct rvv_matmul_t; - } // namespace matmul } // namespace rv64 } // namespace cpu diff --git a/src/cpu/rv64/rvv_matmul.hpp b/src/cpu/rv64/rvv_matmul.hpp index ec5e3cad096..2ffd17d078f 100644 --- a/src/cpu/rv64/rvv_matmul.hpp +++ b/src/cpu/rv64/rvv_matmul.hpp @@ -13,13 +13,14 @@ namespace cpu { namespace rv64 { namespace matmul { -template struct rvv_matmul_t : public primitive_t { struct pd_t : public ::dnnl::impl::cpu::matmul::cpu_matmul_pd_t { using ::dnnl::impl::cpu::matmul::cpu_matmul_pd_t::cpu_matmul_pd_t; DECLARE_COMMON_PD_T("RISCV64GCV", rvv_matmul_t) + static constexpr data_type_t d_type = data_type::f32; + status_t init(engine_t *engine) { UNUSED(engine); @@ -28,41 +29,48 @@ struct rvv_matmul_t : public primitive_t { const memory_desc_wrapper dst_mdw(dst_md(0)); const memory_desc_wrapper bias_mdw = bias_md_; - if (has_zero_dim_memory() || src_mdw.has_runtime_dims_or_strides() - || weights_mdw.has_runtime_dims_or_strides() - || dst_mdw.has_runtime_dims_or_strides() - || bias_mdw.has_runtime_dims_or_strides()) - return status::unimplemented; + VDISPATCH_MATMUL(!has_zero_dim_memory(), VERBOSE_EMPTY_TENSOR, ""); + + VDISPATCH_MATMUL(!src_mdw.has_runtime_dims_or_strides() + && !weights_mdw.has_runtime_dims_or_strides() + && !dst_mdw.has_runtime_dims_or_strides() + && !bias_mdw.has_runtime_dims_or_strides(), + VERBOSE_UNSUPPORTED_TAG); const bool types_ok = src_mdw.data_type() == d_type && weights_mdw.data_type() == d_type && dst_mdw.data_type() == d_type && desc()->accum_data_type == d_type; - if (!types_ok) return status::unimplemented; - - if (!attr()->scales_.has_default_values()) - return status::unimplemented; - - if (attr()->post_ops_.len() != 0) { - rvv_postops_t po_handler(attr()->post_ops_); - if (!po_handler.has_postops()) return status::unimplemented; - } - - if (!set_default_formats()) return status::unimplemented; - - if (!check_layouts(src_mdw, weights_mdw, dst_mdw)) - return status::unimplemented; - - const auto wei_ndims = weights_mdw.ndims(); - for (int i = 0; i < wei_ndims - 2; ++i) { - if (src_mdw.dims()[i] != weights_mdw.dims()[i] - && weights_mdw.dims()[i] != 1) - return status::unimplemented; + VDISPATCH_MATMUL(types_ok, VERBOSE_UNSUPPORTED_DT); + + VDISPATCH_MATMUL(attr()->scales_.has_default_values(), + VERBOSE_UNSUPPORTED_SCALES_CFG); + + VDISPATCH_MATMUL(rvv_postops_t::post_ops_ok(attr()->post_ops_), + VERBOSE_UNSUPPORTED_POSTOP); + + VDISPATCH_MATMUL(set_default_formats(), VERBOSE_UNSUPPORTED_TAG); + + VDISPATCH_MATMUL(check_layouts(src_mdw, weights_mdw, dst_mdw), + VERBOSE_UNSUPPORTED_TAG); + + { + const auto wei_ndims = weights_mdw.ndims(); + bool bc_ok = true; + for (int i = 0; i < wei_ndims - 2; ++i) { + if (src_mdw.dims()[i] != weights_mdw.dims()[i] + && weights_mdw.dims()[i] != 1) { + bc_ok = false; + break; + } + } + VDISPATCH_MATMUL(bc_ok, VERBOSE_UNSUPPORTED_TAG); } - if (!check_bias(dst_mdw, bias_mdw)) return status::unimplemented; + VDISPATCH_MATMUL(check_bias(dst_mdw, bias_mdw), + VERBOSE_UNSUPPORTED_BIAS_CFG); - if (!set_default_formats()) return status::unimplemented; + VDISPATCH_MATMUL(set_default_formats(), VERBOSE_UNSUPPORTED_TAG); return status::success; } @@ -104,9 +112,7 @@ struct rvv_matmul_t : public primitive_t { const memory_desc_wrapper &wei_mdw, const memory_desc_wrapper &dst_mdw) const { if (!is_row_major(src_mdw) || !is_row_major(dst_mdw)) return false; - if (!is_row_major(wei_mdw) && !is_col_major(wei_mdw)) return false; - return true; } @@ -145,4 +151,4 @@ struct rvv_matmul_t : public primitive_t { } // namespace impl } // namespace dnnl -#endif // CPU_RV64_RVV_MATMUL_HPP \ No newline at end of file +#endif // CPU_RV64_RVV_MATMUL_HPP diff --git a/src/cpu/rv64/rvv_postops.hpp b/src/cpu/rv64/rvv_postops.hpp index 7cd386d4c59..8158341f3c8 100644 --- a/src/cpu/rv64/rvv_postops.hpp +++ b/src/cpu/rv64/rvv_postops.hpp @@ -3,6 +3,7 @@ #include "common/primitive_attr.hpp" #include "common/utils.hpp" +#include "oneapi/dnnl/dnnl_types.h" #include namespace dnnl { @@ -11,29 +12,35 @@ namespace cpu { namespace rv64 { struct rvv_postops_t { - rvv_postops_t(const post_ops_t &po) { - is_supported_ = false; + rvv_postops_t(const post_ops_t &po) : alg_(alg_kind::undef) { + if (po.len() > 0) { alg_ = po.entry_[0].eltwise.alg; } + } - if (po.len() == 1 && po.entry_[0].is_eltwise()) { - const auto &e = po.entry_[0]; - if (e.eltwise.alg == alg_kind::eltwise_relu) { - is_supported_ = true; - } + static bool post_ops_ok(const post_ops_t &po) { + if (po.len() == 0) return true; + if (po.len() > 1) return false; + + const auto &e = po.entry_[0]; + if (!e.is_eltwise()) return false; + + switch (e.eltwise.alg) { + case alg_kind::eltwise_relu: return true; + default: return false; } } - inline bool has_postops() const { return is_supported_; } - inline vfloat32m1_t apply(vfloat32m1_t v, size_t vl) const { - if (is_supported_) { - vfloat32m1_t zero = __riscv_vfmv_v_f_f32m1(0.f, vl); - return __riscv_vfmax_vv_f32m1(v, zero, vl); + switch (alg_) { + case alg_kind::eltwise_relu: { + vfloat32m1_t zero = __riscv_vfmv_v_f_f32m1(0.f, vl); + return __riscv_vfmax_vv_f32m1(v, zero, vl); + } + default: return v; } - return v; } private: - bool is_supported_; + dnnl_alg_kind_t alg_; }; } // namespace rv64 @@ -41,4 +48,4 @@ struct rvv_postops_t { } // namespace impl } // namespace dnnl -#endif // CPU_RV64_RVV_POSTOPS_HPP +#endif // CPU_RV64_RVV_POSTOPS_HPP \ No newline at end of file From bbaff737d3212767e0a429edada7797175180017 Mon Sep 17 00:00:00 2001 From: krishnasai-mcw Date: Wed, 20 Aug 2025 07:43:44 +0530 Subject: [PATCH 3/4] cpu: riscv: matmul: add copyright headers to newly added files --- src/cpu/rv64/rvv_matmul.cpp | 15 +++++++++++++++ src/cpu/rv64/rvv_matmul.hpp | 15 +++++++++++++++ src/cpu/rv64/rvv_postops.hpp | 15 +++++++++++++++ 3 files changed, 45 insertions(+) diff --git a/src/cpu/rv64/rvv_matmul.cpp b/src/cpu/rv64/rvv_matmul.cpp index eba90906f35..a2bc38d2a28 100644 --- a/src/cpu/rv64/rvv_matmul.cpp +++ b/src/cpu/rv64/rvv_matmul.cpp @@ -1,3 +1,18 @@ +/******************************************************************************* +* Copyright 2019-2025 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ #include "cpu/rv64/rvv_matmul.hpp" #include "common/dnnl_thread.hpp" #include "cpu/cpu_primitive.hpp" diff --git a/src/cpu/rv64/rvv_matmul.hpp b/src/cpu/rv64/rvv_matmul.hpp index 2ffd17d078f..3bacd7be504 100644 --- a/src/cpu/rv64/rvv_matmul.hpp +++ b/src/cpu/rv64/rvv_matmul.hpp @@ -1,3 +1,18 @@ +/******************************************************************************* +* Copyright 2019-2025 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ #ifndef CPU_RV64_RVV_MATMUL_HPP #define CPU_RV64_RVV_MATMUL_HPP diff --git a/src/cpu/rv64/rvv_postops.hpp b/src/cpu/rv64/rvv_postops.hpp index 8158341f3c8..09ea6ed4ee5 100644 --- a/src/cpu/rv64/rvv_postops.hpp +++ b/src/cpu/rv64/rvv_postops.hpp @@ -1,3 +1,18 @@ +/******************************************************************************* +* Copyright 2019-2025 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ #ifndef CPU_RV64_RVV_POSTOPS_HPP #define CPU_RV64_RVV_POSTOPS_HPP From 4a3a8687f8e7db82591faa7d20f54b959e6a37f3 Mon Sep 17 00:00:00 2001 From: krishnasai-mcw Date: Thu, 21 Aug 2025 11:33:31 +0530 Subject: [PATCH 4/4] cpu: riscv: matmul: fix type/init and drop unused includes --- src/cpu/rv64/rvv_matmul.cpp | 2 -- src/cpu/rv64/rvv_matmul.hpp | 2 -- src/cpu/rv64/rvv_postops.hpp | 10 ++++------ 3 files changed, 4 insertions(+), 10 deletions(-) diff --git a/src/cpu/rv64/rvv_matmul.cpp b/src/cpu/rv64/rvv_matmul.cpp index a2bc38d2a28..dcd932b3676 100644 --- a/src/cpu/rv64/rvv_matmul.cpp +++ b/src/cpu/rv64/rvv_matmul.cpp @@ -15,8 +15,6 @@ *******************************************************************************/ #include "cpu/rv64/rvv_matmul.hpp" #include "common/dnnl_thread.hpp" -#include "cpu/cpu_primitive.hpp" -#include "cpu/matmul/matmul_utils.hpp" #include "cpu/rv64/rvv_postops.hpp" #include diff --git a/src/cpu/rv64/rvv_matmul.hpp b/src/cpu/rv64/rvv_matmul.hpp index 3bacd7be504..0bf6fa06da5 100644 --- a/src/cpu/rv64/rvv_matmul.hpp +++ b/src/cpu/rv64/rvv_matmul.hpp @@ -16,9 +16,7 @@ #ifndef CPU_RV64_RVV_MATMUL_HPP #define CPU_RV64_RVV_MATMUL_HPP -#include "common/c_types_map.hpp" #include "common/primitive.hpp" -#include "common/type_helpers.hpp" #include "cpu/matmul/cpu_matmul_pd.hpp" #include "cpu/rv64/rvv_postops.hpp" diff --git a/src/cpu/rv64/rvv_postops.hpp b/src/cpu/rv64/rvv_postops.hpp index 09ea6ed4ee5..28c54f2e77e 100644 --- a/src/cpu/rv64/rvv_postops.hpp +++ b/src/cpu/rv64/rvv_postops.hpp @@ -16,9 +16,6 @@ #ifndef CPU_RV64_RVV_POSTOPS_HPP #define CPU_RV64_RVV_POSTOPS_HPP -#include "common/primitive_attr.hpp" -#include "common/utils.hpp" -#include "oneapi/dnnl/dnnl_types.h" #include namespace dnnl { @@ -27,8 +24,9 @@ namespace cpu { namespace rv64 { struct rvv_postops_t { - rvv_postops_t(const post_ops_t &po) : alg_(alg_kind::undef) { - if (po.len() > 0) { alg_ = po.entry_[0].eltwise.alg; } + rvv_postops_t(const post_ops_t &po) + : alg_(po.len() > 0 ? po.entry_[0].eltwise.alg : alg_kind::undef) { + assert(po.len() <= 1 && "rvv_postops_t supports at most one post-op"); } static bool post_ops_ok(const post_ops_t &po) { @@ -55,7 +53,7 @@ struct rvv_postops_t { } private: - dnnl_alg_kind_t alg_; + alg_kind_t alg_; }; } // namespace rv64