From e74ac4dda441770b9352828ad840ec89236559a7 Mon Sep 17 00:00:00 2001 From: Anirudh Subramanian Date: Sat, 18 Jan 2020 03:08:44 +0000 Subject: [PATCH] Refactor DynamicForward --- src/imperative/cached_op.cc | 35 +++-------------- src/imperative/cached_op.h | 52 ++++++++++++++++++++++++++ src/imperative/cached_op_threadsafe.cc | 44 ++++------------------ 3 files changed, 65 insertions(+), 66 deletions(-) diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc index 8908f3d44df9..9ebb96659c2f 100644 --- a/src/imperative/cached_op.cc +++ b/src/imperative/cached_op.cc @@ -712,7 +712,6 @@ OpStatePtr CachedOp::DynamicForward( } nnvm::Graph& g = runtime.info.fwd_graph; const auto& idx = g.indexed_graph(); - size_t num_inputs = idx.input_nodes().size(); auto& buff = runtime.buff; auto& states = runtime.op_states; @@ -724,39 +723,17 @@ OpStatePtr CachedOp::DynamicForward( for (auto& buffered_array : buff) { arrays.push_back(&buffered_array); } - for (size_t i = 0; i < num_inputs; ++i) { - arrays[idx.entry_id(idx.input_nodes()[i], 0)] = inputs[i]; - } - for (size_t i = 0; i < idx.outputs().size(); ++i) { - auto eid = idx.entry_id(idx.outputs()[i]); - if (!arrays[eid]->is_none()) *outputs[i] = arrays[eid]->Detach(); - arrays[eid] = outputs[i]; - } - - // Allocate NDArrays + std::vector array_reqs(arrays.size(), kWriteTo); + const auto& dispatch_modes = g.GetAttr("dispatch_mode"); const std::string& graph_type = recording ? FULL : FORWARD; std::vector ref_count = g.GetAttr >(AddPrefix(graph_type, REF_COUNT)); + const auto& mem_plan = g.GetAttr(AddPrefix(graph_type, MEM_PLAN)); + CollectInputOutputNDRefs(g, inputs, outputs, &arrays); + AllocateGraphOutputs(g, default_ctx, ref_count, + mem_plan, use_naive_run, &array_reqs, &arrays); - std::vector array_reqs(arrays.size(), kWriteTo); - for (size_t i = 0; i < idx.num_node_entries(); ++i) { - if (ref_count[i] == 0) array_reqs[i] = kNullOp; - } - const auto& dispatch_modes = g.GetAttr("dispatch_mode"); if (!use_naive_run) { - const auto& mem_plan = g.GetAttr(AddPrefix(graph_type, MEM_PLAN)); - AllocateMemory(g, idx, default_ctx, 0, idx.num_node_entries(), - mem_plan, arrays, &array_reqs); - const auto& dtypes = g.GetAttr("dtype"); - const auto& shapes = g.GetAttr("shape"); - const auto& stypes = g.GetAttr("storage_type"); - for (size_t i = 0; i < outputs.size(); ++i) { - auto eid = idx.entry_id(idx.outputs()[i]); - arrays[eid] = outputs[i]; - if (!outputs[i]->is_none()) continue; - *outputs[i] = NDArray(static_cast(stypes[eid]), - shapes[eid], default_ctx, true, dtypes[eid]); - } // If CachedOp is running in the inline mode, it uses RunGraph to record // computation; otherwise, CachedOp records computation itself. // So if it's not the inline mode, we disable recording. diff --git a/src/imperative/cached_op.h b/src/imperative/cached_op.h index 7f0d109b3420..f70083afb05f 100644 --- a/src/imperative/cached_op.h +++ b/src/imperative/cached_op.h @@ -46,6 +46,58 @@ std::string AddPrefix(const std::string& prefix, return prefix + "_" + s; } +/* \brief collect pointers to input and output ndarrays + * into a single data structure, this data structure can + * be used for Memory allocation pass*/ +void CollectInputOutputNDRefs(const nnvm::Graph& g, + const std::vector& inputs, + const std::vector& outputs, + std::vector* arrays) { + const auto& idx = g.indexed_graph(); + size_t num_inputs = idx.input_nodes().size(); + for (size_t i = 0; i < num_inputs; ++i) { + (*arrays)[idx.entry_id(idx.input_nodes()[i], 0)] = inputs[i]; + } + for (size_t i = 0; i < idx.outputs().size(); ++i) { + auto eid = idx.entry_id(idx.outputs()[i]); + if (!(*arrays)[eid]->is_none()) + *outputs[i] = (*arrays)[eid]->Detach(); + (*arrays)[eid] = outputs[i]; + } +} + +/* \brief create ndarrays for the intermediate outputs and final outputs + * from the allocated storage (happens in MXPlanMemory NNVM pass)*/ +void CreateGraphNDs(const nnvm::Graph& g, + const mxnet::Context& default_ctx, + const std::vector& ref_count, + const mxnet::imperative::MemoryPlanVector& mem_plan, + bool use_naive_run, + std::vector* array_reqs, + std::vector* arrays) { + const auto& idx = g.indexed_graph(); + for (size_t i = 0; i < idx.num_node_entries(); ++i) { + if (ref_count[i] == 0) + (*array_reqs)[i] = kNullOp; + } + + if (!use_naive_run) { + mxnet::imperative::AllocateMemory(g, idx, default_ctx, 0, + idx.num_node_entries(), mem_plan, *arrays, + array_reqs); + const auto &dtypes = g.GetAttr("dtype"); + const auto &shapes = g.GetAttr("shape"); + const auto &stypes = g.GetAttr("storage_type"); + for (size_t i = 0; i < idx.outputs().size(); ++i) { + auto eid = idx.entry_id(idx.outputs()[i]); + if (!(*arrays)[eid]->is_none()) + continue; + *((*arrays)[eid]) = NDArray(static_cast(stypes[eid]), + shapes[eid], default_ctx, true, dtypes[eid]); + } + } +} + /* \brief create a forward graph from they Symbol */ void CreateForwardGraph(const nnvm::Symbol &sym, nnvm::Graph *fwd_graph) { using namespace nnvm; diff --git a/src/imperative/cached_op_threadsafe.cc b/src/imperative/cached_op_threadsafe.cc index d17b9b2cdfae..c98187e0d8b8 100644 --- a/src/imperative/cached_op_threadsafe.cc +++ b/src/imperative/cached_op_threadsafe.cc @@ -90,7 +90,6 @@ OpStatePtr CachedOpThreadSafe::DynamicForward(const Context& default_ctx, using namespace nnvm; using namespace imperative; - { auto state_ptr = GetCachedOpState(default_ctx); auto op_state = OpStatePtr::Create(); auto &runtime = op_state.get_state(); @@ -106,7 +105,6 @@ OpStatePtr CachedOpThreadSafe::DynamicForward(const Context& default_ctx, } nnvm::Graph &g = runtime.info.fwd_graph; const auto &idx = g.indexed_graph(); - size_t num_inputs = idx.input_nodes().size(); size_t max_nodes = runtime.info.fwd_graph.indexed_graph().num_nodes(); runtime.op_states.resize(max_nodes); auto &states = runtime.op_states; @@ -121,46 +119,18 @@ OpStatePtr CachedOpThreadSafe::DynamicForward(const Context& default_ctx, for (auto &buffered_array : buff) { arrays.push_back(&buffered_array); } - for (size_t i = 0; i < num_inputs; ++i) { - arrays[idx.entry_id(idx.input_nodes()[i], 0)] = inputs[i]; - } - for (size_t i = 0; i < idx.outputs().size(); ++i) { - auto eid = idx.entry_id(idx.outputs()[i]); - if (!arrays[eid]->is_none()) - *outputs[i] = arrays[eid]->Detach(); - arrays[eid] = outputs[i]; - } - // Allocate NDArrays - std::vector ref_count = g.GetAttr>( - "forward_ref_count"); - std::vector array_reqs(arrays.size(), kWriteTo); - for (size_t i = 0; i < idx.num_node_entries(); ++i) { - if (ref_count[i] == 0) - array_reqs[i] = kNullOp; - } const auto &dispatch_modes = g.GetAttr("dispatch_mode"); - const auto &mem_plan = g.GetAttr("forward_mem_plan"); - AllocateMemory(g, idx, default_ctx, 0, idx.num_node_entries(), mem_plan, - arrays, &array_reqs); - const auto &dtypes = g.GetAttr("dtype"); - const auto &shapes = g.GetAttr("shape"); - const auto &stypes = g.GetAttr("storage_type"); - for (size_t i = 0; i < outputs.size(); ++i) { - auto eid = idx.entry_id(idx.outputs()[i]); - arrays[eid] = outputs[i]; - if (!outputs[i]->is_none()) - continue; - *outputs[i] = NDArray(static_cast(stypes[eid]), - shapes[eid], default_ctx, true, dtypes[eid]); - } - // If CachedOp is running in the inline mode, it uses RunGraph to record - // computation; otherwise, CachedOp records computation itself. - // So if it's not the inline mode, we disable recording. + std::vector ref_count = g.GetAttr>( + "forward_ref_count"); + const MemoryPlanVector& mem_plan = g.GetAttr("forward_mem_plan"); + const std::string& graph_type = FORWARD; + CollectInputOutputNDRefs(g, inputs, outputs, &arrays); + AllocateGraphOutputs(g, default_ctx, ref_count, + mem_plan, false, &array_reqs, &arrays); RunGraph(false, idx, arrays, 0, idx.num_nodes(), std::move(array_reqs), std::move(ref_count), &states, dispatch_modes, false); return op_state; - } } OpStatePtr CachedOpThreadSafe::Forward(const std::shared_ptr& op_ptr,