Skip to content

Commit

Permalink
cpu: update conv implementations with nthr_ member
Browse files Browse the repository at this point in the history
  • Loading branch information
dzarukin committed Dec 20, 2021
1 parent 4c3e771 commit 2458105
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 20 deletions.
2 changes: 1 addition & 1 deletion src/cpu/x64/jit_avx2_1x1_convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ void jit_avx2_1x1_convolution_fwd_t::execute_forward(
bias = padded_bias;
}

parallel(0, [&](const int ithr, const int nthr) {
parallel(jcp.nthr, [&](const int ithr, const int nthr) {
execute_forward_thr(ithr, nthr, src, weights, bias, weights_dw, bias_dw,
dst, scratchpad, post_ops_binary_rhs_arg_vec.data(),
post_ops_binary_rhs_arg_vec_dw.data());
Expand Down
5 changes: 3 additions & 2 deletions src/cpu/x64/jit_avx512_core_f32_wino_conv_4x3.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -847,7 +847,8 @@ void jit_avx512_core_f32_wino_conv_4x3_bwd_weights_t::
float I[alpha][alpha][simd_w];
float T[alpha][alpha][simd_w];

PRAGMA_OMP(parallel firstprivate(first_tblk, trans_ker_p, I, T))
PRAGMA_OMP(parallel num_threads(nthreads)
firstprivate(first_tblk, trans_ker_p, I, T))
{
if (jcp.with_bias) {
parallel_nd_in_omp(nthreads, jcp.oc, [&](dim_t ithr, dim_t ofm) {
Expand Down Expand Up @@ -945,7 +946,7 @@ void jit_avx512_core_f32_wino_conv_4x3_bwd_weights_t::
}

trans_ker_p.G = G_O_3x3_4x4;
PRAGMA_OMP(parallel firstprivate(trans_ker_p))
PRAGMA_OMP(parallel num_threads(nthreads) firstprivate(trans_ker_p))
{
parallel_nd_in_omp(jcp.nb_ic, jcp.nb_oc, jcp.oc_block, jcp.ic_block,
jcp.oc_reg_block,
Expand Down
29 changes: 13 additions & 16 deletions src/cpu/x64/jit_avx512_core_f32_wino_conv_4x3_kernel.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2017-2020 Intel Corporation
* Copyright 2017-2021 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -165,7 +165,7 @@ bool check_L2_block_per_thread(jit_conv_winograd_conf_t &jcp, int dimN_block,
float C2_min, float C2_max) {
float block_size = alpha * alpha
* (2 * (jcp.oc + jcp.ic) * dimN_block * jcp.dimN_reg_block
+ div_up(jcp.ic * jcp.oc, dnnl_get_max_threads()))
+ div_up(jcp.ic * jcp.oc, jcp.nthr))
* (float)sizeof(float);
float L2_lb = C2_min * L2_cache_size;
float L2_ub = C2_max * L2_cache_size;
Expand Down Expand Up @@ -1232,15 +1232,15 @@ status_t set_wsched_DATA_W_SGD_avx512_core(jit_conv_winograd_conf_t &jcp) {
return check_L2_block_per_thread(jcp, dimN_block, 0.1, 2.0)
&& (dimN_block > current_best)
&& ((jcp.dimN / dimN_block / jcp.dimN_reg_block)
>= 1.5 * dnnl_get_max_threads());
>= 1.5 * jcp.nthr);
};

jcp.dimN_block = get_divisor_satisfying_cond(
jcp, jcp.dimN / jcp.dimN_reg_block, 1, test_cond_dimN_block);
jcp.dimN_nb_block = jcp.dimN / jcp.dimN_block / jcp.dimN_reg_block;

if (check_L2_block_per_thread(jcp, jcp.dimN_block, 0.1, 3.2)
&& (jcp.dimN_nb_block >= 1.5 * dnnl_get_max_threads())) {
&& (jcp.dimN_nb_block >= 1.5 * jcp.nthr)) {

/* ------------------- L1 blocking for GEMM --------------*/
/* -------------------- Choose dimK block ----------------*/
Expand Down Expand Up @@ -2295,9 +2295,8 @@ status_t set_wsched_WEI_SDGtWo(jit_conv_winograd_conf_t &jcp) {
auto test_MV_large_enough = [](jit_conv_winograd_conf_t &jcp) {
size_t M_sz = alpha * alpha * jcp.dimM * jcp.dimK * sizeof(float);
size_t V_sz = alpha * alpha * jcp.dimN * jcp.dimK * sizeof(float);
size_t nthreads = dnnl_get_max_threads();
return (((V_sz + M_sz) / nthreads) >= 2 * L2_cache_size)
&& (jcp.dimK / nthreads >= 1.0);
return (((V_sz + M_sz) / jcp.nthr) >= 2 * L2_cache_size)
&& (jcp.dimK / jcp.nthr >= 1.0);
};

auto test_min_dimK_L1 = [](jit_conv_winograd_conf_t &jcp, int dimK_block_ur,
Expand All @@ -2307,10 +2306,9 @@ status_t set_wsched_WEI_SDGtWo(jit_conv_winograd_conf_t &jcp) {
size_t L1_block_N = jcp.dimN_reg_block * dimK_block_ur * sizeof(float);
size_t M_L2_block
= alpha * alpha * jcp.dimM * dimK_block_ur * sizeof(float);
size_t nthreads = dnnl_get_max_threads();
bool load_balance = true;
if (!(jcp.dimK % nthreads)) {
load_balance = ((jcp.dimK / dimK_block_ur) % nthreads == 0);
if (!(jcp.dimK % jcp.nthr)) {
load_balance = ((jcp.dimK / dimK_block_ur) % jcp.nthr == 0);
}
return (L1_block_M + L1_block_N >= 0.1 * L1_cache_size)
&& (L1_block_M + L1_block_N <= 0.5 * L1_cache_size)
Expand Down Expand Up @@ -2367,8 +2365,7 @@ status_t set_wsched_WEI_SDGtWo(jit_conv_winograd_conf_t &jcp) {
jcp.dimM_block = M_blk;
jcp.sched_policy = WSCHED_WEI_SDGtWo;
set_jcp_WEI_params(jcp);
jcp.nthr = nstl::min(
dnnl_get_max_threads(), jcp.tile_block);
jcp.nthr = nstl::min(jcp.nthr, jcp.tile_block);
return status::success;
}
}
Expand Down Expand Up @@ -2411,10 +2408,10 @@ status_t set_wsched_WEI_S_D_Giot_W(jit_conv_winograd_conf_t &jcp) {
size_t nb_M_blk
= jcp.dimM / M_blk / jcp.dimM_reg_block / jcp.dimM_simd_block;
size_t nb_K_blk = jcp.dimK / K_blk_ur;
size_t nthreads = dnnl_get_max_threads();
bool load_balance = (nb_K_blk * nb_N_blk * nb_M_blk) >= nthreads;
if (!(nb_K_blk % nthreads)) {
load_balance = load_balance && (nb_K_blk % nthreads == 0);
bool load_balance = (nb_K_blk * nb_N_blk * nb_M_blk)
>= static_cast<size_t>(jcp.nthr);
if (!(nb_K_blk % jcp.nthr)) {
load_balance = load_balance && (nb_K_blk % jcp.nthr == 0);
}

size_t V_L2_block = alpha * alpha * N_blk * jcp.dimN_reg_block
Expand Down
2 changes: 1 addition & 1 deletion src/cpu/x64/jit_uni_x8s8s32x_1x1_convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ status_t jit_uni_x8s8s32x_1x1_convolution_fwd_t<isa>::execute_forward(
}
}
}
parallel(0, [&](const int ithr, const int nthr) {
parallel(pd()->jcp_.nthr, [&](const int ithr, const int nthr) {
execute_forward_thr(ithr, nthr, src, weights, bias, weights_dw, bias_dw,
dst, src_zero_point, dst_zero_point, scratchpad,
post_ops_binary_rhs_arg_vec.data(),
Expand Down

0 comments on commit 2458105

Please sign in to comment.