Skip to content

Commit

Permalink
[CPU] [ARM64] jit abs (openvinotoolkit#23692)
Browse files Browse the repository at this point in the history
### Details:
 - *[CPU] [AARCH64] jit abs*

### Tickets:
 - *CVS-136833*
  • Loading branch information
eshoguli authored Apr 1, 2024
1 parent e22191c commit 3b845e6
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,43 @@ ov::element::Type get_arithmetic_binary_exec_precision(const std::shared_ptr<ov:
}
} // namespace

/// ABS ///
jit_abs_emitter::jit_abs_emitter(dnnl::impl::cpu::aarch64::jit_generator* host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const std::shared_ptr<ov::Node>& node)
: jit_emitter(host, host_isa, node, get_arithmetic_binary_exec_precision(node)) {
}

jit_abs_emitter::jit_abs_emitter(dnnl::impl::cpu::aarch64::jit_generator* host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const ov::element::Type exec_prc) : jit_emitter(host, host_isa, exec_prc) {
}

size_t jit_abs_emitter::get_inputs_count() const { return 1; }

void jit_abs_emitter::emit_impl(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const {
if (host_isa_ == dnnl::impl::cpu::aarch64::asimd) {
emit_isa<dnnl::impl::cpu::aarch64::asimd>(in_vec_idxs, out_vec_idxs);
} else {
OV_CPU_JIT_EMITTER_THROW("Can't create jit eltwise kernel");
}
}

template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
void jit_abs_emitter::emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const {
OV_CPU_JIT_EMITTER_ASSERT(exec_prc_ == ov::element::f32, "unsupported precision: " + exec_prc_.to_string());

using TReg = typename dnnl::impl::cpu::aarch64::cpu_isa_traits<isa>::TReg;
TReg src = TReg(in_vec_idxs[0]);
TReg dst = TReg(out_vec_idxs[0]);

h->fabs(dst.s, src.s);
}

std::set<std::vector<element::Type>> jit_abs_emitter::get_supported_precisions(const std::shared_ptr<ov::Node>& node) {
return {{element::f32}};
}

/// ADD ///
jit_add_emitter::jit_add_emitter(dnnl::impl::cpu::aarch64::jit_generator* host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,27 @@ namespace ov {
namespace intel_cpu {
namespace aarch64 {

class jit_abs_emitter : public jit_emitter {
public:
jit_abs_emitter(dnnl::impl::cpu::aarch64::jit_generator *host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const ov::element::Type exec_prc = ov::element::f32);

jit_abs_emitter(dnnl::impl::cpu::aarch64::jit_generator *host,
dnnl::impl::cpu::aarch64::cpu_isa_t host_isa,
const std::shared_ptr<ov::Node>& node);

size_t get_inputs_count() const override;

static std::set<std::vector<element::Type>> get_supported_precisions(const std::shared_ptr<ov::Node>& node = nullptr);

private:
void emit_impl(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const override;

template <dnnl::impl::cpu::aarch64::cpu_isa_t isa>
void emit_isa(const std::vector<size_t> &in_vec_idxs, const std::vector<size_t> &out_vec_idxs) const;
};

class jit_add_emitter : public jit_emitter {
public:
jit_add_emitter(dnnl::impl::cpu::aarch64::jit_generator *host,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ bool JitEltwiseExecutor::isSupported(
const float beta,
const float gamma) {
const auto is_supported = one_of(algorithm,
Algorithm::EltwiseAbs,
Algorithm::EltwiseAdd,
Algorithm::EltwiseClamp,
Algorithm::EltwiseDivide,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -451,13 +451,13 @@ void jit_uni_eltwise_generic<isa>::store_vector(const XReg& ptr,
break;
}
case ov::element::i8: {
fcvtns(data.s, data.s);
fcvtms(data.s, data.s);
xtn(data.h4, data.s4);
xtn(data.b8, data.h8);
break;
}
case ov::element::u8: {
fcvtnu(data.s, data.s);
fcvtmu(data.s, data.s);
xtn(data.h4, data.s4);
xtn(data.b8, data.h8);
break;
Expand Down Expand Up @@ -515,14 +515,14 @@ void jit_uni_eltwise_generic<isa>::store_scalar(const XReg& ptr,
}
case ov::element::i8: {
TReg vec_data(data.getIdx());
fcvtns(vec_data.s, vec_data.s);
fcvtms(vec_data.s, vec_data.s);
xtn(vec_data.h4, vec_data.s4);
xtn(vec_data.b8, vec_data.h8);
break;
}
case ov::element::u8: {
TReg vec_data(data.getIdx());
fcvtnu(vec_data.s, vec_data.s);
fcvtmu(vec_data.s, vec_data.s);
xtn(vec_data.h4, vec_data.s4);
xtn(vec_data.b8, vec_data.h8);
break;
Expand Down Expand Up @@ -609,6 +609,7 @@ std::shared_ptr<jit_emitter> jit_uni_eltwise_generic<isa>::create_eltwise_emitte
};

OV_SWITCH(intel_cpu, EltwiseEmitter, ctx, data.algo,
OV_CASE(Algorithm::EltwiseAbs, ov::intel_cpu::aarch64::jit_abs_emitter),
OV_CASE(Algorithm::EltwiseAdd, ov::intel_cpu::aarch64::jit_add_emitter),
OV_CASE(Algorithm::EltwiseClamp, ov::intel_cpu::aarch64::jit_clamp_emitter),
OV_CASE(Algorithm::EltwiseDivide, ov::intel_cpu::aarch64::jit_divide_emitter),
Expand Down Expand Up @@ -767,6 +768,7 @@ std::set<std::vector<element::Type>> eltwise_precision_helper::get_supported_pre

OV_SWITCH(intel_cpu, SupportedPrecisions, precisions, algo,
OV_CASE(Algorithm::EltwiseRelu, jit_relu_emitter),
OV_CASE(Algorithm::EltwiseAbs, jit_abs_emitter),
OV_CASE(Algorithm::EltwiseAdd, jit_add_emitter),
OV_CASE(Algorithm::EltwiseClamp, jit_clamp_emitter),
OV_CASE(Algorithm::EltwiseDivide, jit_divide_emitter),
Expand Down

0 comments on commit 3b845e6

Please sign in to comment.