From 966df08d2dea0e8eaab6e2d946e94327100e318d Mon Sep 17 00:00:00 2001 From: liubo-intel Date: Thu, 12 Dec 2024 14:50:09 +0800 Subject: [PATCH] limit bf16_emitters truncation impl to eltwise Constant inputs to keep acc fix and minimize the impact on performance --- .../emitters/plugin/x64/jit_bf16_emitters.hpp | 35 +++++++++++-------- .../src/emitters/plugin/x64/jit_emitter.hpp | 8 +++++ .../plugin/x64/jit_load_store_emitters.hpp | 6 ---- src/plugins/intel_cpu/src/nodes/eltwise.cpp | 27 +++++++++++--- src/plugins/intel_cpu/src/nodes/eltwise.h | 1 + 5 files changed, 52 insertions(+), 25 deletions(-) diff --git a/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_bf16_emitters.hpp b/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_bf16_emitters.hpp index 5bcbe25851a9f2..aeb22626307e68 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_bf16_emitters.hpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_bf16_emitters.hpp @@ -13,9 +13,11 @@ class jit_uni_vcvtneps2bf16 : public jit_emitter { public: jit_uni_vcvtneps2bf16(dnnl::impl::cpu::x64::jit_generator* host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, - ov::element::Type exec_prc = ov::element::bf16) + ov::element::Type exec_prc = ov::element::bf16, + arithmetic_mode mode = arithmetic_mode::none) : jit_emitter(host, host_isa, exec_prc) { prepare_table(); + mode_ = mode; } size_t get_inputs_num() const override { @@ -23,6 +25,7 @@ class jit_uni_vcvtneps2bf16 : public jit_emitter { } private: + arithmetic_mode mode_ = arithmetic_mode::none; void emit_impl(const std::vector& in_vec_idxs, const std::vector& out_vec_idxs) const override { if (host_isa_ == dnnl::impl::cpu::x64::avx512_core) { emit_isa(in_vec_idxs, out_vec_idxs); @@ -42,23 +45,25 @@ class jit_uni_vcvtneps2bf16 : public jit_emitter { conditional3::type; Vmm in = Vmm(in_vec_idxs[0]); - Vmm vmm_temp = Vmm(out_vec_idxs[0]); + if (mode_ == arithmetic_mode::constant_saturation) { + Vmm vmm_temp = Vmm(out_vec_idxs[0]); - h->uni_vmaxps(vmm_temp, in, table_val("bf16_min")); - h->uni_vminps(vmm_temp, vmm_temp, table_val("bf16_max")); + h->uni_vmaxps(vmm_temp, in, table_val("bf16_min")); + h->uni_vminps(vmm_temp, vmm_temp, table_val("bf16_max")); - if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core)) { - h->vfixupimmps(vmm_temp, in, table_val("selector"), 0); - } else { - Vmm mask = Vmm(aux_vec_idxs[0]); - h->uni_vcmpps(mask, in, in, 0x03); // _CMP_UNORD_Q - h->uni_vblendvps(vmm_temp, vmm_temp, table_val("nan"), mask); - h->uni_vcmpps(mask, in, table_val("inf"), 0x00); // _CMP_EQ_OQ - h->uni_vblendvps(vmm_temp, vmm_temp, table_val("inf"), mask); - h->uni_vcmpps(mask, in, table_val("neg_inf"), 0x00); // _CMP_EQ_OQ - h->uni_vblendvps(vmm_temp, vmm_temp, table_val("neg_inf"), mask); + if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core)) { + h->vfixupimmps(vmm_temp, in, table_val("selector"), 0); + } else { + Vmm mask = Vmm(aux_vec_idxs[0]); + h->uni_vcmpps(mask, in, in, 0x03); // _CMP_UNORD_Q + h->uni_vblendvps(vmm_temp, vmm_temp, table_val("nan"), mask); + h->uni_vcmpps(mask, in, table_val("inf"), 0x00); // _CMP_EQ_OQ + h->uni_vblendvps(vmm_temp, vmm_temp, table_val("inf"), mask); + h->uni_vcmpps(mask, in, table_val("neg_inf"), 0x00); // _CMP_EQ_OQ + h->uni_vblendvps(vmm_temp, vmm_temp, table_val("neg_inf"), mask); + } + h->uni_vmovups(in, vmm_temp); } - h->uni_vmovups(in, vmm_temp); if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_bf16)) { Ymm out = Ymm(out_vec_idxs[0]); diff --git a/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_emitter.hpp b/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_emitter.hpp index c5729613f1bfe5..bb383cd0777809 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_emitter.hpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_emitter.hpp @@ -27,6 +27,14 @@ enum emitter_in_out_map { gpr_to_gpr, }; +// Arithmetic modes for data type conversion in store_emitter +enum arithmetic_mode { + none, + saturation, + truncation, + constant_saturation +}; + // structure for storage of emitter parameters to hash in map struct emitter_params { virtual size_t hash() const = 0; diff --git a/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_load_store_emitters.hpp b/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_load_store_emitters.hpp index 9570a836aa64ee..75a0e98b8ebe7e 100644 --- a/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_load_store_emitters.hpp +++ b/src/plugins/intel_cpu/src/emitters/plugin/x64/jit_load_store_emitters.hpp @@ -35,12 +35,6 @@ struct store_emitter_params : public emitter_params { int store_num_; }; -// Arithmetic modes for data type conversion in store_emitter -enum arithmetic_mode { - saturation, - truncation -}; - class jit_load_emitter : public jit_emitter { public: jit_load_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa, diff --git a/src/plugins/intel_cpu/src/nodes/eltwise.cpp b/src/plugins/intel_cpu/src/nodes/eltwise.cpp index 54cf435009059d..8fa4b8ae901271 100644 --- a/src/plugins/intel_cpu/src/nodes/eltwise.cpp +++ b/src/plugins/intel_cpu/src/nodes/eltwise.cpp @@ -326,8 +326,10 @@ struct jit_uni_eltwise_generic : public jit_uni_eltwise_kernel, public jit_gener this, p->entry_[i], vmm_d_weights, vmm_d_bias, reg_d_weights, reg_d_bias)); } - if (mayiuse(avx512_core) || mayiuse(avx2_vnni_2)) - uni_vcvtneps2bf16.reset(new jit_uni_vcvtneps2bf16(this, isa)); + if (mayiuse(avx512_core) || mayiuse(avx2_vnni_2)) { + auto const mode = jep_.do_constant_saturation ? arithmetic_mode::constant_saturation : arithmetic_mode::none; + uni_vcvtneps2bf16.reset(new jit_uni_vcvtneps2bf16(this, isa, element::bf16, mode)); + } const auto &jep = jep_; @@ -1310,6 +1312,7 @@ struct EltwiseKey { ov::element::Type outPrc; dnnl::post_ops postOps; EltwiseImplType implType; + bool doConstantSaturation; size_t hash() const { using namespace dnnl::impl; @@ -1345,6 +1348,10 @@ struct EltwiseKey { seed = hash_combine(seed, outPrc.hash()); seed = get_post_op_hash(seed, *postOps.get()); seed = hash_combine(seed, implType); + + if (outPrc == ov::element::bf16) { + seed = hash_combine(seed, doConstantSaturation); + } return seed; } @@ -1376,6 +1383,8 @@ struct EltwiseKey { result = result && (inpDims[i] == rhs.inpDims[i]); } } + if ((outPrc == ov::element::bf16) && (doConstantSaturation != rhs.doConstantSaturation)) + return false; } return result; @@ -1408,7 +1417,8 @@ class EltwiseJitExecutor : public Eltwise::IEltwiseExecutor { const std::vector& inpPrc, const ov::element::Type& outPrc, const dnnl::post_ops& post_ops, - bool useRuntimePtrs) { + bool useRuntimePtrs, + bool doConstantSaturation) { auto collapseLastDims = [](std::vector& dims, int dimsToCollapse) { for (size_t i = dims.size() - 2; i > dims.size() - dimsToCollapse - 2; i--) { dims[dims.size() - 1] *= dims[i]; @@ -1594,6 +1604,7 @@ class EltwiseJitExecutor : public Eltwise::IEltwiseExecutor { jep.dst_prc = outPrc; jep.work_amount = jep.dst_size = jep.dims.back(); jep.oc_size = oc_size; + jep.do_constant_saturation = doConstantSaturation; std::transform(jep.oc_offsets.begin(), jep.oc_offsets.end(), jep.oc_offsets.begin(), [](size_t& offset) { return offset * sizeof(float);}); @@ -2058,7 +2069,8 @@ static Eltwise::executorPtr buildExecutor(const EltwiseKey& key) { key.inpPrc, key.outPrc, key.postOps, - key.implType == EltwiseImplType::optimizedShapeAgnostic); + key.implType == EltwiseImplType::optimizedShapeAgnostic, + key.doConstantSaturation); } bool Eltwise::isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept { @@ -2737,6 +2749,13 @@ void Eltwise::prepareParams() { "'"); } } + key.doConstantSaturation = false; + for (size_t i = 0; i < getParentEdges().size(); i++) { + if (!getParentEdgeAt(i)->getParent()->isConstant()) { + key.doConstantSaturation = true; + break; + } + } auto cache = context->getParamsCache(); auto result = cache->getOrCreate(key, buildExecutor); diff --git a/src/plugins/intel_cpu/src/nodes/eltwise.h b/src/plugins/intel_cpu/src/nodes/eltwise.h index 6013ce732ee5fc..66a04c6a34fd7a 100644 --- a/src/plugins/intel_cpu/src/nodes/eltwise.h +++ b/src/plugins/intel_cpu/src/nodes/eltwise.h @@ -42,6 +42,7 @@ struct jit_eltwise_params { size_t work_amount; bool use_runtime_ptrs; + bool do_constant_saturation; }; struct jit_eltwise_call_args_indexes {