Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Refactor DynamicForward
Browse files Browse the repository at this point in the history
  • Loading branch information
anirudh2290 committed Jan 18, 2020
1 parent 800847d commit e74ac4d
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 66 deletions.
35 changes: 6 additions & 29 deletions src/imperative/cached_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -712,7 +712,6 @@ OpStatePtr CachedOp::DynamicForward(
}
nnvm::Graph& g = runtime.info.fwd_graph;
const auto& idx = g.indexed_graph();
size_t num_inputs = idx.input_nodes().size();
auto& buff = runtime.buff;
auto& states = runtime.op_states;

Expand All @@ -724,39 +723,17 @@ OpStatePtr CachedOp::DynamicForward(
for (auto& buffered_array : buff) {
arrays.push_back(&buffered_array);
}
for (size_t i = 0; i < num_inputs; ++i) {
arrays[idx.entry_id(idx.input_nodes()[i], 0)] = inputs[i];
}
for (size_t i = 0; i < idx.outputs().size(); ++i) {
auto eid = idx.entry_id(idx.outputs()[i]);
if (!arrays[eid]->is_none()) *outputs[i] = arrays[eid]->Detach();
arrays[eid] = outputs[i];
}

// Allocate NDArrays
std::vector<OpReqType> array_reqs(arrays.size(), kWriteTo);
const auto& dispatch_modes = g.GetAttr<DispatchModeVector>("dispatch_mode");
const std::string& graph_type = recording ? FULL : FORWARD;
std::vector<uint32_t> ref_count =
g.GetAttr<std::vector<uint32_t> >(AddPrefix(graph_type, REF_COUNT));
const auto& mem_plan = g.GetAttr<MemoryPlanVector >(AddPrefix(graph_type, MEM_PLAN));
CollectInputOutputNDRefs(g, inputs, outputs, &arrays);
AllocateGraphOutputs(g, default_ctx, ref_count,
mem_plan, use_naive_run, &array_reqs, &arrays);

std::vector<OpReqType> array_reqs(arrays.size(), kWriteTo);
for (size_t i = 0; i < idx.num_node_entries(); ++i) {
if (ref_count[i] == 0) array_reqs[i] = kNullOp;
}
const auto& dispatch_modes = g.GetAttr<DispatchModeVector>("dispatch_mode");
if (!use_naive_run) {
const auto& mem_plan = g.GetAttr<MemoryPlanVector >(AddPrefix(graph_type, MEM_PLAN));
AllocateMemory(g, idx, default_ctx, 0, idx.num_node_entries(),
mem_plan, arrays, &array_reqs);
const auto& dtypes = g.GetAttr<DTypeVector>("dtype");
const auto& shapes = g.GetAttr<mxnet::ShapeVector>("shape");
const auto& stypes = g.GetAttr<StorageTypeVector>("storage_type");
for (size_t i = 0; i < outputs.size(); ++i) {
auto eid = idx.entry_id(idx.outputs()[i]);
arrays[eid] = outputs[i];
if (!outputs[i]->is_none()) continue;
*outputs[i] = NDArray(static_cast<NDArrayStorageType>(stypes[eid]),
shapes[eid], default_ctx, true, dtypes[eid]);
}
// If CachedOp is running in the inline mode, it uses RunGraph to record
// computation; otherwise, CachedOp records computation itself.
// So if it's not the inline mode, we disable recording.
Expand Down
52 changes: 52 additions & 0 deletions src/imperative/cached_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,58 @@ std::string AddPrefix(const std::string& prefix,
return prefix + "_" + s;
}

/* \brief collect pointers to input and output ndarrays
* into a single data structure, this data structure can
* be used for Memory allocation pass*/
void CollectInputOutputNDRefs(const nnvm::Graph& g,
const std::vector<NDArray*>& inputs,
const std::vector<NDArray*>& outputs,
std::vector<NDArray*>* arrays) {
const auto& idx = g.indexed_graph();
size_t num_inputs = idx.input_nodes().size();
for (size_t i = 0; i < num_inputs; ++i) {
(*arrays)[idx.entry_id(idx.input_nodes()[i], 0)] = inputs[i];
}
for (size_t i = 0; i < idx.outputs().size(); ++i) {
auto eid = idx.entry_id(idx.outputs()[i]);
if (!(*arrays)[eid]->is_none())
*outputs[i] = (*arrays)[eid]->Detach();
(*arrays)[eid] = outputs[i];
}
}

/* \brief create ndarrays for the intermediate outputs and final outputs
* from the allocated storage (happens in MXPlanMemory NNVM pass)*/
void CreateGraphNDs(const nnvm::Graph& g,
const mxnet::Context& default_ctx,
const std::vector<uint32_t>& ref_count,
const mxnet::imperative::MemoryPlanVector& mem_plan,
bool use_naive_run,
std::vector<OpReqType>* array_reqs,
std::vector<NDArray*>* arrays) {
const auto& idx = g.indexed_graph();
for (size_t i = 0; i < idx.num_node_entries(); ++i) {
if (ref_count[i] == 0)
(*array_reqs)[i] = kNullOp;
}

if (!use_naive_run) {
mxnet::imperative::AllocateMemory(g, idx, default_ctx, 0,
idx.num_node_entries(), mem_plan, *arrays,
array_reqs);
const auto &dtypes = g.GetAttr<nnvm::DTypeVector>("dtype");
const auto &shapes = g.GetAttr<mxnet::ShapeVector>("shape");
const auto &stypes = g.GetAttr<mxnet::StorageTypeVector>("storage_type");
for (size_t i = 0; i < idx.outputs().size(); ++i) {
auto eid = idx.entry_id(idx.outputs()[i]);
if (!(*arrays)[eid]->is_none())
continue;
*((*arrays)[eid]) = NDArray(static_cast<NDArrayStorageType>(stypes[eid]),
shapes[eid], default_ctx, true, dtypes[eid]);
}
}
}

/* \brief create a forward graph from they Symbol */
void CreateForwardGraph(const nnvm::Symbol &sym, nnvm::Graph *fwd_graph) {
using namespace nnvm;
Expand Down
44 changes: 7 additions & 37 deletions src/imperative/cached_op_threadsafe.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ OpStatePtr CachedOpThreadSafe::DynamicForward(const Context& default_ctx,
using namespace nnvm;
using namespace imperative;

{
auto state_ptr = GetCachedOpState(default_ctx);
auto op_state = OpStatePtr::Create<DynamicRuntime>();
auto &runtime = op_state.get_state<DynamicRuntime>();
Expand All @@ -106,7 +105,6 @@ OpStatePtr CachedOpThreadSafe::DynamicForward(const Context& default_ctx,
}
nnvm::Graph &g = runtime.info.fwd_graph;
const auto &idx = g.indexed_graph();
size_t num_inputs = idx.input_nodes().size();
size_t max_nodes = runtime.info.fwd_graph.indexed_graph().num_nodes();
runtime.op_states.resize(max_nodes);
auto &states = runtime.op_states;
Expand All @@ -121,46 +119,18 @@ OpStatePtr CachedOpThreadSafe::DynamicForward(const Context& default_ctx,
for (auto &buffered_array : buff) {
arrays.push_back(&buffered_array);
}
for (size_t i = 0; i < num_inputs; ++i) {
arrays[idx.entry_id(idx.input_nodes()[i], 0)] = inputs[i];
}
for (size_t i = 0; i < idx.outputs().size(); ++i) {
auto eid = idx.entry_id(idx.outputs()[i]);
if (!arrays[eid]->is_none())
*outputs[i] = arrays[eid]->Detach();
arrays[eid] = outputs[i];
}
// Allocate NDArrays
std::vector<uint32_t> ref_count = g.GetAttr<std::vector<uint32_t>>(
"forward_ref_count");

std::vector<OpReqType> array_reqs(arrays.size(), kWriteTo);
for (size_t i = 0; i < idx.num_node_entries(); ++i) {
if (ref_count[i] == 0)
array_reqs[i] = kNullOp;
}
const auto &dispatch_modes = g.GetAttr<DispatchModeVector>("dispatch_mode");
const auto &mem_plan = g.GetAttr<MemoryPlanVector>("forward_mem_plan");
AllocateMemory(g, idx, default_ctx, 0, idx.num_node_entries(), mem_plan,
arrays, &array_reqs);
const auto &dtypes = g.GetAttr<DTypeVector>("dtype");
const auto &shapes = g.GetAttr<mxnet::ShapeVector>("shape");
const auto &stypes = g.GetAttr<StorageTypeVector>("storage_type");
for (size_t i = 0; i < outputs.size(); ++i) {
auto eid = idx.entry_id(idx.outputs()[i]);
arrays[eid] = outputs[i];
if (!outputs[i]->is_none())
continue;
*outputs[i] = NDArray(static_cast<NDArrayStorageType>(stypes[eid]),
shapes[eid], default_ctx, true, dtypes[eid]);
}
// If CachedOp is running in the inline mode, it uses RunGraph to record
// computation; otherwise, CachedOp records computation itself.
// So if it's not the inline mode, we disable recording.
std::vector<uint32_t> ref_count = g.GetAttr<std::vector<uint32_t>>(
"forward_ref_count");
const MemoryPlanVector& mem_plan = g.GetAttr<MemoryPlanVector>("forward_mem_plan");
const std::string& graph_type = FORWARD;
CollectInputOutputNDRefs(g, inputs, outputs, &arrays);
AllocateGraphOutputs(g, default_ctx, ref_count,
mem_plan, false, &array_reqs, &arrays);
RunGraph(false, idx, arrays, 0, idx.num_nodes(), std::move(array_reqs),
std::move(ref_count), &states, dispatch_modes, false);
return op_state;
}
}

OpStatePtr CachedOpThreadSafe::Forward(const std::shared_ptr<CachedOp>& op_ptr,
Expand Down

0 comments on commit e74ac4d

Please sign in to comment.