Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix lint errors
Browse files Browse the repository at this point in the history
  • Loading branch information
zixuanweeei committed Jul 25, 2019
1 parent 1e1f799 commit 49ebe01
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
12 changes: 8 additions & 4 deletions src/operator/nn/mkldnn/mkldnn_rnn_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,8 @@ static void MKLDNNRNNForwardSingleLayerBi(bool state_outputs,
MKLDNNStream::Get()->Submit();

if (state_outputs) {
DType* dst_hcy = reinterpret_cast<DType *>(mkldnn_mems->hcy_memory[layer_index].get_data_handle());
DType* dst_hcy = reinterpret_cast<DType *>(
mkldnn_mems->hcy_memory[layer_index].get_data_handle());
if (mode == rnn_enum::kLstm) {
offset1 = nstates * single_cell_size;
offset2 = (nstates + 1) * single_cell_size;
Expand Down Expand Up @@ -542,7 +543,8 @@ static void MKLDNNRNNForwardUnidi(bool state_outputs,
MKLDNNStream::Get()->RegisterPrim(reorder(src_wx_f, mkldnn_mems->wx_memory[layer_index]));
MKLDNNStream::Get()->RegisterPrim(reorder(src_wh_f, mkldnn_mems->wh_memory[layer_index]));

DType* user_bias_f = reinterpret_cast<DType *>(mkldnn_mems->bias_memory[layer_index].get_data_handle());
DType* user_bias_f = reinterpret_cast<DType *>(
mkldnn_mems->bias_memory[layer_index].get_data_handle());
if (mode == rnn_enum::kGru) {
const int mx_single_b_sz = ngates * H;
for (int l = 0; l < L; l++) {
Expand All @@ -569,7 +571,8 @@ static void MKLDNNRNNForwardUnidi(bool state_outputs,
#pragma omp parallel for num_threads(omp_threads)
for (int j = 0; j < L * single_b_size; j++) {
int k = j / single_b_size;
user_bias_f[j] = b_ptr[j + k * single_b_size] + b_ptr[j + k * single_b_size + single_b_size];
user_bias_f[j] = b_ptr[j + k * single_b_size] +
b_ptr[j + k * single_b_size + single_b_size];
}
}
}
Expand Down Expand Up @@ -604,7 +607,8 @@ static void MKLDNNRNNForwardUnidi(bool state_outputs,
MKLDNNStream::Get()->Submit();

if (state_outputs) {
DType* dst_hcy = reinterpret_cast<DType *>(mkldnn_mems->hcy_memory[layer_index].get_data_handle());
DType* dst_hcy = reinterpret_cast<DType *>(
mkldnn_mems->hcy_memory[layer_index].get_data_handle());
if (mode == rnn_enum::kLstm) {
for (int l = 0; l < L; l++) {
offset1 = l * single_cell_size;
Expand Down
5 changes: 3 additions & 2 deletions src/operator/rnn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -491,8 +491,9 @@ static void RNNStatefulComputeCPU(const OpStatePtr& state_ptr,
= mkldnn::memory({ user_weight_iter_md, cpu_engine }, weight_iter_n);
op.mkldnn_mems.wh_memory.push_back(user_weight_iter_memory_n);

DType* bias_n = weight_iter_n + (L - 1) * H * ngates * H; // Generally, (L - 1) * ngates * H
// LBR-Gru, (L -1) * (ngates + 1) * H
DType* bias_n = weight_iter_n + (L - 1) * H * ngates * H; // Generally, (L - 1) *
// ngates * H. LBR-Gru,
// (L -1) * (ngates + 1) * H
auto user_bias_memory_n =
mkldnn::memory({ user_bias_md, cpu_engine }, bias_n);
op.mkldnn_mems.bias_memory.push_back(user_bias_memory_n);
Expand Down

0 comments on commit 49ebe01

Please sign in to comment.