From 4827de8e89bf45a91295d81aba21a95bfa8ababd Mon Sep 17 00:00:00 2001 From: Bojian Zheng Date: Thu, 21 May 2020 12:14:55 -0400 Subject: [PATCH] Improve the backward mirroring implementation (#18228) --- ci/windows/test_py3_cpu.ps1 | 6 + ci/windows/test_py3_gpu.ps1 | 7 + docs/static_site/src/pages/api/faq/env_var.md | 6 +- example/image-classification/README.md | 11 +- python/mxnet/rnn/rnn_cell.py | 5 + src/executor/exec_pass.h | 37 +- src/executor/graph_executor.cc | 128 +++- src/executor/graph_executor.h | 8 +- src/imperative/cached_op.h | 2 +- src/imperative/imperative.cc | 2 +- src/nnvm/gradient.cc | 709 ++++++++++++++---- src/nnvm/plan_memory.cc | 15 +- src/operator/nn/activation-inl.h | 9 +- src/operator/nn/activation.cc | 50 +- src/operator/nn/activation.cu | 46 +- src/operator/nn/cudnn/cudnn_batch_norm-inl.h | 16 +- tests/python/unittest/test_memory_opt.py | 202 +++++ 17 files changed, 1009 insertions(+), 250 deletions(-) create mode 100644 tests/python/unittest/test_memory_opt.py diff --git a/ci/windows/test_py3_cpu.ps1 b/ci/windows/test_py3_cpu.ps1 index 870236a421fd..5121a53d4314 100644 --- a/ci/windows/test_py3_cpu.ps1 +++ b/ci/windows/test_py3_cpu.ps1 @@ -36,3 +36,9 @@ if ($LastExitCode -ne 0) { Throw ("Error running serial train tests, python exit $env:MXNET_SAFE_ACCUMULATION=1 C:\Python37\python.exe -m pytest -v --durations=50 --cov-report xml:tests_unittest.xml --cov-append tests\python\unittest\test_operator.py::test_norm if ($LastExitCode -ne 0) { Throw ("Error running unittest, python exited with status code " + ('{0:X}' -f $LastExitCode)) } + +# Similar to the MXNET_SAFE_ACCUMULATION test case above. Need to explicitly +# set the environment variable for MXNET_MEMORY_OPT. +$env:MXNET_MEMORY_OPT=1 +C:\Python37\python.exe -m pytest -v --durations=50 --cov-report xml:tests_unittest.xml --cov-append tests\python\unittest\test_memory_opt.py +if ($LastExitCode -ne 0) { Throw ("Error running unittest, python exited with status code " + ('{0:X}' -f $LastExitCode)) } diff --git a/ci/windows/test_py3_gpu.ps1 b/ci/windows/test_py3_gpu.ps1 index 9e0e7edf9c05..9f200f31b540 100644 --- a/ci/windows/test_py3_gpu.ps1 +++ b/ci/windows/test_py3_gpu.ps1 @@ -44,9 +44,16 @@ C:\Python37\python.exe -m pytest -v -m 'not serial' -n 4 --durations=50 --cov-re if ($LastExitCode -ne 0) { Throw ("Error running parallel tests, python exited with status code " + ('{0:X}' -f $LastExitCode)) } C:\Python37\python.exe -m pytest -v -m 'serial' --durations=50 --cov-report xml:tests_train.xml --cov-append tests\python\train if ($LastExitCode -ne 0) { Throw ("Error running serial tests, python exited with status code " + ('{0:X}' -f $LastExitCode)) } + # Adding this extra test since it's not possible to set env var on the fly in Windows. $env:MXNET_SAFE_ACCUMULATION=1 C:\Python37\python.exe -m pytest -v --durations=50 --cov-report xml:tests_operator.xml --cov-append tests\python\gpu\test_operator_gpu.py::test_norm if ($LastExitCode -ne 0) { Throw ("Error running tests, python exited with status code " + ('{0:X}' -f $LastExitCode)) } C:\Python37\python.exe -m pytest -v --durations=50 --cov-report xml:tests_tvm_op.xml tests\python\gpu\test_tvm_op_gpu.py if ($LastExitCode -ne 0) { Throw ("Error running TVM op tests, python exited with status code " + ('{0:X}' -f $LastExitCode)) } + +# Similar to the MXNET_SAFE_ACCUMULATION test case above. Need to explicitly +# set the environment variable for MXNET_MEMORY_OPT. +$env:MXNET_MEMORY_OPT=1 +C:\Python37\python.exe -m pytest -v --durations=50 --cov-report xml:tests_unittest.xml --cov-append tests\python\unittest\test_memory_opt.py +if ($LastExitCode -ne 0) { Throw ("Error running memory optimization tests, python exited with status code " + ('{0:X}' -f $LastExitCode)) } diff --git a/docs/static_site/src/pages/api/faq/env_var.md b/docs/static_site/src/pages/api/faq/env_var.md index 75255210933d..9c5546b2a654 100644 --- a/docs/static_site/src/pages/api/faq/env_var.md +++ b/docs/static_site/src/pages/api/faq/env_var.md @@ -189,7 +189,7 @@ $env:MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0 - The maximum size of an NDArray slice in terms of number of parameters. - This parameter is used to slice an NDArray before synchronizing through P3Store (dist_p3). -## Memonger +## Memory Optimizations * MXNET_BACKWARD_DO_MIRROR - Values: 0(false) or 1(true) ```(default=0)``` @@ -199,6 +199,10 @@ $env:MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0 - `MXNET_BACKWARD_DO_MIRROR=1` will save 30%~50% of device memory, but retains about 95% of running speed. - One extension of `mirror` in MXNet is called [memonger technology](https://arxiv.org/abs/1604.06174), it will only use O(sqrt(N)) memory at 75% running speed. Checkout the code [here](https://github.com/dmlc/mxnet-memonger). +* MXNET_MEMORY_OPT + - Values: 0(no optimizations) or 1(highest optimization level) ```(default=0)``` + - If set to '1', various optimizations on memory consumption will be enabled. + ## Control the profiler The following environments can be used to profile the application without changing code. Execution options may affect the granularity of profiling result. If you need profiling result of every operator, please set `MXNET_EXEC_BULK_EXEC_INFERENCE`, `MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN` and `MXNET_EXEC_BULK_EXEC_TRAIN` to 0. diff --git a/example/image-classification/README.md b/example/image-classification/README.md index 78ea94eeb440..4b4a48b33ae4 100644 --- a/example/image-classification/README.md +++ b/example/image-classification/README.md @@ -366,11 +366,12 @@ An over sized batch size may result in out of GPU memory. The common error message is `cudaMalloc failed: out of memory`. Now we can - Reduce the batch size -- Set the environment variable `MXNET_BACKWARD_DO_MIRROR` to 1. It trades off - computation for memory consumption. For example, with batch size 64, - inception-v3 uses 10G memory and trains 30 image/sec on a single K80 GPU. When - mirroring is enabled, with 10G GPU memory consumption, we can run inception-v3 - using batch size 128. The cost is that the speed reduces to 27 images/sec. +- Set the environment variable `MXNET_MEMORY_OPT=1` to perform a series of + memory optimizations (e.g., trades off computation for memory consumption). + For example, with batch size 64, inception-v3 uses 10G memory and trains 30 + image/sec on a single K80 GPU. When mirroring is enabled, with 10G GPU memory + consumption, we can run inception-v3 using batch size 128. The cost is that + the speed reduces to 27 images/sec. ## History diff --git a/python/mxnet/rnn/rnn_cell.py b/python/mxnet/rnn/rnn_cell.py index ceb33d7dcf0a..f6f99e892782 100644 --- a/python/mxnet/rnn/rnn_cell.py +++ b/python/mxnet/rnn/rnn_cell.py @@ -459,6 +459,11 @@ def __call__(self, inputs, states): name='%so'%name) next_c = symbol._internal._plus(forget_gate * states[1], in_gate * in_transform, name='%sstate'%name) + next_c._set_attr(force_mirroring='0') + # Cell states are excluded from being mirrored. The reason is because + # they do not pass through the fully-connected layers and will + # significantly increase the overall mirroring depth, incurring large + # performance overhead. next_h = symbol._internal._mul(out_gate, symbol.Activation(next_c, act_type="tanh"), name='%sout'%name) diff --git a/src/executor/exec_pass.h b/src/executor/exec_pass.h index e3d2fa459bc3..270c546f0f49 100644 --- a/src/executor/exec_pass.h +++ b/src/executor/exec_pass.h @@ -273,16 +273,17 @@ namespace pass { /*! * \brief Get the gradient graph whose outputs are gradients of xs wrt to ys. * \param graph The input graph. - * \param ys The entries we want to take gradient from. - * \param xs The input to take gradient with respect to. - * \param ys_out_grad The symbol for additional gradient to be propagate back to y. - * \param aggregate_fun Aggregation function applied to aggregate the inputs. - * \param mirror_fun Optional mirror function to do mirror optimization and save memory. - * \param attr_hint_fun Optional, hint function to output a node that like src, but its attr is same as like. - * \param zero_ops Optional, list of operators that outputs a single zero array. The first one - * must be zeros_like. - * \param copy_op_str Optional, name of the copy operation required to handle duplicates - * on the edge of the graph + * \param ys The entries to take gradient from. + * \param xs The entries to take gradient with respect to. + * \param ys_out_grad The output gradients of ys. + * \param aggregate_fun The aggregation function used for summing gradients. + * \param mirror_fun The backward mirroring function that does mirroring to save memory. + * \param zero_ops The list of operators that output a single zero array, used + * for generating zero gradient nodes. The first operator must + * be zero_like. + * \param copy_op_str The name of the copy operator that handle gradient duplicates. + * \param in_arg_shapes The shapes of input arguments, used for shape inference. + * \param in_arg_dtpyes The data types of input arguments, used for data type inference. * \return A new graph, whose outputs correspond to inputs of xs. */ inline Graph MXGradient( @@ -292,27 +293,27 @@ inline Graph MXGradient( std::vector ys_out_grad, std::function&& inputs)> aggregate_fun = nullptr, std::function mirror_fun = nullptr, - std::function - attr_hint_fun = nullptr, std::vector zero_ops = std::vector(), - std::string copy_op_str = std::string()) { + std::string copy_op_str = std::string(), + mxnet::ShapeVector in_arg_shapes = mxnet::ShapeVector(), + DTypeVector in_arg_dtypes = DTypeVector()) { graph.attrs["grad_ys"] = std::make_shared(std::move(ys)); graph.attrs["grad_xs"] = std::make_shared(std::move(xs)); graph.attrs["grad_ys_out_grad"] = std::make_shared(std::move(ys_out_grad)); + graph.attrs["in_arg_shapes"] = std::make_shared(std::move(in_arg_shapes)); + graph.attrs["in_arg_dtypes"] = std::make_shared(std::move(in_arg_dtypes)); + if (aggregate_fun != nullptr) { graph.attrs["grad_aggregate_fun"] = std::make_shared(aggregate_fun); } if (mirror_fun != nullptr) { - graph.attrs["grad_mirror_fun"] = std::make_shared(mirror_fun); - } - if (attr_hint_fun != nullptr) { - graph.attrs["attr_hint_fun"] = std::make_shared(attr_hint_fun); + graph.attrs["mirror_fun"] = std::make_shared(mirror_fun); } if (zero_ops.size()) { graph.attrs["zero_ops"] = std::make_shared(std::move(zero_ops)); } if (copy_op_str != std::string()) { - graph.attrs["copy_op"] = std::make_shared(std::move(copy_op_str)); + graph.attrs["copy_op_str"] = std::make_shared(std::move(copy_op_str)); } return ApplyPass(std::move(graph), "MXGradient"); } diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index 02ce818d7fa8..a213ff040074 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -302,28 +302,15 @@ nnvm::NodeEntry AggregateGradient(std::vector&& v) { } } -template -inline ValueType get_node_attr( - const nnvm::Node& node, - const std::string& key, ValueType default_value) { - auto it = node.attrs.dict.find(key); - if (it == node.attrs.dict.end()) { - return default_value; - } else { - ValueType ret; - dmlc::parameter::FieldEntry e; - e.Init(key, &ret, ret); - e.Set(&ret, it->second); - return ret; - } -} /*! * \brief Create the graph for backward pass. * This is triggered by both simple_bind and bind flows. */ nnvm::Graph GraphExecutor::InitFullGraph(nnvm::Symbol symbol, - const std::vector& grad_req_types) { + const std::vector& grad_req_types, + const ShapeVector& in_arg_shapes, + const nnvm::DTypeVector& in_arg_dtypes) { using nnvm::ObjectPtr; using nnvm::NodeEntry; // initial information @@ -356,19 +343,28 @@ nnvm::Graph GraphExecutor::InitFullGraph(nnvm::Symbol symbol, } } - int do_mirror = dmlc::GetEnv("MXNET_BACKWARD_DO_MIRROR", 0); - auto need_mirror = [do_mirror](const nnvm::Node& node) -> int { - if (node.is_variable()) return 0; - const std::string& type = node.attrs.op->name; - if (type == "Dropout") return false; - if (get_node_attr(node, "__force_mirroring__", false)) return true; - if (do_mirror == 0) return false; - if (type == "Convolution") return false; - if (type == "FullyConnected") return false; - if (type == "Concat") return false; - if (type == "SoftmaxOutput") return false; - return true; - }; + std::function need_mirror = + [](const nnvm::Node& node) -> int { + if (node.is_variable()) return false; + const std::string& type = node.attrs.op->name; + if (type == "Dropout") return false; + // We follow the hidden key attribute "force_mirroring" if it is + // explicitly set. + auto iter = node.attrs.dict.find("__force_mirroring__"); + if (iter != node.attrs.dict.end()) { + bool do_mirror; + dmlc::parameter::FieldEntry e; + e.Init("__force_mirroring__", &do_mirror, do_mirror); + e.Set(&do_mirror, iter->second); + return do_mirror; + } + if (type == "Embedding") return false; + if (type == "Convolution") return false; + if (type == "FullyConnected") return false; + if (type == "Concat") return false; + if (type == "SoftmaxOutput") return false; + return true; + }; std::vector zero_ops; zero_ops.push_back(nnvm::Op::Get("zeros_like")); @@ -377,8 +373,12 @@ nnvm::Graph GraphExecutor::InitFullGraph(nnvm::Symbol symbol, // take gradient nnvm::Graph g_grad = nnvm::pass::MXGradient( g, symbol.outputs, xs, head_grad_entry_, - AggregateGradient, need_mirror, nullptr, - zero_ops, "_copy"); + AggregateGradient, + (dmlc::GetEnv("MXNET_BACKWARD_DO_MIRROR", 0) || + dmlc::GetEnv("MXNET_MEMORY_OPT", 0)) ? need_mirror : nullptr, + zero_ops, "_copy", + in_arg_shapes, in_arg_dtypes); + CHECK_EQ(g_grad.outputs.size(), xs.size()); for (const auto &e : g_grad.outputs) { g.outputs.push_back(e); @@ -414,8 +414,37 @@ void GraphExecutor::Init(nnvm::Symbol symbol, std::vector aux_state_ctxes(aux_states.size()); std::transform(aux_states.begin(), aux_states.end(), aux_state_ctxes.begin(), get_ctx1); + // Record the shapes and data types of the input arguments in the source graph + // (i.e., the graph prior to the Gradient pass). Such information is need by + // the backward mirroring algorithm for shape and data type inference. + nnvm::Graph src; + src.outputs = symbol.outputs; + const nnvm::IndexedGraph& src_idx = src.indexed_graph(); + const std::unordered_set& src_mutable_nodes = src_idx.mutable_input_nodes(); + size_t src_arg_top = 0, src_aux_top = 0; + ShapeVector src_arg_shapes; + nnvm::DTypeVector src_arg_dtypes; + const size_t src_num_forward_inputs = symbol.ListInputs(nnvm::Symbol::kAll).size(); + + for (size_t i = 0; i < src_num_forward_inputs; ++i) { + const uint32_t nid = src_idx.input_nodes().at(i); + + if (src_mutable_nodes.count(nid)) { + CHECK_LT(src_aux_top, aux_states.size()); + src_arg_shapes.push_back(aux_states[src_aux_top].shape()); + src_arg_dtypes.push_back(aux_states[src_aux_top].dtype()); + ++src_aux_top; + } else { + CHECK_LT(src_arg_top, in_args.size()); + src_arg_shapes.push_back(in_args[src_arg_top].shape()); + src_arg_dtypes.push_back(in_args[src_arg_top].dtype()); + ++src_arg_top; + } + } + nnvm::Graph g = InitGraph(symbol, default_ctx, ctx_map, in_arg_ctxes, - arg_grad_ctxes, aux_state_ctxes, grad_req_types); + arg_grad_ctxes, aux_state_ctxes, grad_req_types, + src_arg_shapes, src_arg_dtypes); // create arg_shapes and arg_dtypes for shape and type inferences const auto& idx = g.indexed_graph(); @@ -811,8 +840,34 @@ void GraphExecutor::Init(nnvm::Symbol symbol, std::unordered_map* shared_buffer, Executor* shared_exec, const nnvm::NodeEntryMap& feed_dict) { + // Record the shapes and data types of the input arguments in the source graph + // (i.e., the graph prior to the Gradient pass). Such information is need by + // the backward mirroring algorithm for shape and data type inference. + nnvm::Graph src; + src.outputs = symbol.outputs; + const nnvm::IndexedGraph& src_idx = src.indexed_graph(); + ShapeVector src_arg_shapes(src_idx.input_nodes().size(), TShape()); + nnvm::DTypeVector src_arg_dtypes(src_idx.input_nodes().size(), -1); + const size_t src_num_forward_inputs = symbol.ListInputs(nnvm::Symbol::kAll).size(); + + for (size_t i = 0; i < src_num_forward_inputs; ++i) { + const uint32_t nid = src_idx.input_nodes().at(i); + const std::string& name = src_idx[nid].source->attrs.name; + std::unordered_map::const_iterator + arg_shape_iter = arg_shape_map.find(name); + std::unordered_map::const_iterator + arg_dtype_iter = arg_dtype_map.find(name); + if (arg_shape_iter != arg_shape_map.end()) { + src_arg_shapes[i] = arg_shape_iter->second; + } + if (arg_dtype_iter != arg_dtype_map.end()) { + src_arg_dtypes[i] = arg_dtype_iter->second; + } + } + nnvm::Graph g = InitGraph(symbol, default_ctx, ctx_map, in_arg_ctxes, arg_grad_ctxes, - aux_state_ctxes, grad_req_types); + aux_state_ctxes, grad_req_types, + src_arg_shapes, src_arg_dtypes); // The following code of shape and dtype inferences and argument // initialization is for simple_bind only. Regular bind operation @@ -1007,9 +1062,12 @@ Graph GraphExecutor::InitGraph(nnvm::Symbol symbol, const std::vector& in_arg_ctxes, const std::vector& arg_grad_ctxes, const std::vector& aux_state_ctxes, - const std::vector& grad_req_types) { + const std::vector& grad_req_types, + const ShapeVector& in_arg_shapes, + const nnvm::DTypeVector& in_arg_dtypes) { // setup gradient - nnvm::Graph g = InitFullGraph(symbol, grad_req_types); + nnvm::Graph g = InitFullGraph(symbol, grad_req_types, + in_arg_shapes, in_arg_dtypes); #if MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC && !defined(_WIN32) if (default_ctx.dev_mask() == Context::kGPU && dmlc::GetEnv("MXNET_USE_FUSION", true)) { diff --git a/src/executor/graph_executor.h b/src/executor/graph_executor.h index 4164bb758376..ed6eeaa11f4f 100644 --- a/src/executor/graph_executor.h +++ b/src/executor/graph_executor.h @@ -188,10 +188,14 @@ class GraphExecutor : public Executor { const std::vector& in_arg_ctxes, const std::vector& arg_grad_ctxes, const std::vector& aux_state_ctxes, - const std::vector& grad_req_types); + const std::vector& grad_req_types, + const ShapeVector& in_arg_shapes, + const nnvm::DTypeVector& in_arg_dtypes); // intialize the full graph for simple bind, including gradient Graph InitFullGraph(nnvm::Symbol symbol, - const std::vector& grad_req_types); + const std::vector& grad_req_types, + const ShapeVector& in_arg_shapes, + const nnvm::DTypeVector& in_arg_dtypes); // initialize the cached operator void InitCachedOps(); // initialize the opr segments for bulk exec diff --git a/src/imperative/cached_op.h b/src/imperative/cached_op.h index 1a395574176f..702a5734b51a 100644 --- a/src/imperative/cached_op.h +++ b/src/imperative/cached_op.h @@ -167,7 +167,7 @@ void CreateBackwardGraph(nnvm::Graph* fwd_graph, *grad_graph = pass::MXGradient( *fwd_graph, fwd_graph->outputs, xs, *ograd_entries, - exec::AggregateGradient, nullptr, nullptr, + exec::AggregateGradient, nullptr, zero_ops, "_copy"); } diff --git a/src/imperative/imperative.cc b/src/imperative/imperative.cc index 14fedc93351c..c12fcee1910e 100644 --- a/src/imperative/imperative.cc +++ b/src/imperative/imperative.cc @@ -455,7 +455,7 @@ std::vector Imperative::Backward( Graph g_graph = pass::MXGradient( graph, graph.outputs, xs, ograd_entries, - exec::AggregateGradient, nullptr, nullptr, + exec::AggregateGradient, nullptr, zero_ops, "_copy"); CHECK_EQ(g_graph.outputs.size(), xs.size()); for (const auto& e : g_graph.outputs) { diff --git a/src/nnvm/gradient.cc b/src/nnvm/gradient.cc index 74cec1623800..c9dc67be74a1 100644 --- a/src/nnvm/gradient.cc +++ b/src/nnvm/gradient.cc @@ -23,178 +23,604 @@ * \brief Passes that takes gradient of the graph * This code code was modified based on mxnet codebase by Min Lin */ +#include #include #include #include + #include +#include +#include #include +#include +#include +#include +#include + +#include "../executor/exec_pass.h" + namespace nnvm { namespace pass { -namespace { -// default aggregate gradient function -// require operator zeros and elemwise_sum to be presented. -NodeEntry DefaultAggregateGradient(std::vector&& v) { - if (v.size() == 1) { - return std::move(v[0]); - } else if (v.size() == 0) { - ObjectPtr zero_node = Node::Create(); - zero_node->attrs.op = Op::Get("zeros"); - zero_node->attrs.name = "zero_grad"; - zero_node->attrs.op->attr_parser(&(zero_node->attrs)); - return NodeEntry{zero_node, 0, 0}; - } else { - ObjectPtr sum_node = Node::Create(); - sum_node->attrs.op = Op::Get("elemwise_sum"); - sum_node->inputs = std::move(v); - sum_node->attrs.name = "grad_sum"; - sum_node->attrs.dict["num_args"] = std::to_string(sum_node->inputs.size()); - sum_node->attrs.op->attr_parser(&(sum_node->attrs)); - return NodeEntry{sum_node, 0, 0}; - } -} +extern size_t MXGetDTypeSize(const int type_flag); // defined in plan_memory.cc -bool CheckGradAllZero(const std::vector& grads, - const std::vector& zero_ops) { - if (!grads.size() || !zero_ops.size()) return false; - for (const auto& g : grads) { - bool found = false; - for (const auto& op : zero_ops) { - if (g.node->op() == op) { - found = true; - break; - } - } - if (!found) return false; - } - return true; -} +namespace { -// helper entry + +/*! Auxiliary Data Structure for Gradient Entries */ struct GradEntry { -#ifdef _MSC_VER - NodeEntry sum = NodeEntry{nullptr, 0, 0}; -#else - NodeEntry sum{nullptr, 0, 0}; -#endif + NodeEntry sum = NodeEntry(nullptr, 0, 0); std::vector grads; - bool need_attr_hint{true}; }; -Graph Gradient(Graph src) { - using nnvm::FGradient; - using MirrorFun = std::function; - using AttrHintFun = std::function; +/*! + * \brief Build the backward graph from the mirror map. This function will be + * invoked twice if backward mirroring has been enabled. + */ +Graph BuildGradientGraph( + const Graph& src, + const std::vector& xs, + const std::vector& topo_order, + std::unordered_map > output_grads, + std::function mirror_fun, + const std::unordered_map& mirror_map); + +/*! + * \brief Auxiliary function that maps the forward node of the source graph to + * its corresponding node on the mirror path. + */ +inline const ObjectPtr& MapFwdNodeToMirrorPath( + const ObjectPtr& n, + const std::unordered_map& mirror_map) { + auto iter = mirror_map.find(n.get()); + if (iter == mirror_map.end() || + iter->second == nullptr) { + return n; + } + return iter->second; +} + + +Graph Gradient(Graph src) { CHECK_NE(src.attrs.count("grad_ys"), 0U) << "Gradient require grad_ys to be presented."; - CHECK_NE(src.attrs.count("grad_ys_out_grad"), 0U) - << "Gradient require grad_ys_out_grad to be presented."; CHECK_NE(src.attrs.count("grad_xs"), 0U) << "Gradient require grad_xs to be presented."; + CHECK_NE(src.attrs.count("grad_ys_out_grad"), 0U) + << "Gradient require grad_ys_out_grad to be presented."; + const std::vector& xs = + src.GetAttr >("grad_xs"); const std::vector& ys = src.GetAttr >("grad_ys"); const std::vector& ys_out_grad = src.GetAttr >("grad_ys_out_grad"); - const std::vector& xs = - src.GetAttr >("grad_xs"); - using AggFun = std::function&& inputs)>; - AggFun agg_fun = DefaultAggregateGradient; - if (src.attrs.count("grad_aggregate_fun") != 0) { - agg_fun = src.GetAttr("grad_aggregate_fun"); - } - MirrorFun mirror_fun = nullptr; - if (src.attrs.count("grad_mirror_fun") != 0) { - mirror_fun = src.GetAttr("grad_mirror_fun"); - } - AttrHintFun attr_hint_fun = nullptr; - if (src.attrs.count("attr_hint_fun") != 0) { - attr_hint_fun = src.GetAttr("attr_hint_fun"); - } - std::vector zero_ops; - if (src.attrs.count("zero_ops") != 0) { - zero_ops = src.GetAttr >("zero_ops"); - } - const Op* copy_op = (src.attrs.count("copy_op") != 0) ? - Op::Get(src.GetAttr("copy_op")) : - nullptr; + CHECK_EQ(ys.size(), ys_out_grad.size()); - // topo sort + // initialize a topological order of the graph nodes and `output_grads` + // that maps every operator node to its gradient entries std::vector topo_order; - std::unordered_map > output_grads; + std::unordered_map > output_grads; - DFSVisit(ys, [&](const ObjectPtr& node) { - if (output_grads.count(node.get()) == 0) { - output_grads[node.get()].resize(node->num_outputs()); - } - topo_order.push_back(node); - }); + DFSVisit(ys, + [&](const ObjectPtr& node) { + if (output_grads.count(node.get()) == 0) { + output_grads[node.get()].resize(node->num_outputs()); + } + topo_order.push_back(node); + }); - CHECK_EQ(ys.size(), ys_out_grad.size()); for (size_t i = 0; i < ys.size(); ++i) { - NodeEntry ograd = ys_out_grad[i]; - output_grads[ys[i].node.get()][ys[i].index].grads = { ograd }; + output_grads[ys[i].node.get()][ys[i].index].grads = {ys_out_grad[i]}; } - // Check that all xs are reachable from ys + // check that all xs are reachable from ys for (size_t i = 0; i < xs.size(); ++i) { CHECK(output_grads.find(xs[i].node.get()) != output_grads.end()) - << "Cannot differentiate with respect to the " << i+1 << "-th variable " + << "Cannot differentiate with respect to the " + << (i + 1) << "-th variable " << "because it is unreachable from the outputs."; } - // construct mirror as memory reduction strategy if needed - std::unordered_map mirror_map; - if (mirror_fun != nullptr) { - for (const ObjectPtr& node_ptr : topo_order) { - if (mirror_fun(*node_ptr)) { - ObjectPtr new_node = Node::Create(); - *new_node = *node_ptr; - new_node->attrs.name += "_mirror"; - for (auto& e : new_node->inputs) { - e.node = mirror_map.at(e.node.get()); + using MirrorFun = std::function; + MirrorFun mirror_fun = nullptr; + if (src.attrs.count("mirror_fun") != 0) { + mirror_fun = src.GetAttr("mirror_fun"); + } + std::unordered_map mirror_map; + + // complete the backward graph of the src, but without backward mirroring + nnvm::Graph gsrc = BuildGradientGraph(src, xs, topo_order, + output_grads, + nullptr, mirror_map); + if (mirror_fun == nullptr) { + return gsrc; // Gradient pass without mirroring ends here. + } + const IndexedGraph& idx = src.indexed_graph(), + & gidx = gsrc.indexed_graph(); + // =========================================================================== + // ----- Gradient Pass w/ Backward Mirroring ----- + // =========================================================================== + // Record, for each node entry ∈ gsrc, the nodes that reference it as inputs. + // It is important to note that since the node entry reference mapping is + // constructed from gradient graph, it can only be indexed using gidx entry ID. + std::vector > node_entry_ref_map( + gidx.num_node_entries()); + static const auto& fignore_inputs = Op::GetAttr("FIgnoreInputs"); + for (uint32_t gnid = 0; gnid < gidx.num_nodes(); ++gnid) { + const IndexedGraph::Node& inode = gidx[gnid]; + if (inode.source->is_variable()) { + continue; + } + for (uint32_t i = 0; i < inode.inputs.size(); ++i) { + if (fignore_inputs.count(inode.source->op()) != 0) { + std::vector ignore_inputs = + fignore_inputs[inode.source->op()](inode.source->attrs); + if (std::find(ignore_inputs.begin(), ignore_inputs.end(), i) + != ignore_inputs.end()) { + continue; } - for (auto& n : new_node->control_deps) { - n = mirror_map.at(n.get()); + } + node_entry_ref_map[gidx.entry_id(inode.inputs[i])].insert(inode.source); + } + } // for (gnid ∈ gidx.num_nodes()) + // Inference the shapes and data types of the gradient graphs. Those + // information is needed in later stages to determine whether putting a node + // on the mirror path can be beneficial or not. + using mxnet::ShapeVector; + ShapeVector in_arg_shapes = std::move(src.GetAttr("in_arg_shapes")); + DTypeVector in_arg_dtypes = std::move(src.GetAttr("in_arg_dtypes")); + src = mxnet::exec::InferShape(std::move(src), std::move(in_arg_shapes), "__shape__"); + src = mxnet::exec::InferType(std::move(src), std::move(in_arg_dtypes), "__dtype__"); + CHECK(src.GetAttr("shape_num_unknown_nodes") == 0U); + CHECK(src.GetAttr("dtype_num_unknown_nodes") == 0U); + const ShapeVector& src_shapes = src.GetAttr("shape"); + const DTypeVector& src_dtypes = src.GetAttr("dtype"); + + std::queue worklist; + // initialize the worklist to the output nodes + for (const NodeEntry& e : src.outputs) { + worklist.push(e.node.get()); + } + for (; !worklist.empty(); worklist.pop()) { + const Node* const workitem = worklist.front(); + // skip the current node if it has already been recorded in the mirror map + if (mirror_map.find(workitem) != mirror_map.end()) { + continue; + } + + // subgraph and its frontier and topological-sorted view + std::unordered_set subgraph; + // The associated boolean variable is used for marking forward propagation. + std::unordered_map subgraph_frontier; + std::deque subgraph_topo_order; + // ========================================================================= + // --- Backward Pass --- + // ========================================================================= + // The sub-worklist is used for constructing the subgraph. It is initialized + // to have the current workitem node. + std::queue subworklist; + subworklist.push(workitem); + // Local auxiliary function that does backpropagation on the subworklist + // items to construct the subgraph. E.g., + // A subworklist = {A} + // ↑ + // B + // After invoking this function. `subgraph` will become {A, B}. + // Note that this function will be invoked multiple times. + auto subworklist_backprop = [&subworklist, &subgraph, + &subgraph_topo_order, + &mirror_fun, &worklist]() { + std::deque subworklist_topo_order; + for (; !subworklist.empty(); subworklist.pop()) { + const Node* const subworkitem = subworklist.front(); + if (subgraph.find(subworkitem) == subgraph.end()) { + subgraph.insert(subworkitem); + subworklist_topo_order.push_front(subworkitem); + } + for (const NodeEntry& e : subworkitem->inputs) { + if (!mirror_fun(*(e.node))) { + worklist.push(e.node.get()); + } else { + subworklist.push(e.node.get()); + } + } + for (const ObjectPtr& n : subworkitem->control_deps) { + if (!mirror_fun(*n)) { + worklist.push(n.get()); + } else { + subworklist.push(n.get()); + } + } + } // for (subworkitem ∈ subworklist) + // please refer to later comments for why the topological order of the + // subworklist should be directly appended to that of the subgraph + subgraph_topo_order.insert(subgraph_topo_order.end(), + subworklist_topo_order.begin(), + subworklist_topo_order.end()); + }; + // Start propagating from the current workitem node backward until the + // mirroring function returns false (indicating that a compute-heavy layer + // has been hit), in which case we put the node that fails the mirroring + // function into the worklist as the new head. During the traversal, we + // build up the subgraph and its topological order at the same time. + subworklist_backprop(); + + // Forward propagate the subgraph nodes in topological order and make sure + // that all the node entries that are part of the forward propagation belong + // to the same subgraph. This process continues until all the node entries + // have been included, in which case we say that the subgraph has converged. + // + // The reason why this step is needed is because, consider the example below: + // A B C subworklist = {A} + // ↑ ↑ ↑ + // ↖ ↑ ↗ + // D + // Without loss of generality, suppose that the previous backpropagation + // starts from node A, then the subgraph will only contain branch D → A. + // However, we want to include branch D → B adn D → C as well since all + // three branches share the same node entries (i.e., the outputs of D) and + // hence they are all affected by the decision on whether D should be put + // onto the mirror path or not. + bool has_subgraph_converged; + do { + has_subgraph_converged = true; + for (const Node* const subgraph_node : subgraph_topo_order) { + for (const NodeEntry& subgraph_node_entry : + subgraph_node->inputs) { + const std::unordered_set ref_nodes = + node_entry_ref_map[gidx.entry_id(subgraph_node_entry)]; + + for (const Node* const ref_node : ref_nodes) { + // If there are other nodes that reference the node entry and that + // node satisfies the following conditions: + // (1) belongs to the forward graph, and + // (2) is not part of the subgraph + // We add that node to the subgraph and adjust the topological order + // accordingly. + if (ref_node != subgraph_node && idx.exist(ref_node) && + subgraph.find(ref_node) == subgraph.end()) { + // Forward propagate from the reference node until the mirroring + // function returns false. This indicates that the head of the + // branch has been reached (i.e., B or C in our previously + // illustrated example), and we add it to the subworklist for + // another backpropagation. + std::queue ref_node_heads; + ref_node_heads.push(ref_node); + for (; !ref_node_heads.empty(); ref_node_heads.pop()) { + const Node* const ref_node_head = ref_node_heads.front(); + bool is_ref_node_head_output = false; + for (const NodeEntry& y : ys) { + if (ref_node_head == y.node.get()) { + is_ref_node_head_output = true; + } + } + if (!mirror_fun(*ref_node_head) || is_ref_node_head_output) { + subworklist.push(ref_node_head); + continue; + } + + uint32_t gnid = gidx.node_id(ref_node_head); + for (uint32_t oid = 0; oid < ref_node_head->num_outputs(); ++oid) { + uint32_t geid = gidx.entry_id(gnid, oid); + for (const Node* const n : node_entry_ref_map[geid]) { + if (idx.exist(n)) { + ref_node_heads.push(n); + } + } + } // for (oid ∈ [0, ref_node_head->num_outputs())) + } // for (ref_node_head ∈ ref_node_heads) + // Do the backpropagation again. The topological order of the + // subworklist can be directly appended to the end of the existing + // order. E,g, in our previous example, we expect to have + // `subgraph_topo_order` = {D, A} + {B} + {C} + subworklist_backprop(); + // indicate that the subgraph has not changed the quit the loop + has_subgraph_converged = false; + break; + } // if (ref_node != subgraph_node && idx.exist(ref_node) && + // subgraph.find(ref_node) == subgraph.end() + } // for (ref_node ∈ ref_nodes) + if (!has_subgraph_converged) { + break; + } + } // for (subgraph_node_entry ∈ subgraph_node->inputs) + if (!has_subgraph_converged) { + break; } - mirror_map[node_ptr.get()] = std::move(new_node); - } else { - mirror_map[node_ptr.get()] = node_ptr; + } // for (subgraph_node ∈ subgraph_topo_order) + } while (!has_subgraph_converged); + // ========================================================================= + // --- Forward Pass --- + // ========================================================================= + // Now that the subgraph is complete, we start by assuming that all the + // nodes in the subgraph can be mirrored, and forward propagate starting + // from the subgraph frontier. The propagation is successful if the amount + // of storage released by removing the frontier nodes off the mirror path is + // greater or equal to the storage allocated. + do { + has_subgraph_converged = true; + // Obtain the subgraph frontier. The subgraph frontier denotes a group of + // nodes whose inputs satisfy the following conditions: + // (1) fails the mirroring function, or + // (2) has been marked as NOT on the mirror path, i.e., + // `mirror_map[input_node] == nullptr` + // E.g., consider the subgraph below: + // A + // ↑ + // B + // ↑ + // C + // The subgraph frontier in this example is {C}. As C is the place where + // the mirror path (and hence the forward propagation) starts. + subgraph_frontier.clear(); + for (const Node* const subgraph_node : subgraph) { + if (!mirror_fun(*subgraph_node)) { + mirror_map[subgraph_node] = nullptr; + continue; + } + if (mirror_map.find(subgraph_node) != mirror_map.end()) { + continue; + } + bool is_frontier = true; + for (const NodeEntry& e : subgraph_node->inputs) { + auto iter = mirror_map.find(e.node.get()); + if (mirror_fun(*(e.node)) && + !(iter != mirror_map.end() && iter->second == nullptr)) { + is_frontier = false; + } + } + for (const ObjectPtr& n : subgraph_node->control_deps) { + auto iter = mirror_map.find(n.get()); + if (mirror_fun(*n) && + !(iter != mirror_map.end() && iter->second == nullptr)) { + is_frontier = false; + } + } + if (is_frontier) { + subgraph_frontier.emplace(subgraph_node, false); + } + } // for (subgraph_node ∈ subgraph) + for (std::pair& frontier_node : subgraph_frontier) { + if (frontier_node.second) { + // If the frontier node has been marked as true, then this indicates + // that the node has been forward propagated before (by other nodes + // that share the same input). + continue; + } + // As we do the forward propagation, we not only propagate the current + // frontier node individually, but all the frontier nodes that share the + // same input with the current one. This is a recursive progress because + // it is possible for A to share the same input with B and B, at the + // same time, to share the same input with C, like in the graph below: + // D E + // ↑ ↑ + // ↗ ↖ ↗ ↖ + // A B C + std::unordered_set forward_candidates{frontier_node.first}; + frontier_node.second = true; + bool has_forward_candidates_converged; + do { + has_forward_candidates_converged = true; + for (const Node* const candidate : forward_candidates) { + for (const NodeEntry& candidate_input : candidate->inputs) { + uint32_t geid = gidx.entry_id(candidate_input); + const std::unordered_set& ref_nodes = node_entry_ref_map[geid]; + for (const Node* const ref_node : ref_nodes) { + if (ref_node != frontier_node.first && + subgraph_frontier.find(ref_node) != subgraph_frontier.end() && + forward_candidates.find(ref_node) == forward_candidates.end()) { + subgraph_frontier[ref_node] = true; + forward_candidates.insert(ref_node); + has_forward_candidates_converged = false; + } + } // for (ref_node ∈ ref_nodes) + if (!has_forward_candidates_converged) { + break; + } + } // for (candidate_input ∈ candidate->inputs) + if (!has_forward_candidates_converged) { + break; + } + } // for (candidate ∈ forward_candidates) + } while (!has_forward_candidates_converged); + // Record the node entries that are newly allocated and those that are + // released. A node entry can be released if all its referencing nodes + // are part of the subgraph frontier. Otherwise, it is in the allocated set. + std::unordered_set newly_allocated_node_entries, + released_node_entries; + for (const Node* const candidate : forward_candidates) { + uint32_t nid = idx.node_id(candidate), + gnid = gidx.node_id(candidate); + for (uint32_t oid = 0; oid < candidate->num_outputs(); ++oid) { + uint32_t eid = idx.entry_id(nid, oid), + geid = gidx.entry_id(gnid, oid); + if (node_entry_ref_map[geid].size() != 0) { + newly_allocated_node_entries.insert(eid); + } + } + for (const NodeEntry& candidate_input : candidate->inputs) { + uint32_t eid = idx.entry_id(candidate_input), + geid = gidx.entry_id(candidate_input); + const std::unordered_set& ref_nodes = node_entry_ref_map[geid]; + bool can_be_released = true; + for (const Node* const ref_node : ref_nodes) { + if (subgraph_frontier.find(ref_node) == subgraph_frontier.end()) { + newly_allocated_node_entries.insert(eid); + can_be_released = false; + } + } + if (can_be_released) { + released_node_entries.insert(eid); + } + } // for (candidate_input ∈ candidate->input) + } // for (candidate ∈ forward_candidates) + + // Now, compare the total amount of newly allocated storage versus the + // released storage, if the latter is greater or equal to the former, + // then we remove the current node from the frontier. Otherwise all the + // forward candidate nodes are marked as on the mirror path. + size_t newly_allocated_storage = 0, released_storage = 0; + for (const uint32_t eid : newly_allocated_node_entries) { + newly_allocated_storage += src_shapes[eid].Size() * + MXGetDTypeSize(src_dtypes[eid]); + } + for (const uint32_t eid : released_node_entries) { + released_storage += src_shapes[eid].Size() * MXGetDTypeSize(src_dtypes[eid]); + } + if (released_storage >= newly_allocated_storage) { + for (const Node* const candidate : forward_candidates) { + CHECK(subgraph_frontier.find(candidate) != subgraph_frontier.end()); + subgraph_frontier.erase(candidate); + mirror_map[candidate] = nullptr; + } + has_subgraph_converged = false; + break; + } // if (released_storage >= newly_allocated_storage) + } // for (frontier_node ∈ subgraph_frontier) + } while (!has_subgraph_converged); + + // Finally, mark all the remaining nodes of the subgraph as on the mirror path. + for (const Node* const subgraph_node : subgraph_topo_order) { + if (mirror_map.find(subgraph_node) != mirror_map.end()) { + continue; + } + ObjectPtr subgraph_node_mirror = Node::Create(); + *subgraph_node_mirror = *subgraph_node; + subgraph_node_mirror->attrs.name += "_mirror"; + for (NodeEntry& e : subgraph_node_mirror->inputs) { + e.node = MapFwdNodeToMirrorPath(e.node, mirror_map); } + for (ObjectPtr& n : subgraph_node_mirror->control_deps) { + n = MapFwdNodeToMirrorPath(n, mirror_map); + } + mirror_map[subgraph_node] = subgraph_node_mirror; } + } // for (workitem ∈ worklist) + DFSVisit(ys, + [&](const ObjectPtr& node) { + if (mirror_map.at(node.get()) != nullptr) { + node->attrs.dict["__mirror_stage__"] = "1"; + } else { + node->attrs.dict["__mirror_stage__"] = "0"; + } + }); + return BuildGradientGraph(src, xs, topo_order, + output_grads, + mirror_fun, mirror_map); +} + + +/*! + * \brief Auxiliary function that checks whether all the gradients are zero or not. + */ +inline bool CheckGradAllZero(const std::vector& grads, + const std::vector& zero_ops) { + if (!grads.size() || !zero_ops.size()) return false; + for (const auto& g : grads) { + bool found = false; + for (const auto& op : zero_ops) { + if (g.node->op() == op) { + found = true; + break; + } + } + if (!found) return false; + } + return true; +} + + +Graph BuildGradientGraph( + const Graph& src, + const std::vector& xs, + const std::vector& topo_order, + std::unordered_map > output_grads, + std::function mirror_fun, + const std::unordered_map& mirror_map) { + static auto& grad_fun_map = Op::GetAttr("FGradient"); + + // gradient aggregation function + using AggFun = std::function&&)>; + AggFun agg_fun = [](std::vector&& v)->NodeEntry { + if (v.size() == 1) { + return std::move(v[0]); + } else if (v.size() == 0) { + ObjectPtr zero_grad_node = Node::Create(); + zero_grad_node->attrs.op = Op::Get("zeros"); + zero_grad_node->attrs.name = "zero_grad"; + zero_grad_node->attrs.op->attr_parser(&(zero_grad_node->attrs)); + return NodeEntry{zero_grad_node, 0, 0}; + } else { + ObjectPtr grad_sum_node = Node::Create(); + grad_sum_node->attrs.op = Op::Get("elemwise_sum"); + grad_sum_node->inputs = std::move(v); + grad_sum_node->attrs.name = "grad_sum"; + grad_sum_node->attrs.dict["num_args"] = + std::to_string(grad_sum_node->inputs.size()); + grad_sum_node->attrs.op->attr_parser(&(grad_sum_node->attrs)); + return NodeEntry{grad_sum_node, 0, 0}; + } + }; + if (src.attrs.count("grad_aggregate_fun") != 0) { + agg_fun = src.GetAttr("grad_aggregate_fun"); } - // traverse backward - static auto& grad_fun_map = Op::GetAttr("FGradient"); - static auto& finfer_shape = Op::GetAttr("FInferShape"); + // zero and copy operators + std::vector zero_ops; + if (src.attrs.count("zero_ops") != 0) { + zero_ops = src.GetAttr >("zero_ops"); + } + const Op* copy_op = (src.attrs.count("copy_op_str") != 0) ? + Op::Get(src.GetAttr("copy_op_str")) : nullptr; std::vector out_agg_grads; - for (auto rit = topo_order.rbegin(); rit != topo_order.rend(); ++rit) { - const ObjectPtr& ptr = *rit; - if (ptr->is_variable()) continue; + for (auto topo_order_rit = topo_order.rbegin(); + topo_order_rit != topo_order.rend(); ++topo_order_rit) { + const ObjectPtr& src_fwd_node = *topo_order_rit; + if (src_fwd_node->is_variable()) continue; + + // gather all the output gradient entries and apply the aggregation function out_agg_grads.clear(); - auto& out_grad_vec = output_grads.at(ptr.get()); + auto& out_grad_vec = output_grads.at(src_fwd_node.get()); for (uint32_t i = 0; i < out_grad_vec.size(); ++i) { GradEntry& e = out_grad_vec[i]; e.sum = agg_fun(std::move(e.grads)); - if (e.need_attr_hint && attr_hint_fun != nullptr) { - e.sum = attr_hint_fun(e.sum, NodeEntry{ptr, 0, i}); - } out_agg_grads.push_back(e.sum); } - if ((*rit)->inputs.size() != 0) { - ObjectPtr fwd_node = (mirror_map.size() == 0 ? ptr : mirror_map.at(ptr.get())); + if (src_fwd_node->inputs.size() != 0) { + // If the current node has inputs, the gradients need to be further + // propagated backward. + ObjectPtr fwd_node = MapFwdNodeToMirrorPath(src_fwd_node, mirror_map); + // calculate the input gradients std::vector input_grads; - // Check for FGradient - if (grad_fun_map.contains(ptr->op())) { - input_grads = grad_fun_map[ptr->op()](fwd_node, out_agg_grads); - CHECK_EQ((*rit)->inputs.size(), input_grads.size()) - << "Gradient function not returning enough gradient"; + if (grad_fun_map.count(src_fwd_node->op())) { + input_grads = grad_fun_map[src_fwd_node->op()](fwd_node, out_agg_grads); + CHECK_EQ(src_fwd_node->inputs.size(), input_grads.size()) + << "The Gradient function is not returning enough gradients."; + // If the operator node fails the mirror function, it is however still + // possible for its feature maps to be recomputed without incurring + // significant runtime overhead. The reason is because some operators + // have their feature maps sit on the inputs rather than the outputs. + // E.g., the fully-connected layer (Y=XW^T), whose gradients are given + // by dX = dYW, dW = dY^TX and hence have no data dependency on Y. + if (mirror_fun != nullptr && !mirror_fun(*fwd_node)) { + for (NodeEntry& input_grad : input_grads) { + for (NodeEntry& grad_input : input_grad.node->inputs) { + const ObjectPtr& grad_input_node_mirrored = MapFwdNodeToMirrorPath( + grad_input.node, mirror_map); + grad_input = NodeEntry( + grad_input_node_mirrored, + grad_input.index, + grad_input.version); + } // for (grad_input ∈ input_grad.node->inputs) + } // for (input_grad ∈ input_grads) + } // if (mirror_fun != nullptr && !mirror_fun(*fwd_node)) } else if (CheckGradAllZero(out_agg_grads, zero_ops)) { - for (size_t i = 0; i < fwd_node->num_inputs(); ++i) { + for (size_t i = 0; i < src_fwd_node->num_inputs(); ++i) { std::ostringstream os; - if (1 == fwd_node->num_inputs()) { + if (1 == src_fwd_node->num_inputs()) { os << fwd_node->attrs.name << "_backward"; } else { os << fwd_node->attrs.name << "_in" << i << "_backward"; @@ -208,25 +634,25 @@ Graph Gradient(Graph src) { p->op()->attr_parser(&(p->attrs)); } input_grads.emplace_back(p, 0, 0); - } + } // for (i ∈ src_fwd_node->num_inputs()) } else { - LOG(FATAL) << "Operator " << fwd_node->op()->name << " is non-differentiable " + LOG(FATAL) << "Operator " << src_fwd_node->op()->name << " is non-differentiable " << "because it didn't register FGradient attribute."; } - for (const auto& nodeEntry : input_grads) - CHECK(nodeEntry.node); - auto git = input_grads.begin(); - CHECK((*rit)->inputs.size() <= input_grads.size()); - for (auto it = (*rit)->inputs.begin(); it != (*rit)->inputs.end(); ++it, ++git) { - auto& output_grad_entry = output_grads[it->node.get()][it->index]; - // if any of the backward op can do shape inference, the hint is not necessary. - if (finfer_shape.contains(git->node->op())) { - output_grad_entry.need_attr_hint = false; - } - output_grad_entry.grads.emplace_back(std::move(*git)); + for (const auto& e : input_grads) { + CHECK(e.node); } - } - } + auto input_grad_iter = input_grads.begin(); + CHECK(src_fwd_node->inputs.size() <= input_grads.size()); + for (auto input_iter = src_fwd_node->inputs.begin(); + input_iter != src_fwd_node->inputs.end(); + ++input_iter, ++input_grad_iter) { + // propagate the input gradients to the output gradients of the input nodes + output_grads[input_iter->node.get()][input_iter->index] + .grads.emplace_back(std::move(*input_grad_iter)); + } + } // if (src_fwd_node->inputs.size() != 0) + } // for (topo_order_rit ∈ reverse(topo_order)) // take out the xs' grads Graph ret; ret.outputs.resize(xs.size()); @@ -237,9 +663,6 @@ Graph Gradient(Graph src) { // aggregate sum if there haven't been if (entry.sum.node.get() == nullptr) { entry.sum = agg_fun(std::move(entry.grads)); - if (entry.need_attr_hint && attr_hint_fun != nullptr) { - entry.sum = attr_hint_fun(entry.sum, e); - } } if (copy_op != nullptr) { auto kv = unique_grads.find(entry.sum); @@ -254,15 +677,16 @@ Graph Gradient(Graph src) { copy_node->attrs.name = os.str(); copy_node->inputs.emplace_back(entry.sum); if (copy_node->attrs.op->attr_parser != nullptr) { - copy_node->attrs.op->attr_parser(&(copy_node->attrs)); + copy_node->attrs.op->attr_parser(&(copy_node->attrs)); } - unique_grads.emplace(NodeEntry{std::move(copy_node), 0, 0}, std::make_pair(1, counter)); + unique_grads.emplace(NodeEntry{std::move(copy_node), 0, 0}, + std::make_pair(1, counter)); } } else { - ret.outputs[counter] = entry.sum; + ret.outputs[counter] = entry.sum; } ++counter; - } + } // for (e ∈ xs) if (copy_op != nullptr) { for (const auto& kv : unique_grads) { ret.outputs[kv.second.second] = kv.first; @@ -271,6 +695,7 @@ Graph Gradient(Graph src) { return ret; } + // register pass NNVM_REGISTER_PASS(MXGradient) .describe("Return a gradient graph of src.attrs[\"ys\"] wrt src.attrs[\"xs\"]") @@ -278,6 +703,8 @@ NNVM_REGISTER_PASS(MXGradient) .set_change_graph(true) .depend_graph_attr("grad_ys") .depend_graph_attr("grad_xs") +.depend_graph_attr("in_arg_shapes") +.depend_graph_attr("in_arg_dtypes") .depend_graph_attr("grad_ys_out_grad"); } // namespace diff --git a/src/nnvm/plan_memory.cc b/src/nnvm/plan_memory.cc index 3815f239f88c..f6c13e247e6a 100644 --- a/src/nnvm/plan_memory.cc +++ b/src/nnvm/plan_memory.cc @@ -34,9 +34,10 @@ namespace nnvm { namespace pass { -namespace { -// Return bytes of data flag. -static int MXGetDTypeSize(int type_flag) { +/*! + * \brief Return the storage in bytes for the corresponding data flag. + */ +size_t MXGetDTypeSize(const int type_flag) { switch (type_flag) { case mshadow::kUint8: case mshadow::kInt8: @@ -61,6 +62,8 @@ static int MXGetDTypeSize(int type_flag) { } } +namespace { + // simple graph based allocator. class MXGraphAllocator { public: @@ -78,8 +81,7 @@ class MXGraphAllocator { StorageID Request(int dev_id, int dtype, mxnet::TShape shape, uint32_t node_id) { if (!mxnet::shape_is_known(shape)) return kBadStorageID; // search memory block in [size / match_range_, size * match_range_) - // TODO(tqchen) add size of the dtype, assume 4 bytes for now - size_t size = shape.Size() * 4; + size_t size = shape.Size() * MXGetDTypeSize(dtype); if (match_range_ == 0) return this->Alloc(dev_id, size); auto begin = free_.lower_bound(size / match_range_); auto mid = free_.lower_bound(size); @@ -373,7 +375,8 @@ Graph MXPlanMemory(Graph ret) { size_t min_allocated_bytes = -1; size_t max_match_range = dmlc::GetEnv("NNVM_EXEC_MATCH_RANGE", 16); size_t min_match_range = - dmlc::GetEnv("NNVM_AUTO_SEARCH_MATCH_RANGE", false) ? 1 : max_match_range; + dmlc::GetEnv("MXNET_MEMORY_OPT", 0) || + dmlc::GetEnv("NNVM_AUTO_SEARCH_MATCH_RANGE", false) ? 1 : max_match_range; for (size_t match_range = min_match_range; match_range <= max_match_range; match_range *= 2) { // Make a copy of related fields StorageVector storage_vec(storage); diff --git a/src/operator/nn/activation-inl.h b/src/operator/nn/activation-inl.h index 1d8e4c2b6cda..06ff1fe1bedb 100644 --- a/src/operator/nn/activation-inl.h +++ b/src/operator/nn/activation-inl.h @@ -176,8 +176,13 @@ void ActivationGradComputeImpl(const nnvm::NodeAttrs& attrs, const OpContext &ct ctx, inputs[0], inputs[1], req[0], outputs[0]); break; case activation::kSoftSign: - ActivationBackward( - ctx, inputs[0], inputs[2], req[0], outputs[0]); + if (dmlc::GetEnv("MXNET_MEMORY_OPT", 0)) { + ActivationBackward( + ctx, inputs[0], inputs[1], req[0], outputs[0]); + } else { + ActivationBackward( + ctx, inputs[0], inputs[2], req[0], outputs[0]); + } break; default: LOG(FATAL) << "unknown activation type"; diff --git a/src/operator/nn/activation.cc b/src/operator/nn/activation.cc index 1259ceb7d9b3..622d3464371f 100644 --- a/src/operator/nn/activation.cc +++ b/src/operator/nn/activation.cc @@ -41,16 +41,19 @@ namespace activation { int GradNumInputs(int act_type) { // check activation.cu \sa ActivationGradCompute + if (dmlc::GetEnv("MXNET_MEMORY_OPT", 0)) { + return 2; + } switch (act_type) { - case kReLU: - return 2; - case kSoftReLU: - case kSoftSign: - case kTanh: - case kSigmoid: - return 3; - default: - CHECK(false) << "missing activation type"; + case kReLU: + return 2; + case kSoftReLU: + case kSoftSign: + case kTanh: + case kSigmoid: + return 3; + default: + CHECK(false) << "missing activation type"; } // unreachable return -1; @@ -65,27 +68,34 @@ struct ActivationGrad { const char *op_name; std::vector operator()(const nnvm::ObjectPtr& n, const std::vector& ograds) const { - // ograds, output... + // ograds std::vector heads(ograds.begin(), ograds.end()); - heads.emplace_back(n, activation::kOut, 0); - const NodeAttrs& attrs = n->attrs; using namespace activation; int act_type = dmlc::get(attrs.parsed).act_type; - // for ReLU, no need to pass input data. This enables inplace optimization during the - // forward pass. - // check activation.cu \sa ActivationGradCompute - switch (act_type) { + + if (dmlc::GetEnv("MXNET_MEMORY_OPT", 0)) { + if (act_type == kSoftSign) { + heads.push_back(n->inputs[activation::kData]); + } else { + heads.emplace_back(n, activation::kOut, 0); + } + } else { + heads.emplace_back(n, activation::kOut, 0); // output + // for ReLU, no need to pass input data. This enables inplace optimization + // during the forward pass. check activation.cu \sa ActivationGradCompute + switch (act_type) { case kReLU: - break; + break; case kSoftReLU: case kSoftSign: case kTanh: case kSigmoid: - heads.push_back(n->inputs[activation::kData]); - break; + heads.push_back(n->inputs[activation::kData]); + break; default: - CHECK(false) << "missing activation type"; + CHECK(false) << "missing activation type"; + } } return MakeGradNode(op_name, n, heads, n->attrs.dict); } diff --git a/src/operator/nn/activation.cu b/src/operator/nn/activation.cu index ec7db844b100..1116cf20165b 100644 --- a/src/operator/nn/activation.cu +++ b/src/operator/nn/activation.cu @@ -82,24 +82,48 @@ void ActivationGradCompute(const nnvm::NodeAttrs& attrs, CHECK_EQ(outputs.size(), 1U); CHECK_EQ(req.size(), 1U); + bool do_memory_opt = dmlc::GetEnv("MXNET_MEMORY_OPT", 0); + // both SoftReLU and SoftSign not supported by CUDNN yet if (act_type == activation::kSoftReLU) { ActivationBackward( ctx, inputs.at(0), inputs.at(1), req[0], outputs[0]); } else if (act_type == activation::kSoftSign) { - ActivationBackward( - ctx, inputs.at(0), inputs.at(2), req[0], outputs[0]); + if (do_memory_opt) { + ActivationBackward( + ctx, inputs.at(0), inputs.at(1), req[0], outputs[0]); + } else { + ActivationBackward( + ctx, inputs.at(0), inputs.at(2), req[0], outputs[0]); + } } else if (act_type == activation::kReLU) { - MSHADOW_REAL_TYPE_SWITCH(inputs.at(0).type_flag_, DType, { - // XXX: for y = relu(x), y is passed as "in_data" to Backward() - get_cudnn_op(param).Backward(ctx, inputs.at(0), inputs.at(1), - inputs.at(1), req[0], outputs[0]); - }); + if (do_memory_opt) { + ActivationBackward( + ctx, inputs.at(0), inputs.at(1), req[0], outputs[0]); + } else { + MSHADOW_REAL_TYPE_SWITCH(inputs.at(0).type_flag_, DType, { + // XXX: for y = relu(x), y is passed as "in_data" to Backward() + get_cudnn_op(param).Backward(ctx, inputs.at(0), inputs.at(1), + inputs.at(1), req[0], outputs[0]); + }); + } } else { - MSHADOW_REAL_TYPE_SWITCH(inputs.at(0).type_flag_, DType, { - get_cudnn_op(param).Backward(ctx, inputs.at(0), inputs.at(2), - inputs.at(1), req[0], outputs[0]); - }); + if (do_memory_opt) { + if (act_type == activation::kTanh) { + ActivationBackward( + ctx, inputs.at(0), inputs.at(1), req[0], outputs[0]); + } else if (act_type == activation::kSigmoid) { + ActivationBackward( + ctx, inputs.at(0), inputs.at(1), req[0], outputs[0]); + } else { + LOG(FATAL) << "unknown activation type"; + } + } else { + MSHADOW_REAL_TYPE_SWITCH(inputs.at(0).type_flag_, DType, { + get_cudnn_op(param).Backward(ctx, inputs.at(0), inputs.at(2), + inputs.at(1), req[0], outputs[0]); + }); + } // if (do_memory_opt) } } #endif diff --git a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h index 881d3d2247da..94a1572b8db3 100644 --- a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h +++ b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h @@ -123,13 +123,15 @@ class CuDNNBatchNormOp { Tensor save_inv_var = out_data[cudnnbatchnorm::kInvVar] .get_with_shape(Shape1(shape_[1]), s); - // If the lock on the auxiliary states is set, - // then this implies that the preceding call is also a `Forward()` call, - // which further indicates that we are in the backward mirroring mode, - // and therefore update to the auxiliary states is disabled. - // This is done by setting the `momentum` to `1` (or `factor` to `0`). - float factor = (dmlc::GetEnv("MXNET_BACKWARD_DO_MIRROR", 0) && internal_aux_states_lock_) ? - 0 : (1 - param_.momentum); + // If the lock on the auxiliary states is set, then this implies that + // the preceding call is also a `Forward()` call, which further + // indicates that we are in the backward mirroring mode, and therefore + // update to the auxiliary states is disabled. This is done by setting + // the `momentum` to `1` (or `factor` to `0`). + float factor = ((dmlc::GetEnv("MXNET_BACKWARD_DO_MIRROR", 0) || + dmlc::GetEnv("MXNET_MEMORY_OPT", 0)) + && internal_aux_states_lock_) ? + 0 : (1 - param_.momentum); CUDNN_CALL(cudnnBatchNormalizationForwardTraining(s->dnn_handle_, mode, &a, diff --git a/tests/python/unittest/test_memory_opt.py b/tests/python/unittest/test_memory_opt.py new file mode 100644 index 000000000000..d31eee0d4a99 --- /dev/null +++ b/tests/python/unittest/test_memory_opt.py @@ -0,0 +1,202 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: skip-file +import mxnet as mx +import os +import sys + + +num_hidden = 4096 + + +def memory_opt_env_check(test_func): + # This decorator checks for th + def test_memory_opt_wrapper(): + # Whether the underlying OS is Windows or not. Windows does not support + # setting environment variblae on the fly. In other words, statement + # + # os.environ["MXNET_MEMORY_OPT"] = '1' + # + # will have NO effect because the C++ backend still sees + # `os.environ["MXNET_MEMORY_OPT"]` as NULL pointer. + # + # \sa test_operator.py:test_norm + is_windows = sys.platform.startswith('win') + do_memory_opt = True + if is_windows: + if "MXNET_MEMORY_OPT" not in os.environ: + do_memory_opt = False + else: + do_memory_opt = os.environ["MXNET_MEMORY_OPT"] == '1' + else: + os.environ["MXNET_MEMORY_OPT"] = '1' + + if do_memory_opt: + test_func() + os.environ["MXNET_MEMORY_OPT"] = '0' + return test_memory_opt_wrapper + + +@memory_opt_env_check +def test_rnn_cell(): + # x →→→ + →→→ tanh ⇒⇒⇒ + # ↑ + # y →→→→ + # + # ⇒⇒⇒ : Backward Dependency + # In this example, there is no benefit in mirroring the elementwise-add + # operator and the tanh operator. + x = mx.sym.Variable("x") + x = mx.sym.FullyConnected(x, num_hidden=num_hidden) + y = mx.sym.Variable("y") + y = mx.sym.FullyConnected(y, num_hidden=num_hidden) + tmp = mx.sym._internal._plus(x, y) + z = mx.sym.Activation(tmp, act_type='tanh') + exec = z.simple_bind(mx.cpu(), 'write', x=(num_hidden,), y=(num_hidden,)) + exec_debug_str = exec.debug_str().split('\n') + op_checklist = 0 + for i, line in enumerate(exec_debug_str): + if "Op:elemwise_add" in line: + op_checklist += 1 + assert exec_debug_str[i + 5] == "\t__mirror_stage__=0" + if "Op:Activation" in line: + op_checklist += 1 + assert exec_debug_str[i + 4] == "\t__mirror_stage__=0" + assert op_checklist == 2, \ + "Not all operator nodes have been verified on the mirror stage" + + +@memory_opt_env_check +def test_mlp_attn(): + # x →→→ + →→→ tanh ⇒⇒⇒ + # ↑ + →→→ tanh ⇒⇒⇒ + # y_1 →→ ↑ + →→→ tanh ⇒⇒⇒ + # y_2 →→→→ ↑ ⋱ + # y_3 →→→→→→ + →→→ tanh ⇒⇒⇒ + # ↑ + # y_n →→→→→→→→→→ + x = mx.sym.Variable("x") + tmp, z = [], [] + num_steps = 5 + in_arg_shapes = {'x': (num_steps, num_hidden,)} + for i in range(num_steps): + y = mx.sym.Variable("y_t%d"%i) + tmp.append(mx.sym.broadcast_add(x, y, name="broadcast_add%d"%i)) + z.append(mx.sym.Activation(tmp[-1], act_type='tanh', + name="activation%d"%i)) + in_arg_shapes["y_t%d"%i] = (1, num_hidden,) + z = mx.sym.Group(z) + exec = z.simple_bind(mx.cpu(), 'write', **in_arg_shapes) + exec_debug_str = exec.debug_str().split('\n') + op_checklist = 0 + for i, line in enumerate(exec_debug_str): + for t in range(num_steps): + if line == "Op:broadcast_add, Name=broadcast_add%d"%t: + op_checklist += 1 + assert exec_debug_str[i + 5] == "\t__mirror_stage__=1" + if line == "Op:Activation, Name=activation%d"%t: + op_checklist += 1 + assert exec_debug_str[i + 4] == "\t__mirror_stage__=1" + assert op_checklist == 2 * num_steps, \ + "Not all operator nodes have been verified on the mirror stage" + + +@memory_opt_env_check +def test_fc(): + # x →→→ tanh ⇒⇒⇒ tanh ⇒⇒⇒ FC + # →→→ tanh_ →→→ + # ↓ + # FC' + x = mx.sym.Variable("x") + y = mx.sym.Activation(x, act_type='tanh', name='y') + z = mx.sym.Activation(y, act_type='tanh', name='z') + z = mx.sym.FullyConnected(z, num_hidden=num_hidden) + exec = z.simple_bind(mx.cpu(), 'write', x=(num_hidden,)) + exec_debug_str = exec.debug_str().split('\n') + op_checklist = 0 + for i, line in enumerate(exec_debug_str): + if line == "Op:Activation, Name=y": + op_checklist += 1 + assert exec_debug_str[i + 4] == "\t__mirror_stage__=0" + if line == "Op:Activation, Name=z": + op_checklist += 1 + assert exec_debug_str[i + 4] == "\t__mirror_stage__=1" + if "Op:FullyConnected" in line: + op_checklist += 1 + assert exec_debug_str[i + 6] == "\t__mirror_stage__=0" + if "Op:_backward_FullyConnected" in line: + op_checklist += 1 + assert exec_debug_str[i + 3] == "\targ[1]=z_mirror(0)" + assert op_checklist == 4, \ + "Not all operator nodes have been verified on the mirror stage" + + +def grep_exec_memory_consumption(exec): + # Grep the memory consumption (in MB) from the executor debug string. + # + # It is important to note that, due to various reasons, the memory + # consumption reported by the executor debug string might be very different + # when compared with the real numbers reported by nvidia-smi. These reasons + # include: + # - Allocations by the CUDA Library (e.g., cuDNN, cuBLAS) + # - Fragmentation (of the MXNet Memory Allocator and cudaMalloc) + exec_debug_str = exec.debug_str().split('\n') + + import re # We will be using regular expressions for grepping the model + # memory consumption. + alloc_line_pattern = re.compile("Total \d+ MB allocated") + for line in exec_debug_str: + if alloc_line_pattern.match(line) is not None: + return int(line.split()[1]) + assert False, "Unable to gerp the memory consumption numbers from the executor " \ + "debug string: %s" % exec_debug_str + + +@memory_opt_env_check +def test_resnet152(): + # Verify the memory allocation behavior on ResNet-152, the state-of-the-art + # model used for image classification. + + # Import the network, similar to what we did in + # ${MXNET_ROOT_DIR}/example/image-classification/train_imagenet.py + from importlib import import_module + sys.path.append(os.path.join(os.path.dirname(__file__), + '..', '..', '..', 'example', 'image-classification')) + resnet_mod = import_module('symbols.resnet') + resnet_152 = resnet_mod.get_symbol(num_classes=1000, + num_layers=152, + image_shape='3,224,224') + # We do the binding twice, one with the memory optimizations and one without. + # It is expected that the memory consumption of the former should be roughly + # half of that of the latter. + memory_opt_exec = resnet_152.simple_bind(mx.cpu(), 'write', + data=(32, 3, 224, 224)) + os.environ["MXNET_MEMORY_OPT"] = '0' + no_opt_exec = resnet_152.simple_bind(mx.cpu(), 'write', data=(32, 3, 224, 224)) + os.environ["MXNET_MEMORY_OPT"] = '1' + memory_opt_alloc = grep_exec_memory_consumption(memory_opt_exec) + no_opt_alloc = grep_exec_memory_consumption(no_opt_exec) + assert memory_opt_alloc / no_opt_alloc < 0.6, \ + "The ratio between the memory consumption with the memory optimizations " \ + "enabled and disabled (%d vs. %d MB) is expected to be smaller than 0.6" \ + % (memory_opt_alloc, no_opt_alloc) + + +if __name__ == "__main__": + import nose + nose.runmodule()