Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cpu:aarch64: Extending support for BRGEMM General and 1x1 Forward Convolution #1983

Merged
merged 5 commits into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion src/cpu/aarch64/brgemm/jit_brgemm_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1876,7 +1876,22 @@ void jit_brgemm_kernel_t::bdb_loop() {
}

void jit_brgemm_kernel_t::generate() {
size_t simd_w_;
switch (brg.isa_impl) {
case sve_512:
simd_w_ = cpu_isa_traits<sve_512>::vlen / sizeof(float);
break;
case sve_256:
simd_w_ = cpu_isa_traits<sve_256>::vlen / sizeof(float);
break;
default: assert(!"unsupported isa");
}
preamble();
if (simd_w_ != cpu_sveLen / sizeof(float)) {
set_preg(P_ALL_ONE.b, simd_w_ * 4, X_TMP_0, X_TMP_1);
set_preg(ld_full_mask.b, simd_w_ * 4, X_TMP_0, X_TMP_1);
} else
ptrue(ld_full_mask.b);

mov(x7, x0);
mov(x6, x1);
Expand All @@ -1896,7 +1911,6 @@ void jit_brgemm_kernel_t::generate() {
brg.req_s8s8_compensation)
&& IMPLICATION(!vpad_exist, brg.req_cal_comp_pads);

ptrue(ld_full_mask.b);
set_preg(ld_tail_mask.s, brg.ldb_tail, X_TMP_0, X_TMP_1);
if (brg.is_int8 && !brg.has_int8_vnni) { assert(!"unsupported\n"); }

Expand Down
1 change: 1 addition & 0 deletions src/cpu/aarch64/jit_brgemm_1x1_conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,7 @@ status_t brgemm_1x1_convolution_fwd_t<isa>::execute_forward_all(
}

template struct brgemm_1x1_convolution_fwd_t<sve_512>;
template struct brgemm_1x1_convolution_fwd_t<sve_256>;

} // namespace aarch64
} // namespace cpu
Expand Down
5 changes: 3 additions & 2 deletions src/cpu/aarch64/jit_brgemm_conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,7 @@ status_t brgemm_convolution_fwd_t<isa, use_inversion>::add_po_kernel(
bcfg->alpha = !is_init && IMPLICATION(jcp.with_sum, jcp.use_buffer);
bcfg->beta = is_init ? 0 : 1;
CHECK(safe_ptr_assign(kernels_po_[ker_idx],
new jit_brgemm_kernel_post_ops(jcp, *bcfg, *_pd->attr())));
new jit_brgemm_kernel_post_ops<isa>(jcp, *bcfg, *_pd->attr())));
dzarukin marked this conversation as resolved.
Show resolved Hide resolved
kernels_po_[ker_idx]->create_kernel();
return status::success;
}
Expand Down Expand Up @@ -810,7 +810,7 @@ status_t brgemm_convolution_fwd_t<isa, use_inversion>::init(engine_t *engine) {

if (jcp.req_cal_comp_pad) {
CHECK(safe_ptr_assign(comp_vpad_pbuffer_,
new jit_uni_brgemm_conv_comp_pad_kernel_t(jcp)));
new jit_uni_brgemm_conv_comp_pad_kernel_t<isa>(jcp)));
CHECK(comp_vpad_pbuffer_->create_kernel());
}

Expand Down Expand Up @@ -2025,6 +2025,7 @@ void brgemm_convolution_fwd_t<isa, use_inversion>::ker_vpad(

#undef BRGEMM_CONV_KER_HEADER
template struct brgemm_convolution_fwd_t<sve_512>;
template struct brgemm_convolution_fwd_t<sve_256>;

} // namespace aarch64

Expand Down
2 changes: 1 addition & 1 deletion src/cpu/aarch64/jit_brgemm_conv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ struct brgemm_convolution_fwd_t : public primitive_t {

brgemm_containers::brgemm_kernel_container_t brgemm_kernels_;

std::vector<std::unique_ptr<jit_brgemm_kernel_post_ops>> kernels_po_;
std::vector<std::unique_ptr<jit_brgemm_kernel_post_ops<isa>>> kernels_po_;
std::unique_ptr<jit_sve_core_brgemm_conv_trans_kernel::
jit_sve_core_brgemm_conv_trans_kernel_t>
copy_to_pbuffer_;
Expand Down
44 changes: 31 additions & 13 deletions src/cpu/aarch64/jit_brgemm_conv_comp_pad_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@ namespace jit_uni_brgemm_conv_comp_pad_kernel {

#define GET_OFF(field) offsetof(jit_brgemm_conv_comp_pad_call_s, field)

jit_uni_brgemm_conv_comp_pad_kernel_t::jit_uni_brgemm_conv_comp_pad_kernel_t(
const jit_brgemm_conv_conf_t &ajcp)
template <cpu_isa_t isa>
jit_uni_brgemm_conv_comp_pad_kernel_t<isa>::
jit_uni_brgemm_conv_comp_pad_kernel_t(
const jit_brgemm_conv_conf_t &ajcp)
: jcp_(ajcp)
, inp_dsz_(jcp_.wei_dsz)
, out_dsz_(jcp_.acc_dsz)
Expand All @@ -44,20 +46,27 @@ jit_uni_brgemm_conv_comp_pad_kernel_t::jit_uni_brgemm_conv_comp_pad_kernel_t(
, inp_kd_sz_(static_cast<size_t>(jcp_.kh) * inp_kh_sz_)
, isa_max_regs(isa_num_vregs(jcp_.isa)) {}

size_t jit_uni_brgemm_conv_comp_pad_kernel_t::out_oc_offset(const int n) const {
template <cpu_isa_t isa>
size_t jit_uni_brgemm_conv_comp_pad_kernel_t<isa>::out_oc_offset(
const int n) const {
return static_cast<size_t>(out_dsz_) * n * m_block2_;
}
size_t jit_uni_brgemm_conv_comp_pad_kernel_t::inp_ic_offset(

template <cpu_isa_t isa>
size_t jit_uni_brgemm_conv_comp_pad_kernel_t<isa>::inp_ic_offset(
const int m_block, const int icb, const int m, const int n) const {
return static_cast<size_t>(inp_dsz_) * n * m_block2_ * last_ic_block_
+ ((icb * m_block) + m) * inp_ic_sz_;
}
Xbyak_aarch64::ZReg jit_uni_brgemm_conv_comp_pad_kernel_t::accum(

template <cpu_isa_t isa>
Xbyak_aarch64::ZReg jit_uni_brgemm_conv_comp_pad_kernel_t<isa>::accum(
const int n_block, const int m, const int n) const {
return Xbyak_aarch64::ZReg(m * n_block + n);
}

void jit_uni_brgemm_conv_comp_pad_kernel_t::store_accumulators(
template <cpu_isa_t isa>
void jit_uni_brgemm_conv_comp_pad_kernel_t<isa>::store_accumulators(
const int m_block, const int n_block) {
if (jcp_.src_zero_point) {
for_(int m = 0; m < m_block; m++)
Expand Down Expand Up @@ -100,7 +109,8 @@ void jit_uni_brgemm_conv_comp_pad_kernel_t::store_accumulators(
}
}

void jit_uni_brgemm_conv_comp_pad_kernel_t::zero_accumulators(
template <cpu_isa_t isa>
void jit_uni_brgemm_conv_comp_pad_kernel_t<isa>::zero_accumulators(
const int m_block, const int n_block) {
for_(int m = 0; m < m_block; m++)
for (int n = 0; n < n_block; n++) {
Expand All @@ -109,7 +119,8 @@ void jit_uni_brgemm_conv_comp_pad_kernel_t::zero_accumulators(
}
}

void jit_uni_brgemm_conv_comp_pad_kernel_t::compute(const int ic_step,
template <cpu_isa_t isa>
void jit_uni_brgemm_conv_comp_pad_kernel_t<isa>::compute(const int ic_step,
const int m_block, const int n_block, const int m_tail,
const bool is_mb_tail) {

Expand All @@ -126,7 +137,8 @@ void jit_uni_brgemm_conv_comp_pad_kernel_t::compute(const int ic_step,
}
}

void jit_uni_brgemm_conv_comp_pad_kernel_t::icb_loop(const int icb,
template <cpu_isa_t isa>
void jit_uni_brgemm_conv_comp_pad_kernel_t<isa>::icb_loop(const int icb,
const int icb_tail, const int ic_step, const int m_block,
const int mb_tail, const int n_block) {
Xbyak_aarch64::Label label_icb_loop, label_loop_end;
Expand All @@ -149,7 +161,8 @@ void jit_uni_brgemm_conv_comp_pad_kernel_t::icb_loop(const int icb,
if (icb_tail) compute(ic_step, mb_tail, n_block, icb_tail, true);
}

void jit_uni_brgemm_conv_comp_pad_kernel_t::khw_loop(const int icb,
template <cpu_isa_t isa>
void jit_uni_brgemm_conv_comp_pad_kernel_t<isa>::khw_loop(const int icb,
const int icb_tail, const int ic_step, const int m_block,
const int mb_tail, const int n_block) {
Xbyak_aarch64::Label label_kw_loop, label_kw_end, label_kh_loop,
Expand Down Expand Up @@ -181,13 +194,15 @@ void jit_uni_brgemm_conv_comp_pad_kernel_t::khw_loop(const int icb,
L_aligned(label_kh_end);
}

void jit_uni_brgemm_conv_comp_pad_kernel_t::load_params() {
template <cpu_isa_t isa>
void jit_uni_brgemm_conv_comp_pad_kernel_t<isa>::load_params() {
add_imm(reg_in, param1, GET_OFF(ptr_in), X_TMP_0);
add_imm(reg_zp_comp_out, param1, GET_OFF(ptr_zp_out), X_TMP_1);
add_imm(reg_comp_out, param1, GET_OFF(ptr_cp_out), X_TMP_2);
}

int jit_uni_brgemm_conv_comp_pad_kernel_t::compute_ic_step(
template <cpu_isa_t isa>
int jit_uni_brgemm_conv_comp_pad_kernel_t<isa>::compute_ic_step(
const int m_max_regs, const int m_block, const int n_block) const {
int best_ic_step = 1;
float best_block_eff = 0.f;
Expand Down Expand Up @@ -217,7 +232,8 @@ int jit_uni_brgemm_conv_comp_pad_kernel_t::compute_ic_step(
return best_ic_step;
}

void jit_uni_brgemm_conv_comp_pad_kernel_t::generate() {
template <cpu_isa_t isa>
void jit_uni_brgemm_conv_comp_pad_kernel_t<isa>::generate() {
preamble();

load_params();
Expand Down Expand Up @@ -280,6 +296,8 @@ void jit_uni_brgemm_conv_comp_pad_kernel_t::generate() {
postamble();
}

template struct jit_uni_brgemm_conv_comp_pad_kernel_t<sve_512>;
template struct jit_uni_brgemm_conv_comp_pad_kernel_t<sve_256>;
} // namespace jit_uni_brgemm_conv_comp_pad_kernel

} // namespace aarch64
Expand Down
6 changes: 4 additions & 2 deletions src/cpu/aarch64/jit_brgemm_conv_comp_pad_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,15 @@ struct jit_brgemm_conv_comp_pad_call_s {
size_t kd_l;
};

template <cpu_isa_t isa>
struct jit_uni_brgemm_conv_comp_pad_kernel_t : public jit_generator {

DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_brgemm_conv_comp_pad_kernel_t)

using XReg = const Xbyak_aarch64::XReg;

jit_uni_brgemm_conv_comp_pad_kernel_t(const jit_brgemm_conv_conf_t &ajcp);
jit_uni_brgemm_conv_comp_pad_kernel_t<isa>(
const jit_brgemm_conv_conf_t &ajcp);

~jit_uni_brgemm_conv_comp_pad_kernel_t() = default;

Expand Down Expand Up @@ -85,7 +87,7 @@ struct jit_uni_brgemm_conv_comp_pad_kernel_t : public jit_generator {

const int last_ic_block_ = 4;
const int n_block2_ = 4;
const int m_block2_ = cpu_isa_traits<sve_512>::vlen / sizeof(int32_t);
const int m_block2_ = cpu_isa_traits<isa>::vlen / sizeof(int32_t);
const int n_max_regs_ = 4;

const Xbyak_aarch64::ZReg &vmm_tmp_1() const noexcept { return vmm_tmp; }
Expand Down
37 changes: 33 additions & 4 deletions src/cpu/aarch64/jit_brgemm_post_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,22 @@ struct jit_brgemm_kernel_diff_bias_t : public jit_generator {
}

void generate() override {
size_t simd_w_;
switch (brg_.isa_impl) {
case sve_512:
simd_w_ = cpu_isa_traits<sve_512>::vlen / sizeof(float);
break;
case sve_256:
simd_w_ = cpu_isa_traits<sve_256>::vlen / sizeof(float);
break;
default: assert(!"unsupported isa");
}
preamble();
if (simd_w_ != cpu_sveLen / sizeof(float)) {
set_preg(P_ALL_ONE.b, simd_w_ * 4, X_TMP_0, X_TMP_1);
set_preg(k_full_mask.b, simd_w_ * 4, X_TMP_0, X_TMP_1);
} else
ptrue(k_full_mask.b);

int nb = utils::div_up(brg_.load_dim, brg_.ld_block);
int nb_tail = brg_.load_dim % brg_.ld_block;
Expand All @@ -208,7 +223,6 @@ struct jit_brgemm_kernel_diff_bias_t : public jit_generator {
n_loop_tail = n_max_regs_;
}

ptrue(k_full_mask.b);
set_preg(k_tail_mask.s, nb_tail, X_TMP_0, X_TMP_1);
pfalse(P_TMP_0.b);
zip1(k_tail_mask.b, k_tail_mask.b, P_TMP_0.b);
Expand Down Expand Up @@ -263,6 +277,7 @@ struct brgemm_kernel_post_ops_t {
void *ptr_dst_scales;
};

template <cpu_isa_t isa>
struct jit_brgemm_kernel_post_ops : public jit_generator {

jit_brgemm_kernel_post_ops(const jit_brgemm_conv_conf_t &ajcp,
Expand Down Expand Up @@ -299,7 +314,7 @@ struct jit_brgemm_kernel_post_ops : public jit_generator {
save_state, reserved_eltwise_gpr, reserved_eltwise_maskr};

postops_injector_ = utils::make_unique<
injector::jit_uni_postops_injector_t<sve_512>>(
injector::jit_uni_postops_injector_t<isa>>(
this, attr.post_ops_, bsp, esp);
}

Expand Down Expand Up @@ -332,7 +347,7 @@ struct jit_brgemm_kernel_post_ops : public jit_generator {
data_type_t inp_dt_;
data_type_t out_dt_;
data_type_t bia_dt_;
std::unique_ptr<injector::jit_uni_postops_injector_t<sve_512>>
std::unique_ptr<injector::jit_uni_postops_injector_t<isa>>
postops_injector_;

const bool with_binary_non_scalar_bcast_;
Expand Down Expand Up @@ -835,7 +850,22 @@ struct jit_brgemm_kernel_post_ops : public jit_generator {
}

void generate() override {
size_t simd_w_;
switch (brg.isa_impl) {
case sve_512:
simd_w_ = cpu_isa_traits<sve_512>::vlen / sizeof(float);
break;
case sve_256:
simd_w_ = cpu_isa_traits<sve_256>::vlen / sizeof(float);
break;
default: assert(!"unsupported isa");
}
preamble();
if (simd_w_ != cpu_sveLen / sizeof(float)) {
set_preg(P_ALL_ONE.b, simd_w_ * 4, X_TMP_0, X_TMP_1);
set_preg(k_full_mask.b, simd_w_ * 4, X_TMP_0, X_TMP_1);
} else
ptrue(k_full_mask.b);

mov(x7, x0);
mov(x6, x1);
Expand All @@ -858,7 +888,6 @@ struct jit_brgemm_kernel_post_ops : public jit_generator {
int mb = brg.bcast_dim / m_block;
int mb_tail = brg.bcast_dim % m_block;

ptrue(k_full_mask.b);
set_preg(k_tail_mask.s, nb_tail, X_TMP_0, X_TMP_1);

if (brg.alpha != 0) {
Expand Down
2 changes: 2 additions & 0 deletions src/cpu/cpu_convolution_list.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ const std::map<pk_dt_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map()
CPU_INSTANCE_AARCH64_ACL(acl_depthwise_convolution_fwd_t)
CPU_INSTANCE_AARCH64_ACL(acl_indirect_gemm_convolution_fwd_t)
CPU_INSTANCE_AARCH64_ACL(acl_gemm_convolution_fwd_t<f32>)
CPU_INSTANCE_AARCH64(brgemm_1x1_convolution_fwd_t<sve_256>)
CPU_INSTANCE_AARCH64(brgemm_convolution_fwd_t<sve_256>)
CPU_INSTANCE_AARCH64(jit_sve_convolution_fwd_t<f32,f32,f32,sve_256>)
CPU_INSTANCE(gemm_convolution_fwd_t)
CPU_INSTANCE(ref_convolution_fwd_t)
Expand Down
Loading