Skip to content

Commit

Permalink
cpu: update lnorm implementations with nthr_ member
Browse files Browse the repository at this point in the history
  • Loading branch information
dzarukin committed Dec 20, 2021
1 parent 8863e34 commit 57b1e7a
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
3 changes: 2 additions & 1 deletion src/cpu/simple_layer_normalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ status_t simple_layer_normalization_bwd_t<data_type>::pd_t::init(
reorder_pd_, engine, stat_md(), &reordered_stat_md_));
}

nthr_ = dnnl_get_max_threads();
init_scratchpad();
return status::success;
}
Expand Down Expand Up @@ -213,7 +214,7 @@ status_t simple_layer_normalization_bwd_t<data_type>::execute_backward(
if (diff_shift == diff_scale) diff_shift = &diff_shift[diff_shift_off];
}

const int max_nthr = dnnl_get_max_threads();
const int max_nthr = pd()->nthr_;

parallel(max_nthr, [&](int ithr, int nthr) {
dim_t N_start = 0, N_end = 0;
Expand Down
5 changes: 3 additions & 2 deletions src/cpu/simple_layer_normalization.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ struct simple_layer_normalization_bwd_t : public primitive_t {

std::shared_ptr<primitive_desc_t> reorder_pd_;
memory_desc_t reordered_stat_md_;
int nthr_; // To not exceed the limit in execute used for set up.

private:
void init_scratchpad() {
Expand All @@ -171,8 +172,8 @@ struct simple_layer_normalization_bwd_t : public primitive_t {
scratchpad.template book<float>(
key_lnorm_tmp_var, across_axis());
}
scratchpad.template book<float>(key_lnorm_reduction,
2 * norm_axis() * dnnl_get_max_threads());
scratchpad.template book<float>(
key_lnorm_reduction, 2 * norm_axis() * nthr_);
scratchpad.template book<float>(
key_lnorm_tmp_diff_ss, 2 * norm_axis());
if (reordered_stat_md_ != *stat_md() && !stats_are_tmp()) {
Expand Down

0 comments on commit 57b1e7a

Please sign in to comment.