Skip to content

Commit

Permalink
exclude separte arithmetic_mode to fully support in future and fix sa…
Browse files Browse the repository at this point in the history
…turate_input function nan,inf process issue
  • Loading branch information
liubo-intel committed Dec 6, 2024
1 parent 55bdf60 commit 72444fb
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 111 deletions.
157 changes: 56 additions & 101 deletions src/plugins/intel_cpu/src/emitters/plugin/x64/jit_bf16_emitters.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,14 @@ class jit_uni_vcvtneps2bf16 : public jit_emitter {
public:
jit_uni_vcvtneps2bf16(dnnl::impl::cpu::x64::jit_generator* host,
dnnl::impl::cpu::x64::cpu_isa_t host_isa,
ov::element::Type exec_prc = ov::element::bf16,
arithmetic_mode mode = arithmetic_mode::saturation)
ov::element::Type exec_prc = ov::element::bf16)
: jit_emitter(host, host_isa, exec_prc) {
mode_ = mode;
prepare_table();
}

size_t get_inputs_num() const override {
return 1;
}
size_t get_inputs_num() const override { return 1; }

private:
arithmetic_mode mode_ = arithmetic_mode::saturation;
void emit_impl(const std::vector<size_t>& in_vec_idxs, const std::vector<size_t>& out_vec_idxs) const override {
if (host_isa_ == dnnl::impl::cpu::x64::avx512_core) {
emit_isa<dnnl::impl::cpu::x64::avx512_core>(in_vec_idxs, out_vec_idxs);
Expand All @@ -43,104 +38,67 @@ class jit_uni_vcvtneps2bf16 : public jit_emitter {
const Vmm& in,
const std::string& bf16_min_key,
const std::string& bf16_max_key) const {
Vmm bf16_bound = Vmm(aux_vec_idxs[0]);
h->uni_vmovups(bf16_bound, table_val(bf16_min_key));
h->uni_vmaxps(clamped, in, bf16_bound);
h->uni_vmovups(bf16_bound, table_val(bf16_max_key));
h->uni_vminps(clamped, clamped, bf16_bound);
Vmm vmm_temp = Vmm(aux_vec_idxs[1]);
h->uni_vmovups(vmm_temp, table_val(bf16_min_key));
h->uni_vmaxps(clamped, in, vmm_temp);
h->uni_vmovups(vmm_temp, table_val(bf16_max_key));
h->uni_vminps(clamped, clamped, vmm_temp);

h->uni_vmovups(vmm_temp, table_val("selector"));
h->vfixupimmps(clamped, in, vmm_temp, 0);
}

template <dnnl::impl::cpu::x64::cpu_isa_t isa>
void emit_isa(const std::vector<size_t>& in_vec_idxs, const std::vector<size_t>& out_vec_idxs) const {
using namespace Xbyak;
using Vmm = typename dnnl::impl::utils::
conditional3<isa == dnnl::impl::cpu::x64::sse41, Xmm, isa == dnnl::impl::cpu::x64::avx2, Ymm, Zmm>::type;
using Vmm = typename dnnl::impl::utils::conditional3<isa == dnnl::impl::cpu::x64::sse41, Xmm, isa == dnnl::impl::cpu::x64::avx2, Ymm, Zmm>::type;

Vmm in = Vmm(in_vec_idxs[0]);

if (mode_ == arithmetic_mode::saturation) {
Vmm clamped = Vmm(aux_vec_idxs[1]);
saturate_input(clamped, in, "bf16_min", "bf16_max");
if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_bf16)) {
Ymm out = Ymm(out_vec_idxs[0]);
h->vcvtneps2bf16(out, clamped);
} else if (host_isa_ == dnnl::impl::cpu::x64::cpu_isa_t::avx512_core) {
Zmm aux = Zmm(aux_vec_idxs[0]);
Zmm aux1 = Zmm(aux_vec_idxs[2]);
Ymm out = Ymm(out_vec_idxs[0]);

h->uni_vpsrld(aux, clamped, 16);
h->vpandd(aux, aux, table_val("one"));
h->uni_vmovups(aux1, table_val("even"));
h->uni_vpaddd(aux, aux1, aux);
h->uni_vpaddd(aux, clamped, aux);
h->vfixupimmps(aux, clamped, table_val("selector"), 0);
h->vpsrad(aux, aux, 16);
h->vpmovdw(out, aux);
} else if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::cpu_isa_t::avx2_vnni_2)) {
Xmm out = Xmm(out_vec_idxs[0]);
h->vcvtneps2bf16(out, clamped, PreferredEncoding::VexEncoding);
Vmm clamped = Vmm(aux_vec_idxs[0]);
saturate_input(clamped, in, "bf16_min", "bf16_max");

if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_bf16)) {
Ymm out = Ymm(out_vec_idxs[0]);
h->vcvtneps2bf16(out, clamped);
} else if (host_isa_ == dnnl::impl::cpu::x64::cpu_isa_t::avx512_core) {
Zmm aux = Zmm(aux_vec_idxs[1]);
Zmm aux1 = Zmm(aux_vec_idxs[2]);
Ymm out = Ymm(out_vec_idxs[0]);

h->uni_vpsrld(aux, clamped, 16);
h->vpandd(aux, aux, table_val("one"));
h->uni_vmovups(aux1, table_val("even"));
h->uni_vpaddd(aux, aux1, aux);
h->uni_vpaddd(aux, clamped, aux);
h->vpsrad(aux, aux, 16);
h->vpmovdw(out, aux);
} else if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::cpu_isa_t::avx2_vnni_2)) {
Xmm out = Xmm(out_vec_idxs[0]);
h->vcvtneps2bf16(out, clamped, PreferredEncoding::VexEncoding);
} else { // round_to_nearest_even emulation
Vmm aux = Vmm(aux_vec_idxs[1]);
Xmm out = Xmm(out_vec_idxs[0]);

if (host_isa_ == dnnl::impl::cpu::x64::cpu_isa_t::avx2) {
h->uni_vandps(aux, clamped, table_val("rounding"));
} else {
Xmm out = Xmm(out_vec_idxs[0]);
Vmm aux = Vmm(aux_vec_idxs[0]);

if (host_isa_ == dnnl::impl::cpu::x64::cpu_isa_t::avx2) {
h->uni_vandps(aux, clamped, table_val("rounding"));
} else {
h->uni_vmovups(aux, clamped);
h->uni_vandps(aux, aux, table_val("rounding"));
}

h->uni_vpsrld(clamped, clamped, 16);
h->uni_vandps(clamped, clamped, table_val("mask_truncation_word"));
h->uni_vpackusdw(clamped, clamped, clamped);

if (host_isa_ == dnnl::impl::cpu::x64::cpu_isa_t::avx2) {
h->vpermq(Ymm(clamped.getIdx()), Ymm(clamped.getIdx()), 0xD8); // 11 01 10 00
h->vextracti128(out, Ymm(clamped.getIdx()), 0);
} else {
h->uni_vmovups(out, clamped);
}
h->uni_vmovups(aux, clamped);
h->uni_vandps(aux, aux, table_val("rounding"));
}
} else {
if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core)) {
Zmm aux = Zmm(aux_vec_idxs[0]);
Zmm aux1 = Zmm(aux_vec_idxs[1]);
Ymm out = Ymm(out_vec_idxs[0]);

h->uni_vpsrld(aux, in, 16);
h->vpandd(aux, aux, table_val("one"));
h->uni_vmovups(aux1, table_val("even"));
h->uni_vpaddd(aux, aux1, aux);
h->uni_vpaddd(aux, in, aux);
h->vfixupimmps(aux, in, table_val("selector"), 0);
h->vpsrad(aux, aux, 16);
h->vpmovdw(out, aux);

h->uni_vpsrld(aux, aux, 1);
h->uni_vpaddd(aux, aux, clamped);
h->uni_vpsrld(aux, aux, 16);

// dword to word using truncation
h->uni_vandps(aux, aux, table_val("mask_truncation_word"));
h->uni_vpackusdw(aux, aux, aux);

if (host_isa_ == dnnl::impl::cpu::x64::cpu_isa_t::avx2) {
h->vpermq(Ymm(aux.getIdx()), Ymm(aux.getIdx()), 0xD8); //11 01 10 00
h->vextracti128(out, Ymm(aux.getIdx()), 0);
} else {
Vmm aux = Vmm(aux_vec_idxs[0]);
Xmm out = Xmm(out_vec_idxs[0]);

if (host_isa_ == dnnl::impl::cpu::x64::cpu_isa_t::avx2) {
h->uni_vandps(aux, in, table_val("rounding"));
} else {
h->uni_vmovups(aux, in);
h->uni_vandps(aux, aux, table_val("rounding"));
}

h->uni_vpsrld(aux, aux, 1);
h->uni_vpaddd(aux, aux, in);
h->uni_vpsrld(aux, aux, 16);

// dword to word using truncation
h->uni_vandps(aux, aux, table_val("mask_truncation_word"));
h->uni_vpackusdw(aux, aux, aux);

if (host_isa_ == dnnl::impl::cpu::x64::cpu_isa_t::avx2) {
h->vpermq(Ymm(aux.getIdx()), Ymm(aux.getIdx()), 0xD8); // 11 01 10 00
h->vextracti128(out, Ymm(aux.getIdx()), 0);
} else {
h->uni_vmovups(out, aux);
}
h->uni_vmovups(out, aux);
}
}
}
Expand Down Expand Up @@ -177,12 +135,9 @@ class jit_uni_vcvtneps2bf16 : public jit_emitter {
}

size_t aux_vecs_count() const override {
if (mode_ == arithmetic_mode::saturation)
return (host_isa_ == dnnl::impl::cpu::x64::cpu_isa_t::avx512_core) ? 3 : 2;
else
return dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core) ? 2 : 1;
return host_isa_ == dnnl::impl::cpu::x64::avx512_core ? 3 : 2;
}
};

} // namespace intel_cpu
} // namespace ov
} // namespace intel_cpu
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,6 @@ void jit_convert_emitter::float2bfloat(const std::vector<size_t> &in_vec_idxs, c
jit_convert_truncation_emitter::jit_convert_truncation_emitter(jit_generator *host, cpu_isa_t host_isa,
const std::shared_ptr<ov::Node>& node, ov::element::Type exec_prc)
: jit_convert_emitter(host, host_isa, node, exec_prc) {
if (uni_vcvtneps2bf16)
uni_vcvtneps2bf16.reset(new jit_uni_vcvtneps2bf16(host, host_isa, exec_prc, arithmetic_mode::truncation));

prepare_table();
}

Expand Down
6 changes: 0 additions & 6 deletions src/plugins/intel_cpu/src/emitters/plugin/x64/jit_emitter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,6 @@ enum emitter_in_out_map {
gpr_to_gpr,
};

// Arithmetic modes for data type conversion in store_emitter
enum arithmetic_mode {
saturation,
truncation
};

// structure for storage of emitter parameters to hash in map
struct emitter_params {
virtual size_t hash() const = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,7 @@ jit_store_emitter::jit_store_emitter(dnnl::impl::cpu::x64::jit_generator *host,
prepare_table();
v_len_elt_ = get_vec_length() / exec_prc.size();
store_size_ = store_num * dst_prc.size();
uni_vcvtneps2bf16_.reset(new jit_uni_vcvtneps2bf16(host, host_isa, exec_prc, mode));
uni_vcvtneps2bf16_.reset(new jit_uni_vcvtneps2bf16(host, host_isa));
}

inline bool jit_store_emitter::is_saturation() const {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ struct store_emitter_params : public emitter_params {
int store_num_;
};

// Arithmetic modes for data type conversion in store_emitter
enum arithmetic_mode {
saturation,
truncation
};

class jit_load_emitter : public jit_emitter {
public:
jit_load_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa,
Expand Down

0 comments on commit 72444fb

Please sign in to comment.