diff --git a/src/cpu/simple_layer_normalization.cpp b/src/cpu/simple_layer_normalization.cpp index a5f01350800..cc6446d247c 100644 --- a/src/cpu/simple_layer_normalization.cpp +++ b/src/cpu/simple_layer_normalization.cpp @@ -153,6 +153,7 @@ status_t simple_layer_normalization_bwd_t::pd_t::init( reorder_pd_, engine, stat_md(), &reordered_stat_md_)); } + nthr_ = dnnl_get_max_threads(); init_scratchpad(); return status::success; } @@ -213,7 +214,7 @@ status_t simple_layer_normalization_bwd_t::execute_backward( if (diff_shift == diff_scale) diff_shift = &diff_shift[diff_shift_off]; } - const int max_nthr = dnnl_get_max_threads(); + const int max_nthr = pd()->nthr_; parallel(max_nthr, [&](int ithr, int nthr) { dim_t N_start = 0, N_end = 0; diff --git a/src/cpu/simple_layer_normalization.hpp b/src/cpu/simple_layer_normalization.hpp index 4cceb255d3f..5ccc5445665 100644 --- a/src/cpu/simple_layer_normalization.hpp +++ b/src/cpu/simple_layer_normalization.hpp @@ -160,6 +160,7 @@ struct simple_layer_normalization_bwd_t : public primitive_t { std::shared_ptr reorder_pd_; memory_desc_t reordered_stat_md_; + int nthr_; // To not exceed the limit in execute used for set up. private: void init_scratchpad() { @@ -171,8 +172,8 @@ struct simple_layer_normalization_bwd_t : public primitive_t { scratchpad.template book( key_lnorm_tmp_var, across_axis()); } - scratchpad.template book(key_lnorm_reduction, - 2 * norm_axis() * dnnl_get_max_threads()); + scratchpad.template book( + key_lnorm_reduction, 2 * norm_axis() * nthr_); scratchpad.template book( key_lnorm_tmp_diff_ss, 2 * norm_axis()); if (reordered_stat_md_ != *stat_md() && !stats_are_tmp()) {