Skip to content

Commit

Permalink
x64: brgconv: remove unnecessary batchsize loops
Browse files Browse the repository at this point in the history
  • Loading branch information
tczeszun authored and tprimak committed Mar 25, 2024
1 parent fbe5b97 commit 1eab005
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 112 deletions.
193 changes: 84 additions & 109 deletions src/cpu/x64/jit_brgemm_conv_bwd_strided.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,18 +130,7 @@ status_t brgemm_convolution_bwd_strided_t<isa, is_deconv>::pd_t::init(

const auto adj_M = nstl::max(jcp_.M, jcp_.M_tail);

batchsizes.resize(jcp_.max_batch + 1);
for (int i = 0; i <= jcp_.max_batch; i++)
batchsizes[i] = -1;

first_bs = 0;
bs_c = 0;

batchsizes[jcp_.max_batch] = bs_c;
first_bs = jcp_.max_batch;
bs_c++;

brgs_sz_ = bs_c * adj_M * 2 * 2 * 2;
brgs_sz_ = adj_M * 2 * 2 * 2;
brgs_ = std::make_shared<brgemm_containers::brgemm_desc_container_t>();
brgs_->resize(brgs_sz_);

Expand All @@ -159,86 +148,79 @@ status_t brgemm_convolution_bwd_strided_t<isa, is_deconv>::pd_t::init(
if (one_of(jcp_.exec_type, exec_trans, exec_vpad) && vM != jcp_.M
&& vM != jcp_.M_tail)
continue;
for (int bs = 0; bs <= jcp_.max_batch; bs++) {
if (batchsizes[bs] == -1) continue;
for_(int i_init = 0; i_init < 2; i_init++)
for_(int i_N = 0; i_N < 2; i_N++)
for (int i_K = 0; i_K < 2; i_K++) {
auto vbeta = (i_init) ? 0 : beta;
auto vN = (i_N) ? jcp_.N_tail : jcp_.N;
auto vK = (i_K) ? jcp_.K_tail : jcp_.K;
auto vbrgM = jcp_.use_M_mask
? (vM == jcp_.M ? jcp_.brgM : jcp_.brgM_tail)
: vM;
auto brg_idx = get_brg_idx(bs, i, i_init, i_N, i_K);
// if brgemm_t already created then skip this iteration
if ((*brgs_)[brg_idx] != nullptr) continue;
brgemm_t brg;
if (vN == 0 || vK == 0) continue;
brgemm_strides_t brg_strides;
brg_strides.stride_a = jcp_.brg_stride_a;
brg_strides.stride_b = jcp_.brg_stride_b;
brg.req_cal_comp_pads = jcp_.req_brg_comp_pad;
brg.req_comp_pads_with_bcast
= jcp_.req_cal_comp_pad && jcp_.exec_type == exec_trans;
const auto strides_ptr = (jcp_.brg_type == brgemm_strd)
? &brg_strides
: nullptr;
CHECK(brgemm_desc_init(&brg, isa, jcp_.brg_type, diff_dst_type,
wei_type, false, false, brgemm_row_major, alpha, vbeta,
jcp_.LDA, jcp_.LDB, jcp_.LDC, vbrgM, vN, vK,
strides_ptr));

brgemm_attr_t brgattr;
brgattr.use_uker = jcp_.use_uker;
brgattr.use_interleave_stores = jcp_.use_interleave_stores;
brgattr.hint_prefetching = jcp_.hint_prefetching;
brgattr.max_bs = bs;
brgattr.hint_innermost_loop = jcp_.brgemm_bd_loop_innermost
? brgemm_bd_loop_innermost
: brgemm_ld_loop_innermost;
if (jcp_.amx_tile_load_xx) {
// assuming 2x2 decomposition in amx brgemm kernel
// and overlap of input by kw
const auto bd_blocking = 2 * jcp_.amx_h;
const auto ld_blocking = 2 * 16;
brgattr.hint_expected_A_size = bd_blocking * jcp_.K
* jcp_.kd_block * jcp_.kh_block;
brgattr.hint_expected_B_size = ld_blocking * jcp_.K
* jcp_.kd_block * jcp_.kh_block * jcp_.kw_block;
brgattr.hint_expected_C_size = bd_blocking * ld_blocking;
} else {
brgattr.hint_expected_A_size = 0;
brgattr.hint_expected_B_size = 0;
brgattr.hint_expected_C_size = 0;
}
for_(int i_init = 0; i_init < 2; i_init++)
for_(int i_N = 0; i_N < 2; i_N++)
for (int i_K = 0; i_K < 2; i_K++) {
auto vbeta = (i_init) ? 0 : beta;
auto vN = (i_N) ? jcp_.N_tail : jcp_.N;
auto vK = (i_K) ? jcp_.K_tail : jcp_.K;
auto vbrgM = jcp_.use_M_mask
? (vM == jcp_.M ? jcp_.brgM : jcp_.brgM_tail)
: vM;
auto brg_idx = get_brg_idx(jcp_.max_batch, i, i_init, i_N, i_K);
// if brgemm_t already created then skip this iteration
if ((*brgs_)[brg_idx] != nullptr) continue;
brgemm_t brg;
if (vN == 0 || vK == 0) continue;
brgemm_strides_t brg_strides;
brg_strides.stride_a = jcp_.brg_stride_a;
brg_strides.stride_b = jcp_.brg_stride_b;
brg.req_cal_comp_pads = jcp_.req_brg_comp_pad;
brg.req_comp_pads_with_bcast
= jcp_.req_cal_comp_pad && jcp_.exec_type == exec_trans;
const auto strides_ptr
= (jcp_.brg_type == brgemm_strd) ? &brg_strides : nullptr;
CHECK(brgemm_desc_init(&brg, isa, jcp_.brg_type, diff_dst_type,
wei_type, false, false, brgemm_row_major, alpha, vbeta,
jcp_.LDA, jcp_.LDB, jcp_.LDC, vbrgM, vN, vK, strides_ptr));

brgemm_attr_t brgattr;
brgattr.use_uker = jcp_.use_uker;
brgattr.use_interleave_stores = jcp_.use_interleave_stores;
brgattr.hint_prefetching = jcp_.hint_prefetching;
brgattr.max_bs = jcp_.max_batch;
brgattr.hint_innermost_loop = jcp_.brgemm_bd_loop_innermost
? brgemm_bd_loop_innermost
: brgemm_ld_loop_innermost;
if (jcp_.amx_tile_load_xx) {
// assuming 2x2 decomposition in amx brgemm kernel
// and overlap of input by kw
const auto bd_blocking = 2 * jcp_.amx_h;
const auto ld_blocking = 2 * 16;
brgattr.hint_expected_A_size
= bd_blocking * jcp_.K * jcp_.kd_block * jcp_.kh_block;
brgattr.hint_expected_B_size = ld_blocking * jcp_.K
* jcp_.kd_block * jcp_.kh_block * jcp_.kw_block;
brgattr.hint_expected_C_size = bd_blocking * ld_blocking;
} else {
brgattr.hint_expected_A_size = 0;
brgattr.hint_expected_B_size = 0;
brgattr.hint_expected_C_size = 0;
}

brgattr.wary_tail_read = false;
// use_M_mask is always 0 for brgemm_convolution_bwd_strided_t
brgattr.bd_mask = nullptr;
brgattr.bd_mask_level = jcp_.use_M_mask;

if (is_amx) {
brgattr.max_top_vpad = 0;
brgattr.max_bottom_vpad = 0;
} else {
brgattr.max_top_vpad = jcp_.max_vpad;
brgattr.max_bottom_vpad = jcp_.max_vpad;
}
brgattr.generate_skip_accumulation = true;
CHECK(brgemm_desc_set_attr(&brg, brgattr));

auto LDD = jcp_.stride_w * jcp_.ic_without_padding;
brg.with_sum = with_sum;
brg.with_weights_scale_adjust
= jcp_.scale_adjust_factor != 1.0f;
CHECK(brgemm_desc_set_postops(
&brg, attr(), &diff_src_md_, LDD, jcp_.bia_dt));
jcp_.amx_buf_size_per_thread
= nstl::max(brg.get_wsp_buffer_size(),
jcp_.amx_buf_size_per_thread);
brgs_->insert(brg_idx, brg);
brgattr.wary_tail_read = false;
// use_M_mask is always 0 for brgemm_convolution_bwd_strided_t
brgattr.bd_mask = nullptr;
brgattr.bd_mask_level = jcp_.use_M_mask;

if (is_amx) {
brgattr.max_top_vpad = 0;
brgattr.max_bottom_vpad = 0;
} else {
brgattr.max_top_vpad = jcp_.max_vpad;
brgattr.max_bottom_vpad = jcp_.max_vpad;
}
brgattr.generate_skip_accumulation = true;
CHECK(brgemm_desc_set_attr(&brg, brgattr));

auto LDD = jcp_.stride_w * jcp_.ic_without_padding;
brg.with_sum = with_sum;
brg.with_weights_scale_adjust = jcp_.scale_adjust_factor != 1.0f;
CHECK(brgemm_desc_set_postops(
&brg, attr(), &diff_src_md_, LDD, jcp_.bia_dt));
jcp_.amx_buf_size_per_thread = nstl::max(
brg.get_wsp_buffer_size(), jcp_.amx_buf_size_per_thread);
brgs_->insert(brg_idx, brg);
}
}

Expand Down Expand Up @@ -410,17 +392,13 @@ void brgemm_convolution_bwd_strided_t<isa, is_deconv>::create_kernels() {
: 0;
int i_init_end = 2;

for (int bs = 0; bs <= jcp.max_batch; bs++) {
if (_pd->batchsizes[bs] == -1) continue;

for_(int i_N = N_begin; i_N < N_end; i_N++)
for_(int i_M = M_begin; i_M < M_end; i_M++)
for_(int i_init = i_init_begin; i_init < i_init_end; i_init++)
for (int i_K = K_begin; i_K < K_end; i_K++) {
auto M = (i_M) ? jcp.M_tail : jcp.M;
if (M <= 0) continue;
add_brg_kernel(bs, M, i_N, i_K, i_init);
}
for_(int i_N = N_begin; i_N < N_end; i_N++)
for_(int i_M = M_begin; i_M < M_end; i_M++)
for_(int i_init = i_init_begin; i_init < i_init_end; i_init++)
for (int i_K = K_begin; i_K < K_end; i_K++) {
auto M = (i_M) ? jcp.M_tail : jcp.M;
if (M <= 0) continue;
add_brg_kernel(jcp.max_batch, M, i_N, i_K, i_init);
}

if (jcp.exec_type == exec_base) {
Expand All @@ -444,14 +422,11 @@ void brgemm_convolution_bwd_strided_t<isa, is_deconv>::create_kernels() {
for (int kw = kw_s; kw < kw_f; kw++) {
get_iw_range(iw_str, iw, kw, iw_s, M_without_overflow);
if (M_without_overflow <= 0) continue;
for (int bs = 0; bs <= jcp.max_batch; bs++) {
if (_pd->batchsizes[bs] == -1) continue;
for_(int i_init = 0; i_init < 2; i_init++)
for_(int i_N = 0; i_N < 2; i_N++)
for (int i_K = 0; i_K < 2; i_K++) {
add_brg_kernel(
bs, M_without_overflow, i_N, i_K, i_init);
}
for_(int i_init = 0; i_init < 2; i_init++)
for_(int i_N = 0; i_N < 2; i_N++)
for (int i_K = 0; i_K < 2; i_K++) {
add_brg_kernel(jcp.max_batch, M_without_overflow, i_N, i_K,
i_init);
}

bool is_iw_tail = (jcp.iw - iw < jcp.iw_block);
Expand Down
6 changes: 3 additions & 3 deletions src/cpu/x64/jit_brgemm_conv_bwd_strided.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,11 @@ struct brgemm_convolution_bwd_strided_t : public primitive_t {
std::shared_ptr<brgemm_containers::brgemm_desc_container_t> brgs_;

jit_brgemm_conv_conf_t jcp_;
// batch sizes info for unrolled kernels
int bs_c, first_bs;
std::vector<int> batchsizes;
// batch size info
const int first_bs = 0;
int get_brg_idx(int bs, int m, bool do_initialization, bool is_N_tail,
bool is_K_tail) const {
const int bs_c = 1;
auto bs_idx = 0;
return (((m * bs_c + bs_idx) * 2
+ static_cast<int>(do_initialization))
Expand Down

0 comments on commit 1eab005

Please sign in to comment.