Skip to content

Commit

Permalink
Naive engine end to end run through
Browse files Browse the repository at this point in the history
  • Loading branch information
reminisce committed Jun 20, 2018
1 parent c95e912 commit 7e96475
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 66 deletions.
4 changes: 0 additions & 4 deletions src/engine/naive_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -165,10 +165,6 @@ class NaiveEngine final : public Engine {
}
CHECK(this->req_completed_)
<< "NaiveEngine only support synchronize Push so far";
// increment mutable_vars version
for (auto var : opr->mutable_vars) {
++var->version_;
}
if (profiling) {
opr->opr_profile->stop();
}
Expand Down
44 changes: 4 additions & 40 deletions src/operator/subgraph/partition_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,47 +111,11 @@ void ResetSubgraphNodes(std::vector<SimpleNode*>* subgraph_nodes) {
/*
* This function traverses the nodes in a computation graph from a starting
* node following the input links and output links, and marks all nodes that
* can be accessed from the starting node.
* can be accessed from the starting node. Before the function returns,
* it will conduct checking whether there is a loop between the potential subgraph
* and the outside nodes. If so, add the node that should break the loop
* in excluded_nodes and return false. Otherwise, return true.
*/
#if 0
void LabelSubgraph(const Graph&g,
SubgraphSelectorPtr select_func,
const int label,
const size_t snid, // simple node id, this is a seed
const std::vector<SimpleNodePtr>& simple_nodes,
std::vector<SimpleNode*>* subgraph_nodes) {
const auto& indexed_graph = g.indexed_graph();
std::queue<SimpleNode*> node_queue;
node_queue.push(simple_nodes[snid].get());
while (!node_queue.empty()) {
SimpleNode* cur_node = node_queue.front();
node_queue.pop();
cur_node->label = label;
subgraph_nodes->push_back(cur_node);
// get qualified adjacent input nodes
for (auto& e : cur_node->node->inputs) {
if (select_func->SelectInput(*cur_node->node, *e.node)) {
const auto nid = indexed_graph.node_id(e.node.get());
CHECK_LT(nid, simple_nodes.size());
// this node has not been visited yet
if (simple_nodes[nid]->label == -1)
node_queue.push(simple_nodes[nid].get());
}
}
// get qualified output nodes
for (auto it = cur_node->outputs.begin(); it != cur_node->outputs.end(); ++it) {
if (select_func->SelectOutput(*cur_node->node, *it->first)) {
const auto nid = indexed_graph.node_id(it->first);
CHECK_LT(nid, simple_nodes.size());
// this node has not been visited yet
if (simple_nodes[nid]->label == -1)
node_queue.push(simple_nodes[nid].get());
}
}
}
}
#endif

bool LabelSubgraph(const Graph& g,
SubgraphSelectorPtr select_func,
const int label,
Expand Down
61 changes: 39 additions & 22 deletions src/operator/subgraph/subgraph_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,25 +27,29 @@
#include "../operator_common.h"
#include "../elemwise_op_common.h"
#include "../../imperative/imperative_utils.h"
#include "../../imperative/cached_op.h"
#include "./subgraph_op.h"

namespace mxnet {
namespace op {

class DefaultSubgraphOperator {
public:
// TODO: initialize uuid
DefaultSubgraphOperator(const Symbol& sym) : subgraph_uuid_("dfasdfadsmxdfw324"),
immutable_data_names_(sym.ListInputNames(Symbol::kReadOnlyArgs)),
mutable_data_names_(sym.ListInputNames(Symbol::kAuxiliaryStates)),
//input_data_names_(sym.ListInputNames(Symbol::kAll)),
output_data_names_(sym.ListOutputNames()) {
this->subg_sym = sym;
const std::vector<std::string> input_data_names = sym.ListInputNames(Symbol::kAll);
DefaultSubgraphOperator(const Symbol& sym) :
//subgraph_uuid_("dfasdfadsmxdfw324"),
//immutable_data_names_(sym.ListInputNames(Symbol::kReadOnlyArgs)),
//mutable_data_names_(sym.ListInputNames(Symbol::kAuxiliaryStates)),
subgraph_sym_(sym),
input_names_(sym.ListInputNames(Symbol::kAll)),
output_names_(sym.ListOutputNames()) {
//subgraph_exec_.reset(new CachedOp(sym, {{"static_alloc", "true"}}));
subgraph_exec_.reset(new CachedOp(sym, {}));
//const std::vector<std::string> input_data_names = sym.ListInputNames(Symbol::kAll);
//const std::vector<std::string> immutable_data_names = sym.ListInputNames(Symbol::kReadOnlyArgs);
//const std::vector<std::string> mutable_data_names = sym.ListInputNames(Symbol::kAuxiliaryStates);
immutable_data_indices_.resize(immutable_data_names_.size());
mutable_data_indices_.resize(mutable_data_names_.size());
//immutable_data_indices_.resize(immutable_data_names_.size());
//mutable_data_indices_.resize(mutable_data_names_.size());
#if 0
for (uint32_t i = 0, j1 = 0, j2 = 0; i < input_data_names.size(); ++i) {
if (input_data_names[i] == immutable_data_names_[j1]) {
immutable_data_indices_[j1++] = i;
Expand All @@ -55,9 +59,9 @@ class DefaultSubgraphOperator {
LOG(FATAL) << "Should not happen";
}
}

// initialize var versions to -1
ndarray_var_versions_.resize(input_data_names.size(), -1);
#endif
}

void Forward(const OpContext& ctx,
Expand All @@ -68,28 +72,41 @@ class DefaultSubgraphOperator {
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
// TODO we don't support backward yet.
// Many of the subgraphs don't need backward computation.
LOG(FATAL) << "Not implemented";
}

private:
nnvm::Symbol subg_sym;
std::string subgraph_uuid_;
nnvm::Symbol subgraph_sym_;
//std::string subgraph_uuid_;
// this variable records the NDArrays' var versions of the last run.
std::vector<int64_t> ndarray_var_versions_;
std::vector<uint32_t> immutable_data_indices_;
std::vector<uint32_t> mutable_data_indices_;
std::vector<std::string> immutable_data_names_;
std::vector<std::string> mutable_data_names_;
//std::vector<uint32_t> immutable_data_indices_;
//std::vector<uint32_t> mutable_data_indices_;
//std::vector<std::string> immutable_data_names_;
//std::vector<std::string> mutable_data_names_;
//std::vector<std::string> input_data_names_;
std::vector<std::string> output_data_names_;
std::shared_ptr<Executor> subgraph_executor_;
std::vector<std::string> input_names_;
std::vector<std::string> output_names_;
//std::vector<int64_t> ndarray_var_versions_;
//std::shared_ptr<Executor> subgraph_executor_;
CachedOpPtr subgraph_exec_;
};

void DefaultSubgraphOperator::Forward(const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
std::vector<NDArray> tmp_inputs = inputs;
std::vector<NDArray*> input_ptrs;
input_ptrs.reserve(inputs.size());
for (auto& nd : tmp_inputs) {
input_ptrs.push_back(&nd);
}
std::vector<NDArray> tmp_outputs = outputs;
std::vector<NDArray*> output_ptrs;
for (auto& nd : tmp_outputs) {
output_ptrs.push_back(&nd);
}
subgraph_exec_->Forward(subgraph_exec_, input_ptrs, output_ptrs);
}

OpStatePtr CreateSubgraphOpState(const NodeAttrs& attrs,
Expand Down

0 comments on commit 7e96475

Please sign in to comment.