Skip to content

Commit

Permalink
cpu:aarch64: Extending support for BRGEMM General and Forward Convolu…
Browse files Browse the repository at this point in the history
…tion

Co-authored-by: Deeksha Kasture/[email protected] <[email protected]>
Co-authored-by: Kasture Deeksha <[email protected]>
  • Loading branch information
3 people authored Jul 12, 2024
1 parent ef992b1 commit c5e4ce7
Show file tree
Hide file tree
Showing 8 changed files with 90 additions and 23 deletions.
16 changes: 15 additions & 1 deletion src/cpu/aarch64/brgemm/jit_brgemm_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1876,7 +1876,22 @@ void jit_brgemm_kernel_t::bdb_loop() {
}

void jit_brgemm_kernel_t::generate() {
size_t simd_w_;
switch (brg.isa_impl) {
case sve_512:
simd_w_ = cpu_isa_traits<sve_512>::vlen / sizeof(float);
break;
case sve_256:
simd_w_ = cpu_isa_traits<sve_256>::vlen / sizeof(float);
break;
default: assert(!"unsupported isa");
}
preamble();
if (simd_w_ != cpu_sveLen / sizeof(float)) {
set_preg(P_ALL_ONE.b, simd_w_ * 4, X_TMP_0, X_TMP_1);
set_preg(ld_full_mask.b, simd_w_ * 4, X_TMP_0, X_TMP_1);
} else
ptrue(ld_full_mask.b);

mov(x7, x0);
mov(x6, x1);
Expand All @@ -1896,7 +1911,6 @@ void jit_brgemm_kernel_t::generate() {
brg.req_s8s8_compensation)
&& IMPLICATION(!vpad_exist, brg.req_cal_comp_pads);

ptrue(ld_full_mask.b);
set_preg(ld_tail_mask.s, brg.ldb_tail, X_TMP_0, X_TMP_1);
if (brg.is_int8 && !brg.has_int8_vnni) { assert(!"unsupported\n"); }

Expand Down
1 change: 1 addition & 0 deletions src/cpu/aarch64/jit_brgemm_1x1_conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,7 @@ status_t brgemm_1x1_convolution_fwd_t<isa>::execute_forward_all(
}

template struct brgemm_1x1_convolution_fwd_t<sve_512>;
template struct brgemm_1x1_convolution_fwd_t<sve_256>;

} // namespace aarch64
} // namespace cpu
Expand Down
5 changes: 3 additions & 2 deletions src/cpu/aarch64/jit_brgemm_conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,7 @@ status_t brgemm_convolution_fwd_t<isa, use_inversion>::add_po_kernel(
bcfg->alpha = !is_init && IMPLICATION(jcp.with_sum, jcp.use_buffer);
bcfg->beta = is_init ? 0 : 1;
CHECK(safe_ptr_assign(kernels_po_[ker_idx],
new jit_brgemm_kernel_post_ops(jcp, *bcfg, *_pd->attr())));
new jit_brgemm_kernel_post_ops<isa>(jcp, *bcfg, *_pd->attr())));
kernels_po_[ker_idx]->create_kernel();
return status::success;
}
Expand Down Expand Up @@ -810,7 +810,7 @@ status_t brgemm_convolution_fwd_t<isa, use_inversion>::init(engine_t *engine) {

if (jcp.req_cal_comp_pad) {
CHECK(safe_ptr_assign(comp_vpad_pbuffer_,
new jit_uni_brgemm_conv_comp_pad_kernel_t(jcp)));
new jit_uni_brgemm_conv_comp_pad_kernel_t<isa>(jcp)));
CHECK(comp_vpad_pbuffer_->create_kernel());
}

Expand Down Expand Up @@ -2025,6 +2025,7 @@ void brgemm_convolution_fwd_t<isa, use_inversion>::ker_vpad(

#undef BRGEMM_CONV_KER_HEADER
template struct brgemm_convolution_fwd_t<sve_512>;
template struct brgemm_convolution_fwd_t<sve_256>;

} // namespace aarch64

Expand Down
2 changes: 1 addition & 1 deletion src/cpu/aarch64/jit_brgemm_conv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ struct brgemm_convolution_fwd_t : public primitive_t {

brgemm_containers::brgemm_kernel_container_t brgemm_kernels_;

std::vector<std::unique_ptr<jit_brgemm_kernel_post_ops>> kernels_po_;
std::vector<std::unique_ptr<jit_brgemm_kernel_post_ops<isa>>> kernels_po_;
std::unique_ptr<jit_sve_core_brgemm_conv_trans_kernel::
jit_sve_core_brgemm_conv_trans_kernel_t>
copy_to_pbuffer_;
Expand Down
44 changes: 31 additions & 13 deletions src/cpu/aarch64/jit_brgemm_conv_comp_pad_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@ namespace jit_uni_brgemm_conv_comp_pad_kernel {

#define GET_OFF(field) offsetof(jit_brgemm_conv_comp_pad_call_s, field)

jit_uni_brgemm_conv_comp_pad_kernel_t::jit_uni_brgemm_conv_comp_pad_kernel_t(
const jit_brgemm_conv_conf_t &ajcp)
template <cpu_isa_t isa>
jit_uni_brgemm_conv_comp_pad_kernel_t<isa>::
jit_uni_brgemm_conv_comp_pad_kernel_t(
const jit_brgemm_conv_conf_t &ajcp)
: jcp_(ajcp)
, inp_dsz_(jcp_.wei_dsz)
, out_dsz_(jcp_.acc_dsz)
Expand All @@ -44,20 +46,27 @@ jit_uni_brgemm_conv_comp_pad_kernel_t::jit_uni_brgemm_conv_comp_pad_kernel_t(
, inp_kd_sz_(static_cast<size_t>(jcp_.kh) * inp_kh_sz_)
, isa_max_regs(isa_num_vregs(jcp_.isa)) {}

size_t jit_uni_brgemm_conv_comp_pad_kernel_t::out_oc_offset(const int n) const {
template <cpu_isa_t isa>
size_t jit_uni_brgemm_conv_comp_pad_kernel_t<isa>::out_oc_offset(
const int n) const {
return static_cast<size_t>(out_dsz_) * n * m_block2_;
}
size_t jit_uni_brgemm_conv_comp_pad_kernel_t::inp_ic_offset(

template <cpu_isa_t isa>
size_t jit_uni_brgemm_conv_comp_pad_kernel_t<isa>::inp_ic_offset(
const int m_block, const int icb, const int m, const int n) const {
return static_cast<size_t>(inp_dsz_) * n * m_block2_ * last_ic_block_
+ ((icb * m_block) + m) * inp_ic_sz_;
}
Xbyak_aarch64::ZReg jit_uni_brgemm_conv_comp_pad_kernel_t::accum(

template <cpu_isa_t isa>
Xbyak_aarch64::ZReg jit_uni_brgemm_conv_comp_pad_kernel_t<isa>::accum(
const int n_block, const int m, const int n) const {
return Xbyak_aarch64::ZReg(m * n_block + n);
}

void jit_uni_brgemm_conv_comp_pad_kernel_t::store_accumulators(
template <cpu_isa_t isa>
void jit_uni_brgemm_conv_comp_pad_kernel_t<isa>::store_accumulators(
const int m_block, const int n_block) {
if (jcp_.src_zero_point) {
for_(int m = 0; m < m_block; m++)
Expand Down Expand Up @@ -100,7 +109,8 @@ void jit_uni_brgemm_conv_comp_pad_kernel_t::store_accumulators(
}
}

void jit_uni_brgemm_conv_comp_pad_kernel_t::zero_accumulators(
template <cpu_isa_t isa>
void jit_uni_brgemm_conv_comp_pad_kernel_t<isa>::zero_accumulators(
const int m_block, const int n_block) {
for_(int m = 0; m < m_block; m++)
for (int n = 0; n < n_block; n++) {
Expand All @@ -109,7 +119,8 @@ void jit_uni_brgemm_conv_comp_pad_kernel_t::zero_accumulators(
}
}

void jit_uni_brgemm_conv_comp_pad_kernel_t::compute(const int ic_step,
template <cpu_isa_t isa>
void jit_uni_brgemm_conv_comp_pad_kernel_t<isa>::compute(const int ic_step,
const int m_block, const int n_block, const int m_tail,
const bool is_mb_tail) {

Expand All @@ -126,7 +137,8 @@ void jit_uni_brgemm_conv_comp_pad_kernel_t::compute(const int ic_step,
}
}

void jit_uni_brgemm_conv_comp_pad_kernel_t::icb_loop(const int icb,
template <cpu_isa_t isa>
void jit_uni_brgemm_conv_comp_pad_kernel_t<isa>::icb_loop(const int icb,
const int icb_tail, const int ic_step, const int m_block,
const int mb_tail, const int n_block) {
Xbyak_aarch64::Label label_icb_loop, label_loop_end;
Expand All @@ -149,7 +161,8 @@ void jit_uni_brgemm_conv_comp_pad_kernel_t::icb_loop(const int icb,
if (icb_tail) compute(ic_step, mb_tail, n_block, icb_tail, true);
}

void jit_uni_brgemm_conv_comp_pad_kernel_t::khw_loop(const int icb,
template <cpu_isa_t isa>
void jit_uni_brgemm_conv_comp_pad_kernel_t<isa>::khw_loop(const int icb,
const int icb_tail, const int ic_step, const int m_block,
const int mb_tail, const int n_block) {
Xbyak_aarch64::Label label_kw_loop, label_kw_end, label_kh_loop,
Expand Down Expand Up @@ -181,13 +194,15 @@ void jit_uni_brgemm_conv_comp_pad_kernel_t::khw_loop(const int icb,
L_aligned(label_kh_end);
}

void jit_uni_brgemm_conv_comp_pad_kernel_t::load_params() {
template <cpu_isa_t isa>
void jit_uni_brgemm_conv_comp_pad_kernel_t<isa>::load_params() {
add_imm(reg_in, param1, GET_OFF(ptr_in), X_TMP_0);
add_imm(reg_zp_comp_out, param1, GET_OFF(ptr_zp_out), X_TMP_1);
add_imm(reg_comp_out, param1, GET_OFF(ptr_cp_out), X_TMP_2);
}

int jit_uni_brgemm_conv_comp_pad_kernel_t::compute_ic_step(
template <cpu_isa_t isa>
int jit_uni_brgemm_conv_comp_pad_kernel_t<isa>::compute_ic_step(
const int m_max_regs, const int m_block, const int n_block) const {
int best_ic_step = 1;
float best_block_eff = 0.f;
Expand Down Expand Up @@ -217,7 +232,8 @@ int jit_uni_brgemm_conv_comp_pad_kernel_t::compute_ic_step(
return best_ic_step;
}

void jit_uni_brgemm_conv_comp_pad_kernel_t::generate() {
template <cpu_isa_t isa>
void jit_uni_brgemm_conv_comp_pad_kernel_t<isa>::generate() {
preamble();

load_params();
Expand Down Expand Up @@ -280,6 +296,8 @@ void jit_uni_brgemm_conv_comp_pad_kernel_t::generate() {
postamble();
}

template struct jit_uni_brgemm_conv_comp_pad_kernel_t<sve_512>;
template struct jit_uni_brgemm_conv_comp_pad_kernel_t<sve_256>;
} // namespace jit_uni_brgemm_conv_comp_pad_kernel

} // namespace aarch64
Expand Down
6 changes: 4 additions & 2 deletions src/cpu/aarch64/jit_brgemm_conv_comp_pad_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,15 @@ struct jit_brgemm_conv_comp_pad_call_s {
size_t kd_l;
};

template <cpu_isa_t isa>
struct jit_uni_brgemm_conv_comp_pad_kernel_t : public jit_generator {

DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_brgemm_conv_comp_pad_kernel_t)

using XReg = const Xbyak_aarch64::XReg;

jit_uni_brgemm_conv_comp_pad_kernel_t(const jit_brgemm_conv_conf_t &ajcp);
jit_uni_brgemm_conv_comp_pad_kernel_t<isa>(
const jit_brgemm_conv_conf_t &ajcp);

~jit_uni_brgemm_conv_comp_pad_kernel_t() = default;

Expand Down Expand Up @@ -85,7 +87,7 @@ struct jit_uni_brgemm_conv_comp_pad_kernel_t : public jit_generator {

const int last_ic_block_ = 4;
const int n_block2_ = 4;
const int m_block2_ = cpu_isa_traits<sve_512>::vlen / sizeof(int32_t);
const int m_block2_ = cpu_isa_traits<isa>::vlen / sizeof(int32_t);
const int n_max_regs_ = 4;

const Xbyak_aarch64::ZReg &vmm_tmp_1() const noexcept { return vmm_tmp; }
Expand Down
37 changes: 33 additions & 4 deletions src/cpu/aarch64/jit_brgemm_post_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,22 @@ struct jit_brgemm_kernel_diff_bias_t : public jit_generator {
}

void generate() override {
size_t simd_w_;
switch (brg_.isa_impl) {
case sve_512:
simd_w_ = cpu_isa_traits<sve_512>::vlen / sizeof(float);
break;
case sve_256:
simd_w_ = cpu_isa_traits<sve_256>::vlen / sizeof(float);
break;
default: assert(!"unsupported isa");
}
preamble();
if (simd_w_ != cpu_sveLen / sizeof(float)) {
set_preg(P_ALL_ONE.b, simd_w_ * 4, X_TMP_0, X_TMP_1);
set_preg(k_full_mask.b, simd_w_ * 4, X_TMP_0, X_TMP_1);
} else
ptrue(k_full_mask.b);

int nb = utils::div_up(brg_.load_dim, brg_.ld_block);
int nb_tail = brg_.load_dim % brg_.ld_block;
Expand All @@ -208,7 +223,6 @@ struct jit_brgemm_kernel_diff_bias_t : public jit_generator {
n_loop_tail = n_max_regs_;
}

ptrue(k_full_mask.b);
set_preg(k_tail_mask.s, nb_tail, X_TMP_0, X_TMP_1);
pfalse(P_TMP_0.b);
zip1(k_tail_mask.b, k_tail_mask.b, P_TMP_0.b);
Expand Down Expand Up @@ -263,6 +277,7 @@ struct brgemm_kernel_post_ops_t {
void *ptr_dst_scales;
};

template <cpu_isa_t isa>
struct jit_brgemm_kernel_post_ops : public jit_generator {

jit_brgemm_kernel_post_ops(const jit_brgemm_conv_conf_t &ajcp,
Expand Down Expand Up @@ -299,7 +314,7 @@ struct jit_brgemm_kernel_post_ops : public jit_generator {
save_state, reserved_eltwise_gpr, reserved_eltwise_maskr};

postops_injector_ = utils::make_unique<
injector::jit_uni_postops_injector_t<sve_512>>(
injector::jit_uni_postops_injector_t<isa>>(
this, attr.post_ops_, bsp, esp);
}

Expand Down Expand Up @@ -332,7 +347,7 @@ struct jit_brgemm_kernel_post_ops : public jit_generator {
data_type_t inp_dt_;
data_type_t out_dt_;
data_type_t bia_dt_;
std::unique_ptr<injector::jit_uni_postops_injector_t<sve_512>>
std::unique_ptr<injector::jit_uni_postops_injector_t<isa>>
postops_injector_;

const bool with_binary_non_scalar_bcast_;
Expand Down Expand Up @@ -835,7 +850,22 @@ struct jit_brgemm_kernel_post_ops : public jit_generator {
}

void generate() override {
size_t simd_w_;
switch (brg.isa_impl) {
case sve_512:
simd_w_ = cpu_isa_traits<sve_512>::vlen / sizeof(float);
break;
case sve_256:
simd_w_ = cpu_isa_traits<sve_256>::vlen / sizeof(float);
break;
default: assert(!"unsupported isa");
}
preamble();
if (simd_w_ != cpu_sveLen / sizeof(float)) {
set_preg(P_ALL_ONE.b, simd_w_ * 4, X_TMP_0, X_TMP_1);
set_preg(k_full_mask.b, simd_w_ * 4, X_TMP_0, X_TMP_1);
} else
ptrue(k_full_mask.b);

mov(x7, x0);
mov(x6, x1);
Expand All @@ -858,7 +888,6 @@ struct jit_brgemm_kernel_post_ops : public jit_generator {
int mb = brg.bcast_dim / m_block;
int mb_tail = brg.bcast_dim % m_block;

ptrue(k_full_mask.b);
set_preg(k_tail_mask.s, nb_tail, X_TMP_0, X_TMP_1);

if (brg.alpha != 0) {
Expand Down
2 changes: 2 additions & 0 deletions src/cpu/cpu_convolution_list.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ const std::map<pk_dt_impl_key_t, std::vector<impl_list_item_t>> &impl_list_map()
CPU_INSTANCE_AARCH64_ACL(acl_depthwise_convolution_fwd_t)
CPU_INSTANCE_AARCH64_ACL(acl_indirect_gemm_convolution_fwd_t)
CPU_INSTANCE_AARCH64_ACL(acl_gemm_convolution_fwd_t<f32>)
CPU_INSTANCE_AARCH64(brgemm_1x1_convolution_fwd_t<sve_256>)
CPU_INSTANCE_AARCH64(brgemm_convolution_fwd_t<sve_256>)
CPU_INSTANCE_AARCH64(jit_sve_convolution_fwd_t<f32,f32,f32,sve_256>)
CPU_INSTANCE(gemm_convolution_fwd_t)
CPU_INSTANCE(ref_convolution_fwd_t)
Expand Down

0 comments on commit c5e4ce7

Please sign in to comment.