diff --git a/src/cpu/x64/jit_avx2_1x1_convolution.cpp b/src/cpu/x64/jit_avx2_1x1_convolution.cpp index 5489aaf41d8..88acfd978b3 100644 --- a/src/cpu/x64/jit_avx2_1x1_convolution.cpp +++ b/src/cpu/x64/jit_avx2_1x1_convolution.cpp @@ -69,7 +69,7 @@ void jit_avx2_1x1_convolution_fwd_t::execute_forward( bias = padded_bias; } - parallel(0, [&](const int ithr, const int nthr) { + parallel(jcp.nthr, [&](const int ithr, const int nthr) { execute_forward_thr(ithr, nthr, src, weights, bias, weights_dw, bias_dw, dst, scratchpad, post_ops_binary_rhs_arg_vec.data(), post_ops_binary_rhs_arg_vec_dw.data()); diff --git a/src/cpu/x64/jit_avx512_core_f32_wino_conv_4x3.cpp b/src/cpu/x64/jit_avx512_core_f32_wino_conv_4x3.cpp index 3a94362015f..99de1ffb69e 100644 --- a/src/cpu/x64/jit_avx512_core_f32_wino_conv_4x3.cpp +++ b/src/cpu/x64/jit_avx512_core_f32_wino_conv_4x3.cpp @@ -847,7 +847,8 @@ void jit_avx512_core_f32_wino_conv_4x3_bwd_weights_t:: float I[alpha][alpha][simd_w]; float T[alpha][alpha][simd_w]; - PRAGMA_OMP(parallel firstprivate(first_tblk, trans_ker_p, I, T)) + PRAGMA_OMP(parallel num_threads(nthreads) + firstprivate(first_tblk, trans_ker_p, I, T)) { if (jcp.with_bias) { parallel_nd_in_omp(nthreads, jcp.oc, [&](dim_t ithr, dim_t ofm) { @@ -945,7 +946,7 @@ void jit_avx512_core_f32_wino_conv_4x3_bwd_weights_t:: } trans_ker_p.G = G_O_3x3_4x4; - PRAGMA_OMP(parallel firstprivate(trans_ker_p)) + PRAGMA_OMP(parallel num_threads(nthreads) firstprivate(trans_ker_p)) { parallel_nd_in_omp(jcp.nb_ic, jcp.nb_oc, jcp.oc_block, jcp.ic_block, jcp.oc_reg_block, diff --git a/src/cpu/x64/jit_avx512_core_f32_wino_conv_4x3_kernel.cpp b/src/cpu/x64/jit_avx512_core_f32_wino_conv_4x3_kernel.cpp index e96f5e6ffe1..021c8f213f6 100644 --- a/src/cpu/x64/jit_avx512_core_f32_wino_conv_4x3_kernel.cpp +++ b/src/cpu/x64/jit_avx512_core_f32_wino_conv_4x3_kernel.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2017-2020 Intel Corporation +* Copyright 2017-2021 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -165,7 +165,7 @@ bool check_L2_block_per_thread(jit_conv_winograd_conf_t &jcp, int dimN_block, float C2_min, float C2_max) { float block_size = alpha * alpha * (2 * (jcp.oc + jcp.ic) * dimN_block * jcp.dimN_reg_block - + div_up(jcp.ic * jcp.oc, dnnl_get_max_threads())) + + div_up(jcp.ic * jcp.oc, jcp.nthr)) * (float)sizeof(float); float L2_lb = C2_min * L2_cache_size; float L2_ub = C2_max * L2_cache_size; @@ -1232,7 +1232,7 @@ status_t set_wsched_DATA_W_SGD_avx512_core(jit_conv_winograd_conf_t &jcp) { return check_L2_block_per_thread(jcp, dimN_block, 0.1, 2.0) && (dimN_block > current_best) && ((jcp.dimN / dimN_block / jcp.dimN_reg_block) - >= 1.5 * dnnl_get_max_threads()); + >= 1.5 * jcp.nthr); }; jcp.dimN_block = get_divisor_satisfying_cond( @@ -1240,7 +1240,7 @@ status_t set_wsched_DATA_W_SGD_avx512_core(jit_conv_winograd_conf_t &jcp) { jcp.dimN_nb_block = jcp.dimN / jcp.dimN_block / jcp.dimN_reg_block; if (check_L2_block_per_thread(jcp, jcp.dimN_block, 0.1, 3.2) - && (jcp.dimN_nb_block >= 1.5 * dnnl_get_max_threads())) { + && (jcp.dimN_nb_block >= 1.5 * jcp.nthr)) { /* ------------------- L1 blocking for GEMM --------------*/ /* -------------------- Choose dimK block ----------------*/ @@ -2295,9 +2295,8 @@ status_t set_wsched_WEI_SDGtWo(jit_conv_winograd_conf_t &jcp) { auto test_MV_large_enough = [](jit_conv_winograd_conf_t &jcp) { size_t M_sz = alpha * alpha * jcp.dimM * jcp.dimK * sizeof(float); size_t V_sz = alpha * alpha * jcp.dimN * jcp.dimK * sizeof(float); - size_t nthreads = dnnl_get_max_threads(); - return (((V_sz + M_sz) / nthreads) >= 2 * L2_cache_size) - && (jcp.dimK / nthreads >= 1.0); + return (((V_sz + M_sz) / jcp.nthr) >= 2 * L2_cache_size) + && (jcp.dimK / jcp.nthr >= 1.0); }; auto test_min_dimK_L1 = [](jit_conv_winograd_conf_t &jcp, int dimK_block_ur, @@ -2307,10 +2306,9 @@ status_t set_wsched_WEI_SDGtWo(jit_conv_winograd_conf_t &jcp) { size_t L1_block_N = jcp.dimN_reg_block * dimK_block_ur * sizeof(float); size_t M_L2_block = alpha * alpha * jcp.dimM * dimK_block_ur * sizeof(float); - size_t nthreads = dnnl_get_max_threads(); bool load_balance = true; - if (!(jcp.dimK % nthreads)) { - load_balance = ((jcp.dimK / dimK_block_ur) % nthreads == 0); + if (!(jcp.dimK % jcp.nthr)) { + load_balance = ((jcp.dimK / dimK_block_ur) % jcp.nthr == 0); } return (L1_block_M + L1_block_N >= 0.1 * L1_cache_size) && (L1_block_M + L1_block_N <= 0.5 * L1_cache_size) @@ -2367,8 +2365,7 @@ status_t set_wsched_WEI_SDGtWo(jit_conv_winograd_conf_t &jcp) { jcp.dimM_block = M_blk; jcp.sched_policy = WSCHED_WEI_SDGtWo; set_jcp_WEI_params(jcp); - jcp.nthr = nstl::min( - dnnl_get_max_threads(), jcp.tile_block); + jcp.nthr = nstl::min(jcp.nthr, jcp.tile_block); return status::success; } } @@ -2411,10 +2408,10 @@ status_t set_wsched_WEI_S_D_Giot_W(jit_conv_winograd_conf_t &jcp) { size_t nb_M_blk = jcp.dimM / M_blk / jcp.dimM_reg_block / jcp.dimM_simd_block; size_t nb_K_blk = jcp.dimK / K_blk_ur; - size_t nthreads = dnnl_get_max_threads(); - bool load_balance = (nb_K_blk * nb_N_blk * nb_M_blk) >= nthreads; - if (!(nb_K_blk % nthreads)) { - load_balance = load_balance && (nb_K_blk % nthreads == 0); + bool load_balance = (nb_K_blk * nb_N_blk * nb_M_blk) + >= static_cast(jcp.nthr); + if (!(nb_K_blk % jcp.nthr)) { + load_balance = load_balance && (nb_K_blk % jcp.nthr == 0); } size_t V_L2_block = alpha * alpha * N_blk * jcp.dimN_reg_block diff --git a/src/cpu/x64/jit_uni_x8s8s32x_1x1_convolution.cpp b/src/cpu/x64/jit_uni_x8s8s32x_1x1_convolution.cpp index 9a5851ae093..2bfedc09921 100644 --- a/src/cpu/x64/jit_uni_x8s8s32x_1x1_convolution.cpp +++ b/src/cpu/x64/jit_uni_x8s8s32x_1x1_convolution.cpp @@ -97,7 +97,7 @@ status_t jit_uni_x8s8s32x_1x1_convolution_fwd_t::execute_forward( } } } - parallel(0, [&](const int ithr, const int nthr) { + parallel(pd()->jcp_.nthr, [&](const int ithr, const int nthr) { execute_forward_thr(ithr, nthr, src, weights, bias, weights_dw, bias_dw, dst, src_zero_point, dst_zero_point, scratchpad, post_ops_binary_rhs_arg_vec.data(),