Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Backport #17002, #17068 and #17114 to 1.6 branch (#17137)
Browse files Browse the repository at this point in the history
* Improve the speed of the pointwise fusion graph pass (#17114)

* Debug the long startup time

* Optimize backward fusion

* Figure out why the fusion pass is called twice

* Cleaning

* Small optimization

* [BUGFIX] Fix trainer param order (#17068)

* fix trainer param order

* Update trainer.py

* Update trainer.py

* Update trainer.py

* [reproducibility] multi_sum_sq review, AtomicAdd removal (#17002)

* Update multi_sum_sq to avoid AtomicAdd

* Add specific test for multi_sum_sq

* Add a determism test and lint issues

* better test for cheching op is deterministic

* Follow MXNet letters case format

* Reduce dimensions of tensors in the test

Co-authored-by: Haibin Lin <[email protected]>
Co-authored-by: MoisesHer <[email protected]>
  • Loading branch information
3 people committed Dec 20, 2019
1 parent ac6ed2c commit 0c6f49f
Show file tree
Hide file tree
Showing 8 changed files with 214 additions and 97 deletions.
5 changes: 4 additions & 1 deletion python/mxnet/gluon/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,11 @@ class Trainer(object):
"""
def __init__(self, params, optimizer, optimizer_params=None, kvstore='device',
compression_params=None, update_on_kvstore=None):
param_list = []
if isinstance(params, (dict, ParameterDict)):
params = list(params.values())
for key in sorted(list(params.keys())):
param_list.append(params[key])
params = param_list
if not isinstance(params, (list, tuple)):
raise ValueError(
"First argument must be a list or dict of Parameters, " \
Expand Down
98 changes: 64 additions & 34 deletions src/executor/simple_partition_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,57 +102,87 @@ class BidirectionalGraph {
std::vector<std::unordered_set<Node*>> get_subsets(FCompatible is_compatible) {
std::vector<std::unordered_set<Node*>> subgraphs;
std::unordered_set<Node*> incomp_set;
std::unordered_set<Node*> all_set(nodes.size());
std::vector<PairSet> separation_sets;
std::vector<std::pair<bool, PairSet>> separation_sets;
// Check each node for compatibility
// and, if it is incompatible, mark nodes
// on each side of it as not possible to be
// in the same subset
for (Node& node : nodes) {
if (!is_compatible(node.nnvmptr)) {
incomp_set.insert(&node);
std::unordered_set<Node*> in_graph;
std::unordered_set<Node*> out_graph;
std::vector<Node*> dummy_head;
dummy_head.emplace_back(&node);
DFS(dummy_head, false, [&out_graph, &is_compatible](Node* node) {
if (is_compatible(node->nnvmptr))
out_graph.insert(node);
});
DFS(dummy_head, true, [&in_graph, is_compatible](Node* node) {
if (is_compatible(node->nnvmptr))
in_graph.insert(node);
});
if (!(in_graph.empty() || out_graph.empty()))
separation_sets.push_back(std::make_pair(in_graph, out_graph));
}
all_set.emplace(&node);
}
IncompMap incomp_map;
std::unordered_set<Node*> comp_set;
comp_set.insert(all_set.begin(), all_set.end());
for (Node* n : incomp_set) {
comp_set.erase(n);
for (Node& node : nodes) {
if (incomp_set.count(&node) != 0) {
// Check if all your inputs are incompatible too.
// If so, then your separation set does not matter,
// because it will covered by the sets of your inputs
bool inside_node = true;
for (Node* input : node.inputs) {
if (incomp_set.count(input) == 0) {
inside_node = false;
}
}
if (!inside_node) {
std::unordered_set<Node*> in_graph;
std::unordered_set<Node*> out_graph;
std::vector<Node*> dummy_head;
dummy_head.emplace_back(&node);
DFS(dummy_head, false, [&out_graph](Node* node) {
out_graph.insert(node);
});
DFS(dummy_head, true, [&in_graph](Node* node) {
in_graph.insert(node);
});
separation_sets.push_back(std::make_pair(true,
std::make_pair(in_graph, out_graph)));
} else {
separation_sets.push_back(std::make_pair(false, PairSet()));
}
} else {
separation_sets.push_back(std::make_pair(false, PairSet()));
}
}
IncompMap incomp_map;
// For each node construct the map of nodes that cannot be in
// the same subset
for (Node* n : comp_set) {
for (PairSet p : separation_sets) {
if (p.first.count(n)) {
incomp_map[n].insert(p.second.begin(), p.second.end());
} else if (p.second.count(n)) {
incomp_map[n].insert(p.first.begin(), p.first.end());
index_t num_nodes = nodes.size();
for (index_t i = 0; i < num_nodes; ++i) {
const auto n = &(nodes[i]);
if (incomp_set.count(n) == 0) {
for (index_t j = i + 1; j < num_nodes; ++j) {
const auto& sep_set_pair = separation_sets[j];
if (sep_set_pair.first && incomp_map[n].count(&nodes[j]) == 0) {
const auto& p = sep_set_pair.second;
if (p.first.count(n)) {
incomp_map[n].insert(p.second.begin(), p.second.end());
} else if (p.second.count(n)) {
incomp_map[n].insert(p.first.begin(), p.first.end());
}
}
}
for (index_t j = i - 1; j >= 0; --j) {
const auto& sep_set_pair = separation_sets[j];
if (sep_set_pair.first && incomp_map[n].count(&nodes[j]) == 0) {
const auto& p = sep_set_pair.second;
if (p.first.count(n)) {
incomp_map[n].insert(p.second.begin(), p.second.end());
} else if (p.second.count(n)) {
incomp_map[n].insert(p.first.begin(), p.first.end());
}
}
}
for (Node* incomp_n : incomp_set) {
incomp_map[n].erase(incomp_n);
}
}
for (Node* incomp_n : incomp_set) {
incomp_map[n].erase(incomp_n);
}
}
std::unordered_set<Node*> unused_set;
unused_set.reserve(comp_set.size());

for (auto& n : comp_set) {
unused_set.insert(n);
for (auto& n : nodes) {
if (incomp_set.count(&n) == 0) {
unused_set.insert(&n);
}
}
std::unordered_set<Node*> visited;
std::deque<Node*> stack(outputs.begin(), outputs.end());
Expand Down
22 changes: 12 additions & 10 deletions src/imperative/cached_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1032,17 +1032,19 @@ OpStatePtr CachedOp::Forward(
CHECK_EQ(inputs.size(), num_inputs());

Context default_ctx = inputs[0]->ctx();
auto state_ptr = GetCachedOpState(default_ctx);
auto& state = state_ptr.get_state<CachedOpState>();
{
auto state_ptr = GetCachedOpState(default_ctx);
auto& state = state_ptr.get_state<CachedOpState>();

const auto& idx = state.info.fwd_graph.indexed_graph();
for (size_t i = 0; i < inputs.size(); ++i) {
CHECK_EQ(inputs[i]->ctx(), default_ctx)
<< "CachedOp requires all inputs to live on the same context. But "
<< idx[idx.input_nodes()[0]].source->attrs.name
<< " is on " << default_ctx << " while "
<< idx[idx.input_nodes()[i]].source->attrs.name
<< " is on " << inputs[i]->ctx();
const auto& idx = state.info.fwd_graph.indexed_graph();
for (size_t i = 0; i < inputs.size(); ++i) {
CHECK_EQ(inputs[i]->ctx(), default_ctx)
<< "CachedOp requires all inputs to live on the same context. But "
<< idx[idx.input_nodes()[0]].source->attrs.name
<< " is on " << default_ctx << " while "
<< idx[idx.input_nodes()[i]].source->attrs.name
<< " is on " << inputs[i]->ctx();
}
}

int prev_bulk_size = Engine::Get()->set_bulk_size(config_.forward_bulk_size);
Expand Down
10 changes: 7 additions & 3 deletions src/operator/contrib/multi_sum_sq-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
* Copyright (c) 2019 by Contributors
* \file multi_l2_norm-inl.h
* \brief vectorized L2 norm over multiple arrays operators
* \author Clement Fuji Tsang, Andrei Ivanov
* \author Clement Fuji Tsang, Andrei Ivanov, Moises Hernandez
*/


Expand All @@ -32,6 +32,10 @@
#include <vector>
#include "../operator_common.h"

namespace multi_sum_sq {
enum MultiSumSqUpdateResource {kTempSpace};
} // namespace multi_sum_sq

namespace mxnet {
namespace op {

Expand Down Expand Up @@ -80,7 +84,7 @@ inline bool MultiSumSqType(const NodeAttrs& attrs,

template<typename xpu>
void MultiSumSqRun(const std::vector<TBlob> &inputs, int nInputs,
float *out_ptr, mshadow::Stream<xpu> *s);
float *out_ptr, const OpContext &ctx);

template<typename xpu>
void MultiSumSq(const nnvm::NodeAttrs& attrs,
Expand All @@ -91,7 +95,7 @@ void MultiSumSq(const nnvm::NodeAttrs& attrs,
auto s = ctx.get_stream<xpu>();
const auto& p = dmlc::get<MultiSumSqParam>(attrs.parsed);
float* out_ptr = outputs[0].FlatTo2D<xpu, float>(s).dptr_;
MultiSumSqRun<xpu>(inputs, p.num_arrays, out_ptr, s);
MultiSumSqRun<xpu>(inputs, p.num_arrays, out_ptr, ctx);
}

} // namespace op
Expand Down
20 changes: 12 additions & 8 deletions src/operator/contrib/multi_sum_sq.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
* Copyright (c) 2019 by Contributors
* \file multi_sum_sq.cc
* \brief vectorized sum or squared over multiple arrays operators
* \author Clement Fuji Tsang, Andrei Ivanov
* \author Clement Fuji Tsang, Andrei Ivanov, Moises Hernandez
*/

#include "./multi_sum_sq-inl.h"
Expand Down Expand Up @@ -52,31 +52,35 @@ NNVM_REGISTER_OP(multi_sum_sq)
return ret;
})
.set_attr<FCompute>("FCompute<cpu>", MultiSumSq<cpu>)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.add_argument("data", "NDArray-or-Symbol[]", "Arrays")
.add_arguments(MultiSumSqParam::__FIELDS__());

template<typename DType>
inline void CalcSumSq(const std::vector<TBlob> &inputs, int nInputs,
inline void CalcSumSq(const std::vector<TBlob> &inputs, int n_inputs,
float *out_ptr, mshadow::Stream<cpu> *s) {
int i;
size_t j;
#pragma omp parallel for private(i, j)
for (i = 0; i < nInputs; ++i) { // array index in inputs
for (i = 0; i < n_inputs; ++i) { // array index in inputs
float sum = 0;
const auto address = inputs[i].FlatTo2D<cpu, DType>(s).dptr_;
const auto jMax = inputs[i].shape_.Size();
for (j = 0; j < jMax; ++j)
const auto j_max = inputs[i].shape_.Size();
for (j = 0; j < j_max; ++j)
sum += address[j] * address[j];

out_ptr[i] = sum;
}
}

template<>
void MultiSumSqRun<cpu>(const std::vector<TBlob> &inputs, int nInputs,
float *out_ptr, mshadow::Stream<cpu> *s) {
void MultiSumSqRun<cpu>(const std::vector<TBlob> &inputs, int n_inputs,
float *out_ptr, const OpContext &ctx) {
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType,
CalcSumSq<DType>(inputs, nInputs, out_ptr, s);
CalcSumSq<DType>(inputs, n_inputs, out_ptr, ctx.get_stream<cpu>());
)
}

Expand Down
Loading

0 comments on commit 0c6f49f

Please sign in to comment.