Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Avoid uneccesary vector copies in imperative_utils.cc
Browse files Browse the repository at this point in the history
  • Loading branch information
larroy committed Apr 10, 2019
1 parent fde4963 commit d04400a
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 13 deletions.
22 changes: 11 additions & 11 deletions src/imperative/imperative_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,41 +25,41 @@ namespace imperative {

inline std::vector<NDArray*> NodeInputs(const nnvm::IndexedGraph& idx,
const int node_idx,
const std::vector<NDArray*> arrays) {
const std::vector<NDArray*>& arrays) {
const nnvm::IndexedGraph::Node& node = idx[node_idx];
const size_t num_inputs = node.inputs.size();
std::vector<NDArray*> ndinputs;
ndinputs.reserve(num_inputs);
for (const auto& j : node.inputs) {
size_t eid = idx.entry_id(j);
const size_t eid = idx.entry_id(j);
ndinputs.emplace_back(arrays[eid]);
}
return ndinputs;
}

inline std::vector<NDArray*> NodeOutputs(const nnvm::IndexedGraph& idx,
const int node_idx,
const std::vector<NDArray*> arrays) {
const std::vector<NDArray*>& arrays) {
const nnvm::IndexedGraph::Node& node = idx[node_idx];
const size_t num_outputs = node.source->num_outputs();
std::vector<NDArray*> ndoutputs;
ndoutputs.reserve(num_outputs);
for (size_t j = 0; j < num_outputs; ++j) {
size_t eid = idx.entry_id(node_idx, j);
const size_t eid = idx.entry_id(node_idx, j);
ndoutputs.emplace_back(arrays[eid]);
}
return ndoutputs;
}

inline std::vector<OpReqType> NodeReq(const nnvm::IndexedGraph& idx,
const int node_idx,
const std::vector<OpReqType> array_reqs) {
const std::vector<OpReqType>& array_reqs) {
const nnvm::IndexedGraph::Node& node = idx[node_idx];
const size_t num_outputs = node.source->num_outputs();
std::vector<OpReqType> req;
req.reserve(num_outputs);
for (size_t j = 0; j < num_outputs; ++j) {
size_t eid = idx.entry_id(node_idx, j);
const size_t eid = idx.entry_id(node_idx, j);
req.push_back(array_reqs[eid]);
}
return req;
Expand All @@ -68,11 +68,11 @@ inline std::vector<OpReqType> NodeReq(const nnvm::IndexedGraph& idx,
inline void InvokeOperator(const nnvm::IndexedGraph& idx,
const int node_idx,
const bool retain_graph,
const std::vector<NDArray*> arrays,
const std::vector<NDArray*>& arrays,
Context ctx,
std::vector<OpStatePtr>* p_states,
std::vector<NDArray*> ndinputs,
std::vector<NDArray*> ndoutputs,
std::vector<NDArray*>& ndinputs,
std::vector<NDArray*>& ndoutputs,
std::vector<OpReqType> *p_req,
std::vector<uint32_t> *p_ref_count,
std::function<void(const OpStatePtr &state)> invoke) {
Expand Down Expand Up @@ -125,7 +125,7 @@ inline void InvokeOperator(const nnvm::IndexedGraph& idx,
void RunGraph(
const bool retain_graph,
const nnvm::IndexedGraph& idx,
const std::vector<NDArray*> arrays,
const std::vector<NDArray*>& arrays,
size_t node_start, size_t node_end,
std::vector<OpReqType>&& array_reqs,
std::vector<uint32_t>&& ref_count,
Expand Down Expand Up @@ -161,7 +161,7 @@ void NaiveRunGraph(
const bool retain_graph,
const Context& default_ctx,
const nnvm::IndexedGraph& idx,
const std::vector<NDArray*> arrays,
const std::vector<NDArray*>& arrays,
size_t node_start, size_t node_end,
std::vector<OpReqType>&& array_reqs,
std::vector<uint32_t>&& ref_count,
Expand Down
4 changes: 2 additions & 2 deletions src/imperative/imperative_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -999,7 +999,7 @@ inline void CreateEngineOpSeg(

void RunGraph(const bool retain_graph,
const nnvm::IndexedGraph& idx,
const std::vector<NDArray*> arrays,
const std::vector<NDArray*>& arrays,
size_t node_start, size_t node_end,
std::vector<OpReqType>&& array_reqs,
std::vector<uint32_t>&& ref_count,
Expand All @@ -1011,7 +1011,7 @@ void RunGraph(const bool retain_graph,
void NaiveRunGraph(const bool retain_graph,
const Context& default_ctx,
const nnvm::IndexedGraph& idx,
const std::vector<NDArray*> arrays,
const std::vector<NDArray*>& arrays,
size_t node_start, size_t node_end,
std::vector<OpReqType>&& array_reqs,
std::vector<uint32_t>&& ref_count,
Expand Down

0 comments on commit d04400a

Please sign in to comment.