Skip to content

Commit

Permalink
Optimize move semantics of NodeEntry
Browse files Browse the repository at this point in the history
apache/tvm#2576
Making copies of shared_ptr is more expensive than moving.
This PR reduces lock contention by using move semantics in NNVM nodes
making also more convenient to construct NodeEntry classes in the code
due to the added ctors.

Update NDarray with NodeEntry constructors and refine initializer lists.

Sync gradient.cc with tvm
  • Loading branch information
larroy committed May 21, 2019
1 parent 5854b98 commit 91f8c19
Show file tree
Hide file tree
Showing 22 changed files with 167 additions and 131 deletions.
2 changes: 1 addition & 1 deletion 3rdparty/dmlc-core
Submodule dmlc-core updated 1 files
+7 −4 src/io/filesys.cc
2 changes: 1 addition & 1 deletion 3rdparty/tvm
Submodule tvm updated from 8518c7 to 21935d
2 changes: 1 addition & 1 deletion contrib/clojure-package/test/good-test-ndarray-api.clj
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@
Defined in src/operator/nn/batch_norm.cc:L574
Defined in src/operator/nn/batch_norm.cc:L572
`data`: Input data to batch normalization
`gamma`: gamma array
Expand Down
2 changes: 1 addition & 1 deletion contrib/clojure-package/test/good-test-symbol-api.clj
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@
Defined in src/operator/nn/batch_norm.cc:L574
Defined in src/operator/nn/batch_norm.cc:L572
`data`: Input data to batch normalization (optional)
`gamma`: gamma array (optional)
Expand Down
22 changes: 22 additions & 0 deletions docs/faq/new_op.md
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,28 @@ output or nothing to calculating gradient.
For more complicated patterns, use `MakeGradNode(op_name, n, heads, dict)` to create gradient entries,
where heads are input entries to the backward op, composed from ograds and n->inputs.

When assembling a return vector of `std::vector<nnvm::NodeEntry> ret;` a common pattern would be to
either create nodes in place as in:

```
ret.emplace_back(MakeNode("zeros_like", n->attrs.name + "_xyz_backward",
{n->inputs[1]}, nullptr, &n))
```

Or create the node, modify and then move into NodeEntry's constructor if this node is not to be used
again. This avoids uneccessary copies of the shared_ptr.

```
for (size_t i = 0; i < n->inputs.size(); ++i) {
nnvm::NodePtr node = nnvm::Node::Create();
node->attrs.op = copy_op;
node->inputs = {ograds[0]};
ret.emplace_back(std::move(node));
}
```

The first case uses RVO and the second in place construction.

#### FCompute\<xpu\>

Simple operators can register FCompute<xpu> with `.set_attr<FCompute>("FCompute<cpu>", ...)` and `.set_attr<FCompute>("FCompute<gpu>", ...)` for both CPU and (optionally) GPU computation.
Expand Down
53 changes: 32 additions & 21 deletions include/mxnet/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@ class MKLDNNMemory;
class NDArray {
public:
/*! \brief default constructor */
NDArray() {
NDArray()
: entry_(nullptr) {
}
/*!
* \brief constructs a new dynamic NDArray
Expand All @@ -94,8 +95,10 @@ class NDArray {
NDArray(const mxnet::TShape &shape, Context ctx,
bool delay_alloc = false, int dtype = mshadow::default_type_flag)
: ptr_(std::make_shared<Chunk>(shape, ctx, delay_alloc, dtype)),
shape_(shape), dtype_(dtype), storage_type_(kDefaultStorage),
entry_({nullptr, 0, 0}) {
shape_(shape),
dtype_(dtype),
storage_type_(kDefaultStorage),
entry_(nullptr) {
}
/*! \brief constructor for NDArray with storage type
*/
Expand All @@ -109,11 +112,12 @@ class NDArray {
* \param ctx context of NDArray
* \param dtype data type of this ndarray
*/
explicit NDArray(Context ctx, int dtype = mshadow::default_type_flag) {
ptr_ = std::make_shared<Chunk>(mxnet::TShape(mshadow::Shape1(0)), ctx, true, dtype);
dtype_ = dtype;
storage_type_ = kDefaultStorage;
entry_ = {nullptr, 0, 0};
explicit NDArray(Context ctx, int dtype = mshadow::default_type_flag)
: ptr_(std::make_shared<Chunk>(mxnet::TShape(mshadow::Shape1(0)), ctx, true, dtype)),
shape_(),
dtype_(dtype),
storage_type_(kDefaultStorage),
entry_(nullptr) {
}
/*!
* \brief constructing a static NDArray that shares data with TBlob
Expand All @@ -123,9 +127,11 @@ class NDArray {
* \param dev_id the device id this tensor sits at
*/
NDArray(const TBlob &data, int dev_id)
: ptr_(std::make_shared<Chunk>(data, dev_id)), shape_(data.shape_),
dtype_(data.type_flag_), storage_type_(kDefaultStorage),
entry_({nullptr, 0, 0}) {
: ptr_(std::make_shared<Chunk>(data, dev_id)),
shape_(data.shape_),
dtype_(data.type_flag_),
storage_type_(kDefaultStorage),
entry_(nullptr) {
}

/*!
Expand All @@ -137,20 +143,22 @@ class NDArray {
* \param deleter the function pointer of custom deleter
*/
NDArray(const TBlob &data, int dev_id, const std::function<void()>& deleter)
: ptr_(new Chunk(data, dev_id),
[deleter](Chunk *p) {
deleter(); // call custom deleter
delete p; // delete Chunk object
: ptr_(new Chunk(data, dev_id), [deleter](Chunk *p) {
deleter(); // call custom deleter
delete p; // delete Chunk object
}),
shape_(data.shape_),
dtype_(data.type_flag_), storage_type_(kDefaultStorage),
entry_({nullptr, 0, 0}) {
entry_(nullptr) {
}

/*! \brief create ndarray from shared memory */
NDArray(int shared_pid, int shared_id, const mxnet::TShape& shape, int dtype)
: ptr_(std::make_shared<Chunk>(shared_pid, shared_id, shape, dtype)), shape_(shape),
dtype_(dtype), storage_type_(kDefaultStorage), entry_({nullptr, 0, 0}) {
: ptr_(std::make_shared<Chunk>(shared_pid, shared_id, shape, dtype)),
shape_(shape),
dtype_(dtype),
storage_type_(kDefaultStorage),
entry_(nullptr) {
}

/*!
Expand All @@ -165,8 +173,11 @@ class NDArray {
*/
NDArray(const NDArrayStorageType stype, const mxnet::TShape &shape,
const TBlob &data, const std::vector<TBlob> &aux_data, int dev_id)
: ptr_(std::make_shared<Chunk>(stype, data, aux_data, dev_id)), shape_(shape),
dtype_(data.type_flag_), storage_type_(stype), entry_({nullptr, 0, 0}) {
: ptr_(std::make_shared<Chunk>(stype, data, aux_data, dev_id)),
shape_(shape),
dtype_(data.type_flag_),
storage_type_(stype),
entry_(nullptr) {
}
/*!
* \brief initialize the NDArray, assuming it is not assigned a meaningful shape before
Expand Down Expand Up @@ -642,7 +653,7 @@ class NDArray {
*/
NDArray Detach() const {
NDArray ret(*this);
ret.entry_ = nnvm::NodeEntry{nullptr, 0, 0};
ret.entry_ = nnvm::NodeEntry(nullptr);
return ret;
}

Expand Down
2 changes: 1 addition & 1 deletion src/c_api/c_api_function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ std::vector<nnvm::NodeEntry> Gradient(

std::vector<nnvm::NodeEntry> ret;
for (uint32_t i = 0; i < g->num_outputs(); ++i) {
ret.emplace_back(nnvm::NodeEntry{g, i, 0});
ret.emplace_back(g, i, 0);
}

return ret;
Expand Down
11 changes: 6 additions & 5 deletions src/executor/graph_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -223,11 +223,12 @@ nnvm::NodeEntry AggregateGradient(std::vector<nnvm::NodeEntry>&& v) {
ng->attrs.op = Op::Get("_zeros_without_dtype");
ng->attrs.name = "zeros_without_dtype";
ng->attrs.op->attr_parser(&(ng->attrs));
return nnvm::NodeEntry{ng, 0, 0};
return nnvm::NodeEntry(std::move(ng), 0, 0);
}

// remove zero in the sum. at least keep 1.
auto begin = std::remove_if(v.begin(), v.end(), [](const nnvm::NodeEntry& nodeEntry) {
CHECK(nodeEntry.node);
return nodeEntry.node->op() == zeros_op || nodeEntry.node->op() == zeros_like_op;
});
if (begin == v.begin()) ++begin;
Expand All @@ -244,7 +245,7 @@ nnvm::NodeEntry AggregateGradient(std::vector<nnvm::NodeEntry>&& v) {
sum_node->attrs.dict["num_args"] = std::to_string(v.size());
sum_node->attrs.op->attr_parser(&(sum_node->attrs));
sum_node->inputs = std::move(v);
return nnvm::NodeEntry{sum_node, 0, 0};
return nnvm::NodeEntry(std::move(sum_node), 0, 0);
} else {
// use a stream line of plus instead
nnvm::NodeEntry ret = v[0];
Expand Down Expand Up @@ -274,7 +275,7 @@ nnvm::NodeEntry AggregateGradient(std::vector<nnvm::NodeEntry>&& v) {
x->attrs.op = ewise_plus_op;
x->attrs.name = os.str();
x->inputs = {ret, v[i]};
ret = nnvm::NodeEntry{x, 0, 0};
ret = nnvm::NodeEntry(std::move(x), 0, 0);
}
// identity node is used to avoid exposure of dummy plus node
// when its output get assigned to another space.
Expand Down Expand Up @@ -323,15 +324,15 @@ nnvm::Graph GraphExecutor::InitFullGraph(nnvm::Symbol symbol,
}
if (!need_grad_) return g;
for (size_t i = 0; i < g.outputs.size(); ++i) {
NodeEntry ngrad{nnvm::Node::Create(), 0, 0};
NodeEntry ngrad(nnvm::Node::Create(), 0, 0);
head_grad_entry_.emplace_back(AttrHint(ngrad, g.outputs[i]));
head_grad_map_[ngrad.node.get()] = i;
}
std::vector<NodePtr> args = symbol.ListInputs(nnvm::Symbol::kReadOnlyArgs);
std::vector<NodeEntry> xs;
for (size_t i = 0; i < grad_req_types.size(); ++i) {
if (grad_req_types[i] != kNullOp) {
xs.emplace_back(NodeEntry{args[i], 0, 0});
xs.emplace_back(args[i]);
}
}

Expand Down
4 changes: 3 additions & 1 deletion src/executor/infer_graph_attr_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -628,7 +628,9 @@ nnvm::Graph InferShapeAttr(nnvm::Graph &&ret,
}
if (dispatch_mode_name) {
for (size_t i = node_start; i < node_end; i++) {
if (dispatch_modes[i] == DispatchMode::kUndefined) ++num_unknown;
if (dispatch_modes[i] == DispatchMode::kUndefined) {
++num_unknown;
}
}
}
++i;
Expand Down
53 changes: 27 additions & 26 deletions src/imperative/cached_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ CachedOp::CachedOp(
using namespace nnvm;
using namespace imperative;
static const std::vector<const Op*> zero_ops{Op::Get("zeros_like"), Op::Get("_zeros")};
static const auto _copy = Op::Get("_copy");
static const auto _copy_op = Op::Get("_copy");
config_.Init(flags);

if (config_.static_shape) {
Expand All @@ -107,21 +107,21 @@ CachedOp::CachedOp(

// construct forward graph
{
NodeEntryMap<int> dedup_out;
for (const auto& i : sym.outputs) {
if (dedup_out.count(i)) {
NodeEntryMap<size_t> dedup_out;
for (const NodeEntry& nodeEntry : sym.outputs) {
if (dedup_out.find(nodeEntry) != dedup_out.end()) {
NodePtr copy_node = Node::Create();
copy_node->attrs.op = _copy;
copy_node->attrs.op = _copy_op;
copy_node->attrs.name =
i.node->attrs.name + "_copy" + std::to_string(dedup_out[i]++);
copy_node->inputs.emplace_back(i);
if (_copy->attr_parser != nullptr) {
_copy->attr_parser(&(copy_node->attrs));
nodeEntry.node->attrs.name + "_copy" + std::to_string(dedup_out[nodeEntry]++);
copy_node->inputs.emplace_back(nodeEntry);
if (_copy_op->attr_parser != nullptr) {
_copy_op->attr_parser(&(copy_node->attrs));
}
fwd_graph_.outputs.emplace_back(copy_node, 0, 0);
fwd_graph_.outputs.emplace_back(std::move(copy_node));
} else {
dedup_out.insert({i, 0});
fwd_graph_.outputs.push_back(i);
dedup_out.emplace(nodeEntry, 0);
fwd_graph_.outputs.push_back(nodeEntry);
}
}
const auto& idx = fwd_graph_.indexed_graph();
Expand All @@ -143,14 +143,15 @@ CachedOp::CachedOp(

// Set params
{
const auto& idx = fwd_graph_.indexed_graph();
const auto& indexed_graph = fwd_graph_.indexed_graph();
if (config_.data_indices.ndim() || config_.param_indices.ndim()) {
CHECK_EQ(config_.data_indices.ndim() + config_.param_indices.ndim(),
idx.input_nodes().size());
indexed_graph.input_nodes().size());
} else {
std::vector<uint32_t> tmp;
for (size_t i = 0; i < idx.input_nodes().size(); ++i) {
tmp.push_back(i);
tmp.reserve(indexed_graph.input_nodes().size());
for (size_t i = 0; i < indexed_graph.input_nodes().size(); ++i) {
tmp.emplace_back(i);
}
config_.data_indices.assign(tmp.begin(), tmp.end());
}
Expand All @@ -159,20 +160,20 @@ CachedOp::CachedOp(
// construct backward graph
{
ograd_entries_.reserve(fwd_graph_.outputs.size());
for (size_t i = 0; i < fwd_graph_.outputs.size(); ++i) {
ograd_entries_.emplace_back(NodeEntry{Node::Create(), 0, 0});
}
for (size_t i = 0; i < fwd_graph_.outputs.size(); ++i)
ograd_entries_.emplace_back(Node::Create());

std::vector<NodeEntry> xs;
const auto& idx = fwd_graph_.indexed_graph();
for (size_t i = 0; i < idx.input_nodes().size(); ++i) {
auto nid = idx.input_nodes()[i];
if (idx.mutable_input_nodes().count(nid)) continue;
const IndexedGraph& indexed_graph = fwd_graph_.indexed_graph();
for (size_t i = 0; i < indexed_graph.input_nodes().size(); ++i) {
const uint32_t node_id = indexed_graph.input_nodes()[i];
if (indexed_graph.mutable_input_nodes().count(node_id))
continue;
fwd_input_to_grad_output_[i] = xs.size();
xs.emplace_back(NodeEntry{idx[nid].weak_ref.lock(), 0, 0});
xs.emplace_back(indexed_graph[node_id].weak_ref.lock());
}

CHECK_GT(xs.size(), 0)
CHECK(!xs.empty())
<< "There are no inputs in computation graph that require gradients.";

grad_graph_ = pass::MXGradient(
Expand All @@ -199,7 +200,7 @@ CachedOp::CachedOp(
}

auto full_ref_count = fwd_graph_.GetAttr<std::vector<uint32_t> >("forward_ref_count");
for (size_t i = 0; i < num_forward_entries; ++i) full_ref_count[i] += ref_count[i];
for (size_t i = 0; i < num_forward_entries; ++i) full_ref_count.at(i) += ref_count[i];
fwd_graph_.attrs["full_ref_count"] =
std::make_shared<dmlc::any>(std::move(full_ref_count));

Expand Down
2 changes: 1 addition & 1 deletion src/imperative/imperative.cc
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ std::vector<NDArray*> Imperative::Backward(
auto node = Node::Create();
node->attrs.op = copy_op;
node->inputs.push_back(e);
graph.outputs.emplace_back(node, 0, 0);
graph.outputs.emplace_back(std::move(node));
} else {
graph.outputs.push_back(e);
}
Expand Down
6 changes: 3 additions & 3 deletions src/ndarray/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ namespace mxnet {
NDArray::NDArray(const NDArrayStorageType stype, const mxnet::TShape &shape, Context ctx,
bool delay_alloc, int dtype, std::vector<int> aux_types,
mxnet::ShapeVector aux_shapes, mxnet::TShape storage_shape) : shape_(shape),
dtype_(dtype), storage_type_(stype), entry_({nullptr, 0, 0}) {
dtype_(dtype), storage_type_(stype), entry_(nullptr) {
// Assign default aux types if not given
if (aux_types.size() == 0
&& stype != kDefaultStorage) {
Expand Down Expand Up @@ -171,7 +171,7 @@ nnvm::Symbol NDArray::get_autograd_symbol() const {
#if MXNET_USE_MKLDNN == 1

NDArray::NDArray(mkldnn::memory::primitive_desc mem_pd)
: storage_type_(kDefaultStorage), entry_({nullptr, 0, 0}) {
: storage_type_(kDefaultStorage), entry_(nullptr) {
auto mem_desc = mem_pd.desc();
shape_ = mxnet::TShape(mem_desc.data.dims, mem_desc.data.dims + mem_desc.data.ndims);
dtype_ = get_mxnet_type(mem_desc.data.data_type);
Expand All @@ -181,7 +181,7 @@ NDArray::NDArray(mkldnn::memory::primitive_desc mem_pd)
}

NDArray::NDArray(const std::shared_ptr<mkldnn::memory> &mkldnn_mem)
: storage_type_(kDefaultStorage), entry_({nullptr, 0, 0}) {
: storage_type_(kDefaultStorage), entry_(nullptr) {
auto mem_pd = mkldnn_mem->get_primitive_desc();
auto mem_desc = mem_pd.desc();
shape_ = mxnet::TShape(mem_desc.data.dims, mem_desc.data.dims + mem_desc.data.ndims);
Expand Down
Loading

0 comments on commit 91f8c19

Please sign in to comment.