Skip to content

Commit

Permalink
Optimize move semantics of NodeEntry reducing copies of shared_ptr wh…
Browse files Browse the repository at this point in the history
…ich causes atomic contention (#2576)
  • Loading branch information
larroy authored and tqchen committed Feb 8, 2019
1 parent 390b744 commit 2da23bd
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 12 deletions.
25 changes: 23 additions & 2 deletions nnvm/include/nnvm/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,18 @@ using NodePtr = std::shared_ptr<Node>;

/*! \brief an entry that represents output data from a node */
struct NodeEntry {
NodeEntry(NodePtr node, uint32_t index, uint32_t version):
node(std::move(node)),
index(index),
version(version)
{}

NodeEntry():
node(),
index(),
version()
{}

/*! \brief the source node of this data */
NodePtr node;
/*! \brief index of output from the source. */
Expand Down Expand Up @@ -113,6 +125,11 @@ struct NodeAttrs {
*/
class NNVM_DLL Node {
public:
Node() = default;
Node(const Op* op, const std::string& name) {
this->attrs.op = op;
this->attrs.name = name;
}
/*! \brief The attributes in the node. */
NodeAttrs attrs;
/*! \brief inputs to this node */
Expand Down Expand Up @@ -142,7 +159,10 @@ class NNVM_DLL Node {
* \brief create a new empty shared_ptr of Node.
* \return a created empty node.
*/
static NodePtr Create();
template<class ...Args>
static NodePtr Create(Args&&... args) {
return std::make_shared<Node>(std::forward<Args>(args)...);
}
};

/*!
Expand All @@ -167,13 +187,14 @@ inline NodeEntry MakeNode(
p->attrs.op->attr_parser(&(p->attrs));
}
p->inputs = std::move(inputs);
return NodeEntry{p, 0, 0};
return NodeEntry(p, 0, 0);
}

// implementation of functions.
inline const Op* Node::op() const {
return this->attrs.op;
}

inline bool Node::is_variable() const {
return this->op() == nullptr;
}
Expand Down
4 changes: 0 additions & 4 deletions nnvm/src/core/node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,4 @@ Node::~Node() {
}
}

NodePtr Node::Create() {
return std::make_shared<Node>();
}

} // namespace nnvm
8 changes: 4 additions & 4 deletions nnvm/src/core/symbolic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -601,8 +601,8 @@ Symbol Symbol::CreateFunctor(const Op* op,
if (fnum_vis_output.count(n->op())) {
nout = fnum_vis_output[n->op()](n->attrs);
}
for (uint32_t i = 0; i < nout; ++i) {
s.outputs.emplace_back(NodeEntry{n, i, 0});
for (size_t i = 0; i < nout; i++) {
s.outputs.emplace_back(n, i, 0);
}
return s;
}
Expand All @@ -618,7 +618,7 @@ Symbol Symbol::CreateFunctor(const NodeAttrs& attrs) {
nout = fnum_vis_output[n->op()](n->attrs);
}
for (uint32_t i = 0; i < nout; ++i) {
s.outputs.emplace_back(NodeEntry{n, i, 0});
s.outputs.emplace_back(n, i, 0);
}
return s;
}
Expand All @@ -633,7 +633,7 @@ Symbol Symbol::CreateGroup(const std::vector<Symbol> &symbols) {

Symbol Symbol::CreateVariable(const std::string& name) {
Symbol s;
s.outputs.emplace_back(NodeEntry{CreateVariableNode(name), 0, 0});
s.outputs.emplace_back(CreateVariableNode(name), 0, 0);
return s;
}

Expand Down
4 changes: 2 additions & 2 deletions nnvm/src/pass/correct_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,10 @@ nnvm::Graph CorrectLayout(nnvm::Graph src) {
nnvm::NodePtr tnode = CreateLayoutTransformNode(produce, request);
tnode->attrs.name = idx[e.node_id].source->attrs.name + "_" + request.name();
tnode->inputs.emplace_back(new_node->inputs[i]);
nnvm::NodeEntry tnode_output{tnode, 0, 0};
nnvm::NodeEntry tnode_output(std::move(tnode), 0, 0);
new_node->inputs[i] = tnode_output;
// layout produced by LayoutTransformNode
new_layouts[tnode.get()] = {request};
new_layouts[tnode_output.node.get()] = {request};
} else if (!produce.defined()) {
// do reverse infer
new_layouts[in.get()][e.index] = request;
Expand Down

0 comments on commit 2da23bd

Please sign in to comment.