diff --git a/src/imperative/imperative_utils.cc b/src/imperative/imperative_utils.cc index 6cb4a70324b5..c7204c1d85e6 100644 --- a/src/imperative/imperative_utils.cc +++ b/src/imperative/imperative_utils.cc @@ -20,62 +20,61 @@ #include "./imperative_utils.h" #include "./cached_op.h" -namespace mxnet { -namespace imperative { +namespace { -inline std::vector NodeInputs(const nnvm::IndexedGraph& idx, - const int node_idx, - const std::vector arrays) { +std::vector NodeInputs(const nnvm::IndexedGraph& idx, + const int node_idx, + const std::vector& arrays) { const nnvm::IndexedGraph::Node& node = idx[node_idx]; const size_t num_inputs = node.inputs.size(); std::vector ndinputs; ndinputs.reserve(num_inputs); for (const auto& j : node.inputs) { - size_t eid = idx.entry_id(j); + const size_t eid = idx.entry_id(j); ndinputs.emplace_back(arrays[eid]); } return ndinputs; } -inline std::vector NodeOutputs(const nnvm::IndexedGraph& idx, - const int node_idx, - const std::vector arrays) { +std::vector NodeOutputs(const nnvm::IndexedGraph& idx, + const int node_idx, + const std::vector& arrays) { const nnvm::IndexedGraph::Node& node = idx[node_idx]; const size_t num_outputs = node.source->num_outputs(); std::vector ndoutputs; ndoutputs.reserve(num_outputs); for (size_t j = 0; j < num_outputs; ++j) { - size_t eid = idx.entry_id(node_idx, j); + const size_t eid = idx.entry_id(node_idx, j); ndoutputs.emplace_back(arrays[eid]); } return ndoutputs; } -inline std::vector NodeReq(const nnvm::IndexedGraph& idx, - const int node_idx, - const std::vector array_reqs) { +std::vector NodeReq(const nnvm::IndexedGraph& idx, + const int node_idx, + const std::vector& array_reqs) { const nnvm::IndexedGraph::Node& node = idx[node_idx]; const size_t num_outputs = node.source->num_outputs(); std::vector req; req.reserve(num_outputs); for (size_t j = 0; j < num_outputs; ++j) { - size_t eid = idx.entry_id(node_idx, j); + const size_t eid = idx.entry_id(node_idx, j); req.push_back(array_reqs[eid]); } return req; } -inline void InvokeOperator(const nnvm::IndexedGraph& idx, - const int node_idx, - const bool retain_graph, - const std::vector arrays, - Context ctx, - std::vector* p_states, - std::vector ndinputs, - std::vector ndoutputs, - std::vector *p_req, - std::vector *p_ref_count, - std::function invoke) { +void InvokeOperator(const nnvm::IndexedGraph& idx, + const int node_idx, + const bool retain_graph, + const std::vector& arrays, + Context ctx, + std::vector* p_states, + const std::vector& ndinputs, + const std::vector& ndoutputs, + std::vector *p_req, + std::vector *p_ref_count, + std::function invoke) { static const auto bwd_cached_op = Op::Get("_backward_CachedOp"); static auto& createop = nnvm::Op::GetAttr("FCreateOpState"); static auto& is_layer_backward = Op::GetAttr("TIsLayerOpBackward"); @@ -122,10 +121,15 @@ inline void InvokeOperator(const nnvm::IndexedGraph& idx, } } +} // namespace + +namespace mxnet { +namespace imperative { + void RunGraph( const bool retain_graph, const nnvm::IndexedGraph& idx, - const std::vector arrays, + const std::vector& arrays, size_t node_start, size_t node_end, std::vector&& array_reqs, std::vector&& ref_count, @@ -161,7 +165,7 @@ void NaiveRunGraph( const bool retain_graph, const Context& default_ctx, const nnvm::IndexedGraph& idx, - const std::vector arrays, + const std::vector& arrays, size_t node_start, size_t node_end, std::vector&& array_reqs, std::vector&& ref_count, diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h index 071f4fa9dd0b..d134d47c55cf 100644 --- a/src/imperative/imperative_utils.h +++ b/src/imperative/imperative_utils.h @@ -999,7 +999,7 @@ inline void CreateEngineOpSeg( void RunGraph(const bool retain_graph, const nnvm::IndexedGraph& idx, - const std::vector arrays, + const std::vector& arrays, size_t node_start, size_t node_end, std::vector&& array_reqs, std::vector&& ref_count, @@ -1011,7 +1011,7 @@ void RunGraph(const bool retain_graph, void NaiveRunGraph(const bool retain_graph, const Context& default_ctx, const nnvm::IndexedGraph& idx, - const std::vector arrays, + const std::vector& arrays, size_t node_start, size_t node_end, std::vector&& array_reqs, std::vector&& ref_count,