Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix comments
Browse files Browse the repository at this point in the history
  • Loading branch information
wuxun-zhang committed Sep 25, 2019
1 parent 67dd824 commit 827b5e1
Showing 1 changed file with 11 additions and 15 deletions.
26 changes: 11 additions & 15 deletions src/operator/nn/mkldnn/mkldnn_reshape.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,10 @@ MKLDNNReshapeFwd::MKLDNNReshapeFwd(const OpReqType &req,

// temp_
auto temp_md = GetDesc(in_md, GetDefaultFormat(in_md));
temp_ = std::make_shared<mkldnn::memory>(temp_md, engine);
temp_ = std::make_shared<mkldnn::memory>(temp_md, engine, nullptr);

// destination
out_ = std::make_shared<mkldnn::memory>(temp_md, engine);
out_ = std::make_shared<mkldnn::memory>(temp_md, engine, nullptr);

if (req == kWriteInplace) {
// If the input has MKL-DNN internal layout, we need reorder it to a temporal buffer with
Expand All @@ -68,18 +68,15 @@ MKLDNNReshapeFwd::MKLDNNReshapeFwd(const OpReqType &req,
if (input.IsMKLDNNData()) {
prims_.push_back(mkldnn::reorder(*in_mem, *temp_)); // reorder to default
prims_.push_back(mkldnn::reorder(*temp_, *out_)); // copy back

needInvalidateInput = true;
}
} else if (req == kWriteTo) {
if (input.IsMKLDNNData()) {
prims_.push_back(mkldnn::reorder(*in_mem, *temp_)); // reorder to default
prims_.push_back(mkldnn::reorder(*temp_, *out_)); // copy to the output buffer

needInvalidateInput = false;
} else {
prims_.push_back(mkldnn::reorder(*in_mem, *out_)); // copy directly from input to output

needInvalidateInput = false;
}
} else {
Expand All @@ -97,28 +94,27 @@ void MKLDNNReshapeFwd::Execute(const NDArray &input,
void* workspace) {
auto stream = MKLDNNStream::Get();
auto in_mem = input.GetMKLDNNData();
auto in_md = in_mem->get_desc();
// register primitives and arguments
std::vector<mkldnn_args_map_t> args_map_;
std::vector<mkldnn_args_map_t> args_map;
size_t prims_size = prims_.size();
if (prims_size == 1) {
args_map_.push_back({{MKLDNN_ARG_FROM, *in_mem},
{MKLDNN_ARG_TO, *output.GetMKLDNNData()}});
args_map.push_back({{MKLDNN_ARG_FROM, *in_mem},
{MKLDNN_ARG_TO, *output.GetMKLDNNData()}});
} else if (prims_size == 2) {
auto temp_md = GetDesc(in_md, GetDefaultFormat(in_md));
temp_ = std::make_shared<mkldnn::memory>(temp_md, CpuEngine::Get()->get_engine(),
workspace);
args_map_.push_back({{MKLDNN_ARG_FROM, *in_mem},
if (workspace) {
temp_->set_data_handle(workspace);
}
args_map.push_back({{MKLDNN_ARG_FROM, *in_mem},
{MKLDNN_ARG_TO, *temp_}});
args_map_.push_back({{MKLDNN_ARG_FROM, *temp_},
args_map.push_back({{MKLDNN_ARG_FROM, *temp_},
{MKLDNN_ARG_TO, *output.GetMKLDNNData()}});
} else {
CHECK(prims_size == 0 && req != kWriteTo)
<< "kWriteTo should never reach here.";
}

for (size_t i = 0; i < prims_size; i++) {
stream->RegisterPrimArgs(prims_[i], args_map_[i]);
stream->RegisterPrimArgs(prims_[i], args_map[i]);
}
stream->Submit();
// invalidate mkldnn memory in input
Expand Down

0 comments on commit 827b5e1

Please sign in to comment.