Skip to content

Commit

Permalink
cpu: aarch64: batch_normalization : expand ARM SVE support in jit_uni…
Browse files Browse the repository at this point in the history
…_batch_normalization (#1918)
  • Loading branch information
nikhilfujitsu committed Jun 3, 2024
1 parent 4aba780 commit 0a1d0fb
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 49 deletions.
127 changes: 79 additions & 48 deletions src/cpu/aarch64/jit_uni_batch_normalization.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/*******************************************************************************
* Copyright 2020-2022 Intel Corporation
* Copyright 2020-2022 FUJITSU LIMITED
* Copyright 2020-2024 FUJITSU LIMITED
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -430,7 +430,7 @@ struct jit_bnorm_t : public jit_generator {
#undef STR_PARAM
}

void prepare_tail_mask_sve_512() {
void prepare_tail_mask() {
if (!is_c_padded()) return;
const int tail = pd_->C() % (int)(vlen / sizeof(float));
set_preg(ktail_mask.s, tail, X_TMP_0, X_TMP_1);
Expand All @@ -447,18 +447,18 @@ struct jit_bnorm_t : public jit_generator {
if (with_relu) uni_clear(vzero);
}

void fwd_process_relu_sve_512(ZRegS vdst, int offt = 0) {
void fwd_process_relu(ZRegS vdst, int offt = 0) {
const int bits = bit_shift();
const int offset = offt / (1 << bits);
XReg r = jbp_->is_nspc_ ? reg_soff_nspc : reg_soff;
ZRegS zzero = ZRegS(vzero.getIdx());

assert(isa == sve_512);
assert(isa == sve_256 || isa == sve_512);

assert(bits < 64);
lsr(r, r, bits);
fcmlt(kstore_mask.s, P_ALL_ONE / T_z, zzero, vdst);
sub(X_DEFAULT_ADDR, sp, 8); // sve_512
sub(X_DEFAULT_ADDR, sp, 8);
uzp1(p_tmp0.b, kstore_mask.b, kstore_mask.b);
uzp1(p_tmp0.b, p_tmp0.b, p_tmp0.b);
str(p_tmp0, ptr(X_DEFAULT_ADDR));
Expand All @@ -472,11 +472,11 @@ struct jit_bnorm_t : public jit_generator {
lsl(r, r, bit_shift());
}

void fwd_process_relu_alpha_sve_512(TRegS vmm_dst) {
void fwd_process_relu_alpha(TRegS vmm_dst) {
ZRegS dst = ZRegS(vmm_dst.getIdx());
ZRegS z_tmp0 = ZRegS(t_tmp0.getIdx());

assert(isa == sve_512);
assert(isa == sve_256 || isa == sve_512);

add_imm(X_DEFAULT_ADDR, sp, (int)stack_off_relu_alpha, X_TMP_0);
ld1rw(ZRegS(t_tmp0.getIdx()), P_ALL_ONE / T_z, ptr(X_DEFAULT_ADDR));
Expand All @@ -486,7 +486,7 @@ struct jit_bnorm_t : public jit_generator {
sel(dst, kstore_mask, dst, z_tmp0);
}

void bwd_process_relu_sve_512(ZRegS vdiff_dst, int offt = 0) {
void bwd_process_relu(ZRegS vdiff_dst, int offt = 0) {
const int bits = bit_shift();
const int offset = offt / (1 << bits);
XReg r = jbp_->is_nspc_ ? reg_soff_nspc : reg_soff;
Expand All @@ -498,7 +498,7 @@ struct jit_bnorm_t : public jit_generator {
if (offset) add_imm(X_DEFAULT_ADDR, X_DEFAULT_ADDR, offset, X_TMP_0);

ldrh(W_TMP_0, ptr(X_DEFAULT_ADDR));
sub(X_DEFAULT_ADDR, sp, 8); // sve_512
sub(X_DEFAULT_ADDR, sp, 8);
str(X_TMP_0, ptr(X_DEFAULT_ADDR));
ldr(kstore_mask, ptr(X_DEFAULT_ADDR));
zip1(kstore_mask.b, kstore_mask.b, kstore_mask.b);
Expand All @@ -512,7 +512,9 @@ struct jit_bnorm_t : public jit_generator {
ldr(QReg(IDX(v)), ptr(x));
}

void uni_load_spat_data(const ZReg &z, const XReg &x) { ldr(z, ptr(x)); }
void uni_load_spat_data(const ZReg &z, const XReg &x) {
ld1w(z.s, P_ALL_ONE / T_z, ptr(x));
}

void uni_store_spat_data(
const XReg &x, const VReg &v, bool is_nt_store = false) {
Expand All @@ -522,7 +524,7 @@ struct jit_bnorm_t : public jit_generator {

void uni_store_spat_data(
const XReg &x, const ZReg &z, bool is_nt_store = false) {
stnt1w(z.s, P_ALL_ONE, ptr(x));
stnt1w(z.s, P_ALL_ONE / T_z, ptr(x));
}

void jump_check(const Label &l_no_mask) {
Expand All @@ -541,7 +543,8 @@ struct jit_bnorm_t : public jit_generator {

if (is_c_padded()) {
jump_check(l_no_mask);
if (isa == sve_512) ld1w(ZRegS(IDX(t)), ktail_mask / T_z, ptr(x));
assert(isa == sve_256 || isa == sve_512);
ld1w(ZRegS(IDX(t)), ktail_mask / T_z, ptr(x));
b(l_ret);
}
L(l_no_mask);
Expand All @@ -554,7 +557,8 @@ struct jit_bnorm_t : public jit_generator {

if (is_c_padded()) {
jump_check(l_no_mask);
if (isa == sve_512) st1w(ZRegS(IDX(t)), ktail_mask / T_z, ptr(x));
assert(isa == sve_256 || isa == sve_512);
st1w(ZRegS(IDX(t)), ktail_mask / T_z, ptr(x));
b(l_ret);
}
L(l_no_mask);
Expand Down Expand Up @@ -589,7 +593,9 @@ struct jit_bnorm_t : public jit_generator {

void uni_ldr(const VReg &v, const XReg &x) { ldr(QReg(IDX(v)), ptr(x)); }

void uni_ldr(const ZReg &z, const XReg &x) { ldr(z, ptr(x)); }
void uni_ldr(const ZReg &z, const XReg &x) {
ld1w(z.s, P_ALL_ONE / T_z, ptr(x));
}

void uni_str(const VReg &v, const XReg &base,
const XReg &off = XReg(DUMMY_IDX), const int disp = 0) {
Expand All @@ -615,7 +621,7 @@ struct jit_bnorm_t : public jit_generator {

void uni_str(const ZReg &z, const XReg &base,
const XReg &off = XReg(DUMMY_IDX), const int disp = 0) {
str(z, ptr(xreg_addr(base, off, disp)));
st1w(z.s, P_ALL_ONE / T_z, ptr(xreg_addr(base, off, disp)));
}

void uni_stnt1w(const ZReg &z, const XReg &base,
Expand Down Expand Up @@ -885,12 +891,12 @@ struct jit_bnorm_t : public jit_generator {

if (with_relu_inf_only) { // --attr=post_ops='relu'
if (pd_->alpha() != 0.f)
fwd_process_relu_alpha_sve_512(vdata);
fwd_process_relu_alpha(vdata);
else
uni_fmaxnm(vdata, vdata, vzero.s);
} else if (with_relu) { // --flags=R
assert(isa == sve_512);
fwd_process_relu_sve_512(
assert(isa == sve_256 || isa == sve_512);
fwd_process_relu(
ZRegS(vdata.getIdx()), idx * vlen_spat_data_);
}
add(X_DEFAULT_ADDR, reg_dst, reg_soff_nspc);
Expand Down Expand Up @@ -1004,7 +1010,8 @@ struct jit_bnorm_t : public jit_generator {
L(zero_rbuf);
{
uni_str(TReg(0), reg_rbuf1, reg_coff);
add_imm(reg_coff, reg_coff, isa == sve_512 ? vlen : vlen / 2,
add_imm(reg_coff, reg_coff,
(isa == sve_256 || isa == sve_512) ? vlen : vlen / 2,
X_TMP_0);
cmp(reg_coff, reg_coff_max);
b(NE, zero_rbuf);
Expand Down Expand Up @@ -1080,13 +1087,13 @@ struct jit_bnorm_t : public jit_generator {
subs_imm(reg_ctr, reg_ctr, 1, X_TMP_0);
b(NE, mean_reduction_thrs);
}
if (isa == sve_512)
if (isa == sve_256 || isa == sve_512)
fdiv(ZRegS(1), P_ALL_ONE / T_m, ZRegS(vchan_size.getIdx()));
else
fdiv(VReg4S(1), VReg4S(1), VReg4S(vchan_size.getIdx()));
uni_store_maybe_tail(mean_ptr(), TReg(1));

if (isa == sve_512)
if (isa == sve_256 || isa == sve_512)
add_imm(reg_coff, reg_coff, vlen, X_TMP_0);
else
add_imm(reg_coff, reg_coff, vlen / 2, X_TMP_0);
Expand Down Expand Up @@ -1163,13 +1170,13 @@ struct jit_bnorm_t : public jit_generator {
subs(reg_ctr, reg_ctr, 1);
b(NE, var_reduction_thrs);
}
if (isa == sve_512)
if (isa == sve_256 || isa == sve_512)
fdiv(ZRegS(1), P_ALL_ONE / T_m, ZRegS(vchan_size.getIdx()));
else {
fdiv(VReg4S(1), VReg4S(1), VReg4S(IDX(vchan_size)));
}
uni_store_maybe_tail(var_ptr(), TReg(1));
if (isa == sve_512)
if (isa == sve_256 || isa == sve_512)
add_imm(reg_coff, reg_coff, vlen, X_TMP_0);
else
add_imm(reg_coff, reg_coff, vlen / 2, X_TMP_0);
Expand Down Expand Up @@ -1224,12 +1231,12 @@ struct jit_bnorm_t : public jit_generator {
}
if (with_relu_inf_only) { // --attr=post_ops='relu'
if (pd_->alpha() != 0.f) {
fwd_process_relu_alpha_sve_512(v);
fwd_process_relu_alpha(v);
} else
uni_fmaxnm(v, v, vzero.s);
} else if (with_relu) { // --flags=R
assert(isa == sve_512);
fwd_process_relu_sve_512(ZRegS(v.getIdx()), offt);
assert(isa == sve_256 || isa == sve_512);
fwd_process_relu(ZRegS(v.getIdx()), offt);
}
add(X_DEFAULT_ADDR, reg_dst, reg_soff);
if (offt)
Expand Down Expand Up @@ -1405,8 +1412,8 @@ struct jit_bnorm_t : public jit_generator {
if (offt) add_imm(X_TMP_0, X_TMP_0, offt, X_TMP_1);
uni_load_spat_data(t2, X_TMP_0);
if (with_relu) {
assert(isa == sve_512);
bwd_process_relu_sve_512(ZRegS(t2.getIdx()), offt);
assert(isa == sve_256 || isa == sve_512);
bwd_process_relu(ZRegS(t2.getIdx()), offt);
}
fsub(t3.s, vmean.s, t1.s);
if (isa == asimd) {
Expand Down Expand Up @@ -1490,8 +1497,8 @@ struct jit_bnorm_t : public jit_generator {
uni_load_spat_data(vdiff_dst, X_TMP_3);

if (with_relu) {
assert(isa == sve_512);
bwd_process_relu_sve_512(ZRegS(vdiff_dst.getIdx()), offt);
assert(isa == sve_256 || isa == sve_512);
bwd_process_relu(ZRegS(vdiff_dst.getIdx()), offt);
}

fsub(vsrc.s, vsrc.s, vmean.s);
Expand Down Expand Up @@ -1603,8 +1610,8 @@ struct jit_bnorm_t : public jit_generator {
add_imm(X_DEFAULT_ADDR, X_DEFAULT_ADDR, offt, X_TMP_0);
uni_load_spat_data(TReg(v.getIdx()), X_DEFAULT_ADDR);
if (with_relu) {
assert(isa == sve_512);
bwd_process_relu_sve_512(ZRegS(v.getIdx()), offt);
assert(isa == sve_256 || isa == sve_512);
bwd_process_relu(ZRegS(v.getIdx()), offt);
}
if (!pd_->use_global_stats()) {
fsub(v, v, vdiff_beta.s);
Expand Down Expand Up @@ -1723,9 +1730,8 @@ struct jit_bnorm_t : public jit_generator {
TReg(vdiff_data.getIdx()), X_DEFAULT_ADDR);

if (with_relu) {
assert(isa == sve_512);
bwd_process_relu_sve_512(
ZRegS(vdiff_data.getIdx()), offt);
assert(isa == sve_256 || isa == sve_512);
bwd_process_relu(ZRegS(vdiff_data.getIdx()), offt);
}

if (!pd_->use_global_stats()) {
Expand Down Expand Up @@ -1841,7 +1847,7 @@ struct jit_bnorm_t : public jit_generator {
uni_str(TReg(0), X_TMP_0);
add(X_TMP_0, reg_rbuf2, reg_coff);
uni_str(TReg(0), X_TMP_0);
if (isa == sve_512)
if (isa == sve_256 || isa == sve_512)
add_imm(reg_coff, reg_coff, vlen, X_TMP_0);
else
add_imm(reg_coff, reg_coff, vlen / 2, X_TMP_0);
Expand All @@ -1852,7 +1858,7 @@ struct jit_bnorm_t : public jit_generator {
LDR_ASSERT(reg_src, sp, (int)stack_off_src);
LDR_ASSERT(reg_diff_dst, sp, (int)stack_off_diff_dst);
if (with_relu) {
assert(isa == sve_512);
assert(isa == sve_256 || isa == sve_512);
LDR_ASSERT(reg_ws, sp, (int)stack_off_ws);
}

Expand Down Expand Up @@ -1935,7 +1941,8 @@ struct jit_bnorm_t : public jit_generator {
fmul(TRegS(0), TRegS(0), vsqrtvar.s);
uni_store_maybe_tail(diff_gamma_ptr(), TReg(0));
uni_store_maybe_tail(diff_beta_ptr(), TReg(1));
add_imm(reg_coff, reg_coff, isa == sve_512 ? vlen : vlen / 2,
add_imm(reg_coff, reg_coff,
isa == sve_256 || isa == sve_512 ? vlen : vlen / 2,
X_TMP_0);
cmp(reg_coff, reg_coff_max);
b(NE, sh_reduction_channels);
Expand All @@ -1946,7 +1953,7 @@ struct jit_bnorm_t : public jit_generator {

LDR_ASSERT(reg_diff_src, sp, (int)stack_off_diff_src);
if (with_relu) {
assert(isa == sve_512);
assert(isa == sve_256 || isa == sve_512);
LDR_ASSERT(reg_ws, sp, (int)stack_off_ws);
}

Expand Down Expand Up @@ -2003,20 +2010,31 @@ struct jit_bnorm_t : public jit_generator {

jit_bnorm_t(const batch_normalization_pd_t *pd, const jit_bnorm_conf_t *jbp)
: pd_(pd), jbp_(jbp) {
static_assert(isa == asimd || isa == sve_512, "unsupported isa");
static_assert(isa == asimd || isa == sve_256 || isa == sve_512,
"unsupported isa");

is_bf16_ = pd_->src_md()->data_type == data_type::bf16;
is_f16_ = pd_->src_md()->data_type == data_type::f16;
vlen_spat_data_ = vlen / (1 + is_xf16()); // 32B of xF16 -> 64B of FP32

unroll_blocks = isa == sve_512 && !jbp_->is_spatial_thr_ ? 4 : 1;
unroll_regs = isa == sve_512 && !jbp_->is_spatial_thr_ ? 4 : 1;
unroll_blocks
= (isa == sve_256 || isa == sve_512) && !jbp_->is_spatial_thr_
? 4
: 1;
unroll_regs
= (isa == sve_256 || isa == sve_512) && !jbp_->is_spatial_thr_
? 4
: 1;
}

void generate() override {
preamble();

if (isa == sve_512) { prepare_tail_mask_sve_512(); }
size_t simd_w_ = cpu_isa_traits<isa>::vlen / sizeof(float);
if (simd_w_ != cpu_sveLen / sizeof(float))
set_preg(P_ALL_ONE.s, simd_w_, X_TMP_0, X_TMP_1);

if (isa == sve_256 || isa == sve_512) { prepare_tail_mask(); }

compute_static_strides();

Expand Down Expand Up @@ -2281,21 +2299,26 @@ status_t jit_uni_batch_normalization_fwd_t<isa>::pd_t::init(engine_t *engine) {
if (!src_d.matches_one_of_tag(
nCw16c, nChw16c, nCdhw16c, nc, nwc, nhwc, ndhwc))
return status::unimplemented;
} else if (isa == sve_256) {
if (!src_d.matches_one_of_tag(
nCw8c, nChw8c, nCdhw8c, nc, nwc, nhwc, ndhwc))
return status::unimplemented;
} else {
if (!src_d.matches_one_of_tag(nCw8c, nChw8c, nCdhw8c))
return status::unimplemented;
}

if (is_fwd() ? with_relu_post_op(is_training()) || fuse_norm_relu()
: fuse_norm_relu())
if (isa != sve_512) return status::unimplemented;
if (isa != sve_512) return status::unimplemented; // TODO

if (is_training() && fuse_norm_relu()) {
if (isa < sve_512) return status::unimplemented;
if (isa != sve_256 && isa != sve_512) return status::unimplemented;
init_default_ws(1);
}

if (memory_desc_wrapper(src_md()).padded_dims()[1] != C() && isa < sve_512)
if (memory_desc_wrapper(src_md()).padded_dims()[1] != C() && isa != sve_256
&& isa != sve_512)
return status::unimplemented;

// Only IC % 16 == 0 is supported for now
Expand Down Expand Up @@ -2386,6 +2409,11 @@ status_t jit_uni_batch_normalization_bwd_t<isa>::pd_t::init(engine_t *engine) {
nc, nwc, nCw16c, nhwc, nChw16c, ndhwc, nCdhw16c);
diff_src_tag = diff_src_d.matches_one_of_tag(
nc, nwc, nCw16c, nhwc, nChw16c, ndhwc, nCdhw16c);
} else if (isa == sve_256) {
src_tag = src_d.matches_one_of_tag(
nc, nwc, nCw8c, nhwc, nChw8c, ndhwc, nCdhw8c);
diff_src_tag = diff_src_d.matches_one_of_tag(
nc, nwc, nCw8c, nhwc, nChw8c, ndhwc, nCdhw8c);
} else {
src_tag = src_d.matches_one_of_tag(nCw8c, nChw8c, nCdhw8c);
diff_src_tag = diff_src_d.matches_one_of_tag(nCw8c, nChw8c, nCdhw8c);
Expand All @@ -2394,7 +2422,8 @@ status_t jit_uni_batch_normalization_bwd_t<isa>::pd_t::init(engine_t *engine) {
&& src_tag == diff_src_tag);
if (!ok) return status::unimplemented;

if (memory_desc_wrapper(src_md()).padded_dims()[1] != C() && isa < sve_512)
if (memory_desc_wrapper(src_md()).padded_dims()[1] != C() && isa != sve_256
&& isa != sve_512)
return status::unimplemented;

// Only IC % 16 == 0 is supported for now
Expand All @@ -2404,7 +2433,7 @@ status_t jit_uni_batch_normalization_bwd_t<isa>::pd_t::init(engine_t *engine) {
}

if (fuse_norm_relu()) {
if (isa < sve_512) return status::unimplemented;
if (isa != sve_256 && isa != sve_512) return status::unimplemented;
init_default_ws(1);
if (!compare_ws(hint_fwd_pd_)) return status::unimplemented;
}
Expand Down Expand Up @@ -2465,6 +2494,8 @@ jit_uni_batch_normalization_bwd_t<isa>::~jit_uni_batch_normalization_bwd_t() {
/* struct instantiation */
template struct jit_uni_batch_normalization_fwd_t<asimd>;
template struct jit_uni_batch_normalization_bwd_t<asimd>;
template struct jit_uni_batch_normalization_fwd_t<sve_256>;
template struct jit_uni_batch_normalization_bwd_t<sve_256>;
template struct jit_uni_batch_normalization_fwd_t<sve_512>;
template struct jit_uni_batch_normalization_bwd_t<sve_512>;

Expand Down
Loading

0 comments on commit 0a1d0fb

Please sign in to comment.