Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
improve log & modify definition location of args_map_
Browse files Browse the repository at this point in the history
  • Loading branch information
wuxun-zhang committed Sep 24, 2019
1 parent 47f72bc commit 8239215
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/operator/nn/mkldnn/mkldnn_expand_dims.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ void MKLDNNExpandDimsForward(const nnvm::NodeAttrs &attrs,
ws_ptr = reinterpret_cast<void*>(ws.dptr_);
}

fwd.Execute(input, output, ws_ptr);
fwd.Execute(input, output, req, ws_ptr);
}

} // namespace op
Expand Down
2 changes: 1 addition & 1 deletion src/operator/nn/mkldnn/mkldnn_flatten.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ void MKLDNNFlattenForward(const nnvm::NodeAttrs &attrs,
ws_ptr = reinterpret_cast<void*>(ws.dptr_);
}

fwd.Execute(input, output, ws_ptr);
fwd.Execute(input, output, req, ws_ptr);
}

} // namespace op
Expand Down
2 changes: 1 addition & 1 deletion src/operator/nn/mkldnn/mkldnn_reshape-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ class MKLDNNReshapeFwd {
protected:
std::shared_ptr<mkldnn::memory> out_;
std::shared_ptr<mkldnn::memory> temp_;
std::vector<mkldnn_args_map_t> args_map_;
std::vector<mkldnn::primitive> prims_;
bool needInvalidateInput = false;

Expand All @@ -49,6 +48,7 @@ class MKLDNNReshapeFwd {
int GetWorkspaceSize();
void Execute(const NDArray &input,
const NDArray &output,
const OpReqType &req,
void* workspace = nullptr);
};

Expand Down
10 changes: 8 additions & 2 deletions src/operator/nn/mkldnn/mkldnn_reshape.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,24 +93,30 @@ int MKLDNNReshapeFwd::GetWorkspaceSize() {

void MKLDNNReshapeFwd::Execute(const NDArray &input,
const NDArray &output,
const OpReqType &req,
void* workspace) {
auto stream = MKLDNNStream::Get();
auto in_mem = input.GetMKLDNNData();
auto in_md = in_mem->get_desc();
// register primitives and arguments
std::vector<mkldnn_args_map_t> args_map_;
size_t prims_size = prims_.size();
if (prims_size == 1) {
args_map_.push_back({{MKLDNN_ARG_FROM, *in_mem},
{MKLDNN_ARG_TO, *output.GetMKLDNNData()}});
} else if (prims_size == 2) {
auto temp_md = GetDesc(in_md, GetDefaultFormat(in_md));
temp_ = std::make_shared<mkldnn::memory>(temp_md, CpuEngine::Get()->get_engine(),
workspace);
workspace);
args_map_.push_back({{MKLDNN_ARG_FROM, *in_mem},
{MKLDNN_ARG_TO, *temp_}});
args_map_.push_back({{MKLDNN_ARG_FROM, *temp_},
{MKLDNN_ARG_TO, *output.GetMKLDNNData()}});
} else {
CHECK(prims_size == 0 && req != kWriteTo)
<< "kWriteTo should never reach here.";
}

for (size_t i = 0; i < prims_size; i++) {
stream->RegisterPrimArgs(prims_[i], args_map_[i]);
}
Expand Down Expand Up @@ -142,7 +148,7 @@ void MKLDNNReshapeForward(const nnvm::NodeAttrs& attrs,
ws_ptr = reinterpret_cast<void*>(ws.dptr_);
}

fwd.Execute(input, output, ws_ptr);
fwd.Execute(input, output, req, ws_ptr);
}
} // namespace op
} // namespace mxnet
Expand Down

0 comments on commit 8239215

Please sign in to comment.