From 05d1b26520d0dcaf3bf9d55343b9792eccb63a6f Mon Sep 17 00:00:00 2001 From: KexinFeng Date: Thu, 23 Sep 2021 16:41:08 -0500 Subject: [PATCH] [FEATURE] Add feature of retain_grad (#20500) * Replace "CloneGradient" with "ElemwiseGradUseNone" * fix issue elemwise_add * fix elemwise_add issue with `ElemwiseGradUseNone` * reverse_to_CloneGradient * add_retain_grad * unit_test * tidy_up * tidy_up * sanity * const_reference * const_ref * merge_rg_to_ag * sanity * sanity * add_drop_grad * sanity_check * sanity_check * sanity_check * build_err * build_err * skip_remark_variable * repetitive_mark * ReInit_in_dropgrad * ReInit_in_dropgrad * sanity_check * add drop and tests to gluon * sanity * update exec_pass.h Co-authored-by: Zhenghui Jin <69359374+barry-jin@users.noreply.github.com> --- include/mxnet/c_api.h | 8 ++ include/mxnet/imperative.h | 4 + python/mxnet/ndarray/ndarray.py | 5 ++ python/mxnet/numpy/multiarray.py | 5 ++ src/c_api/c_api_ndarray.cc | 12 +++ src/imperative/exec_pass.h | 4 +- src/imperative/imperative.cc | 101 ++++++++++++++++++++----- src/nnvm/gradient.cc | 30 ++++++-- tests/python/unittest/test_autograd.py | 67 +++++++++++++++- 9 files changed, 211 insertions(+), 25 deletions(-) diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index e759b767227d..926b31e08356 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -1274,6 +1274,14 @@ MXNET_DLL int MXAutogradMarkVariables(uint32_t num_var, NDArrayHandle *var_handles, uint32_t *reqs_array, NDArrayHandle *grad_handles); +/*! + * \brief unmark nonleaf NDArrays to free the memory + * \param num_var number of variable NDArrays + * \param var_handles variable NDArrays + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXAutogradDropGrads(uint32_t num_var, + NDArrayHandle *var_handles); /*! * \brief compute the gradient of outputs w.r.t variabels * \param num_output number of output NDArray diff --git a/include/mxnet/imperative.h b/include/mxnet/imperative.h index d998a74fde48..76ccf253d904 100644 --- a/include/mxnet/imperative.h +++ b/include/mxnet/imperative.h @@ -272,12 +272,16 @@ class Imperative { void MarkVariables(const std::vector& variables, const std::vector& grad_reqs, const std::vector& gradients); + /*! \brief unmark nonleaf variables to free the memory. */ + void DropGrads(const std::vector& variables); /*! \brief compute the gradient of outputs w.r.t variables. */ std::vector Backward(const std::vector& outputs, const std::vector& ograds, const std::vector& variables, bool is_train, bool retain_graph, bool create_graph); + /*! \brief Return the marked nonleaf nodes. */ + std::vector ListNonleafVariables(const nnvm::Symbol& sym) const; /*! \return AutogradRuntime singleton */ static Imperative* Get(); /*! \brief Should op execution bulking be employed during inference. */ diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py index cbd0c51d8431..1f49fc566c70 100644 --- a/python/mxnet/ndarray/ndarray.py +++ b/python/mxnet/ndarray/ndarray.py @@ -2885,6 +2885,11 @@ def attach_grad(self, grad_req='write', stype=None): ctypes.pointer(mx_uint(grad_req)), ctypes.pointer(grad.handle))) + def drop_grad(self): + """Free the memory of the marked ndarray.""" + check_call(_LIB.MXAutogradDropGrads( + 1, ctypes.pointer(self.handle))) + @property def grad(self): """Returns gradient buffer attached to this NDArray.""" diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index 5cca1fa9225a..c2d9db95f471 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -1410,6 +1410,11 @@ def attach_grad(self, grad_req='write'): # pylint: disable=arguments-differ ctypes.pointer(mx_uint(grad_req)), ctypes.pointer(grad.handle))) + def drop_grad(self): + """Free the memory of the marked ndarray.""" + check_call(_LIB.MXAutogradDropGrads( + 1, ctypes.pointer(self.handle))) + @property def grad(self): """Returns gradient buffer attached to this ndarray.""" diff --git a/src/c_api/c_api_ndarray.cc b/src/c_api/c_api_ndarray.cc index d967ae6e12b3..3d66996c4e32 100644 --- a/src/c_api/c_api_ndarray.cc +++ b/src/c_api/c_api_ndarray.cc @@ -335,6 +335,18 @@ int MXAutogradMarkVariables(uint32_t num_var, API_END(); } +int MXAutogradDropGrads(uint32_t num_var, + NDArrayHandle *var_handles) { + API_BEGIN(); + std::vector variables; + variables.reserve(num_var); + for (uint32_t i = 0; i < num_var; ++i) { + variables.emplace_back(static_cast(var_handles[i])); + } + Imperative::Get()->DropGrads(variables); + API_END(); +} + int MXAutogradComputeGradient(uint32_t num_output, NDArrayHandle* output_handles) { return MXAutogradBackward(num_output, output_handles, nullptr, 0); } diff --git a/src/imperative/exec_pass.h b/src/imperative/exec_pass.h index 440fc6b937ca..5f27a16fe695 100644 --- a/src/imperative/exec_pass.h +++ b/src/imperative/exec_pass.h @@ -287,12 +287,14 @@ inline Graph MXGradient( std::vector zero_ops = std::vector(), std::string copy_op_str = std::string(), mxnet::ShapeVector in_arg_shapes = mxnet::ShapeVector(), - DTypeVector in_arg_dtypes = DTypeVector()) { + DTypeVector in_arg_dtypes = DTypeVector(), + std::vector us = std::vector() ) { graph.attrs["grad_ys"] = std::make_shared(std::move(ys)); graph.attrs["grad_xs"] = std::make_shared(std::move(xs)); graph.attrs["grad_ys_out_grad"] = std::make_shared(std::move(ys_out_grad)); graph.attrs["in_arg_shapes"] = std::make_shared(std::move(in_arg_shapes)); graph.attrs["in_arg_dtypes"] = std::make_shared(std::move(in_arg_dtypes)); + graph.attrs["grad_us"] = std::make_shared(std::move(us)); if (aggregate_fun != nullptr) { graph.attrs["grad_aggregate_fun"] = std::make_shared(aggregate_fun); diff --git a/src/imperative/imperative.cc b/src/imperative/imperative.cc index 0ec5ae579dce..af1ee097ac1e 100644 --- a/src/imperative/imperative.cc +++ b/src/imperative/imperative.cc @@ -142,29 +142,54 @@ void Imperative::MarkVariables(const std::vector& variables, const std::vector& grad_reqs, const std::vector& gradients) { for (uint32_t i = 0; i < variables.size(); ++i) { - std::string str_c(std::to_string(variable_count_++)); - - variables[i]->autograd_entry_ = - nnvm::NodeEntry{nnvm::Symbol::CreateVariable("var" + str_c).outputs[0].node, 0, 0}; - AGInfo& info = AGInfo::Create(variables[i]->autograd_entry_.node); - info.outputs.emplace_back(variables[i]->Detach()); - info.out_grads.emplace_back(gradients[i]->Detach()); - info.grad_req = static_cast(grad_reqs[i]); - info.ctx = variables[i]->ctx(); - - gradients[i]->autograd_entry_ = - nnvm::NodeEntry{nnvm::Symbol::CreateVariable("grad" + str_c).outputs[0].node, 0, 0}; - AGInfo& grad_info = AGInfo::Create(gradients[i]->autograd_entry_.node); - grad_info.outputs.emplace_back(gradients[i]->Detach()); - grad_info.ctx = gradients[i]->ctx(); + // Unmarked leaf nodes have null autograd_entry_, while marked nonleaf nodes don't. + if (!variables[i]->autograd_entry_.node || variables[i]->autograd_entry_.node->is_variable()) { + std::string str_c(std::to_string(variable_count_++)); + variables[i]->autograd_entry_ = + nnvm::NodeEntry{nnvm::Symbol::CreateVariable("var" + str_c).outputs[0].node, 0, 0}; + AGInfo& info = AGInfo::Create(variables[i]->autograd_entry_.node); + info.outputs.emplace_back(variables[i]->Detach()); + info.out_grads.emplace_back(gradients[i]->Detach()); + info.grad_req = static_cast(grad_reqs[i]); + info.ctx = variables[i]->ctx(); + + gradients[i]->autograd_entry_ = + nnvm::NodeEntry{nnvm::Symbol::CreateVariable("grad" + str_c).outputs[0].node, 0, 0}; + AGInfo& grad_info = AGInfo::Create(gradients[i]->autograd_entry_.node); + grad_info.outputs.emplace_back(gradients[i]->Detach()); + grad_info.ctx = gradients[i]->ctx(); + } else { + AGInfo& info = AGInfo::Get(variables[i]->autograd_entry_.node); + CHECK_EQ(info.out_grads.size(), 0) + <<"The node has already been marked. Cannot mark it again."; + info.out_grads.emplace_back(gradients[i]->Detach()); + info.grad_req = static_cast(grad_reqs[i]); + info.ctx = variables[i]->ctx(); + } + } +} + +// Unmark the variables to free the memory. +void Imperative::DropGrads(const std::vector& variables) { + for (auto variable : variables) { + if (variable->autograd_entry_.node) { + AGInfo& info = AGInfo::Get(variable->autograd_entry_.node); + CHECK_NE(info.out_grads.size(), 0) + <<"The node has empty out_grads already. Cannot DropGrads again."; + for (auto grad : info.out_grads) { + grad.ReInit(); + } + info.out_grads.clear(); + info.grad_req = kNullOp; + } } } void Imperative::GetBackwardDependency(const nnvm::ObjectPtr& node, uint32_t num_inputs, uint32_t num_outputs, - std::vector* p_save_inputs, - std::vector* p_save_outputs) { + std::vector *p_save_inputs, + std::vector *p_save_outputs) { static auto& fgradient = nnvm::Op::GetAttr("FGradient"); std::vector& save_inputs = *p_save_inputs; std::vector& save_outputs = *p_save_outputs; @@ -488,6 +513,12 @@ std::vector Imperative::Backward(const std::vector& outputs, } CHECK_GT(xs.size(), 0) << "There are no inputs in computation graph that require gradients."; } + std::vector nleaf_vars = ListNonleafVariables(sym); + std::vector us; + us.reserve(nleaf_vars.size()); + for (const auto& i : nleaf_vars) { + us.emplace_back(NodeEntry{i, 0, 0}); + } Graph g_graph = pass::MXGradient(graph, graph.outputs, @@ -496,7 +527,10 @@ std::vector Imperative::Backward(const std::vector& outputs, mxnet::AggregateGradient, nullptr, zero_ops, - "_copy"); + "_copy", + ShapeVector(), + DTypeVector(), + us); CHECK_EQ(g_graph.outputs.size(), xs.size()); for (const auto& e : g_graph.outputs) { if (e.node->op() == nullptr) { @@ -575,6 +609,20 @@ std::vector Imperative::Backward(const std::vector& outputs, arrays[eid] = x_grads[i - num_forward_outputs]; ref_count[eid] = 1; } + const std::vector& us_grads = + g_graph.GetAttr>("nleaf_grads"); + CHECK_EQ(us_grads.size(), us.size()) + << "Size of queried nleaf_vars and size of their gradients don't match."; + for (size_t i = 0; i < us_grads.size(); i++) { + size_t eid = idx.entry_id(us_grads[i]); + AGInfo& info = AGInfo::Get(us[i].node); + if (arrays[eid]->dtype_ == -1) { + arrays[eid] = &info.out_grads[0]; + } else { + info.out_grads[0] = *arrays[eid]; + } + ref_count[eid] = 1; + } // Assign context auto vctx = PlaceDevice(idx); @@ -627,6 +675,11 @@ std::vector Imperative::Backward(const std::vector& outputs, size_t eid = idx.entry_id(idx.outputs()[i]); array_reqs[eid] = x_reqs[i - num_forward_outputs]; } + for (size_t i = 0; i < us_grads.size(); i++) { + size_t eid = idx.entry_id(us_grads[i]); + AGInfo& info = AGInfo::Get(us[i].node); + array_reqs[eid] = info.grad_req; + } const auto& shapes = graph.GetAttr("shape"); const auto& dtypes = graph.GetAttr("dtype"); @@ -766,4 +819,16 @@ void Imperative::DCInfo::Compute(const NDArray& arr) { info.outputs_.clear(); } +std::vector Imperative::ListNonleafVariables(const nnvm::Symbol& sym) const { + using namespace nnvm; + std::vector ret; + DFSVisit(sym.outputs, [&ret](const ObjectPtr& node) { + AGInfo& info = AGInfo::Get(node); + if (info.out_grads.size() > 0 && !node->is_variable()) { + ret.push_back(node); + } + }); + return ret; +} + } // namespace mxnet diff --git a/src/nnvm/gradient.cc b/src/nnvm/gradient.cc index 2609f78d7579..625eaa33e23b 100644 --- a/src/nnvm/gradient.cc +++ b/src/nnvm/gradient.cc @@ -62,7 +62,8 @@ Graph BuildGradientGraph(const Graph& src, const std::vector& topo_order, std::unordered_map > output_grads, std::function mirror_fun, - const std::unordered_map& mirror_map); + const std::unordered_map& mirror_map, + const std::vector& us = std::vector()); /*! * \brief Auxiliary function that maps the forward node of the source graph to @@ -88,6 +89,8 @@ Graph Gradient(Graph src) { const std::vector& ys_out_grad = src.GetAttr >("grad_ys_out_grad"); CHECK_EQ(ys.size(), ys_out_grad.size()); + const std::vector& us = + src.GetAttr >("grad_us"); // initialize a topological order of the graph nodes and `output_grads` // that maps every operator node to its gradient entries @@ -120,7 +123,7 @@ Graph Gradient(Graph src) { std::unordered_map mirror_map; // complete the backward graph of the src, but without backward mirroring - nnvm::Graph gsrc = BuildGradientGraph(src, xs, topo_order, output_grads, nullptr, mirror_map); + nnvm::Graph gsrc = BuildGradientGraph(src, xs, topo_order, output_grads, nullptr, mirror_map, us); if (mirror_fun == nullptr) { return gsrc; // Gradient pass without mirroring ends here. } @@ -504,12 +507,14 @@ inline bool CheckGradAllZero(const std::vector& grads, return true; } + Graph BuildGradientGraph(const Graph& src, const std::vector& xs, const std::vector& topo_order, std::unordered_map > output_grads, std::function mirror_fun, - const std::unordered_map& mirror_map) { + const std::unordered_map& mirror_map, + const std::vector& us) { static auto& grad_fun_map = Op::GetAttr("FGradient"); // gradient aggregation function @@ -617,7 +622,7 @@ Graph BuildGradientGraph(const Graph& src, CHECK(src_fwd_node->inputs.size() <= input_grads.size()); for (auto input_iter = src_fwd_node->inputs.begin(); input_iter != src_fwd_node->inputs.end(); ++input_iter, ++input_grad_iter) { - // propagate the input gradients to the output gradients of the input nodes + // propagate the input_grads to the corresponding GradEntries mapped by output_grads output_grads[input_iter->node.get()][input_iter->index].grads.emplace_back( std::move(*input_grad_iter)); } @@ -661,6 +666,20 @@ Graph BuildGradientGraph(const Graph& src, ret.outputs[kv.second.second] = kv.first; } } + + // Take the us' grad NodeEntry and store them in graph.attrs + std::vector nleaf_grads; + nleaf_grads.reserve(us.size()); + for (const NodeEntry& e : us) { + GradEntry& entry = output_grads[e.node.get()][e.index]; + // aggregate sum if it hasn't been + if (entry.sum.node.get() == nullptr) { + entry.sum = agg_fun(std::move(entry.grads)); + } + nleaf_grads.push_back(entry.sum); + } + ret.attrs["nleaf_grads"] = std::make_shared(std::move(nleaf_grads)); + return ret; } @@ -673,7 +692,8 @@ NNVM_REGISTER_PASS(MXGradient) .depend_graph_attr("grad_xs") .depend_graph_attr("in_arg_shapes") .depend_graph_attr("in_arg_dtypes") - .depend_graph_attr("grad_ys_out_grad"); + .depend_graph_attr("grad_ys_out_grad") + .depend_graph_attr("grad_us"); } // namespace diff --git a/tests/python/unittest/test_autograd.py b/tests/python/unittest/test_autograd.py index 554d830512b6..c48d20479f15 100644 --- a/tests/python/unittest/test_autograd.py +++ b/tests/python/unittest/test_autograd.py @@ -243,7 +243,7 @@ def test_detach_updated_grad(): assert x._fresh_grad == False -def test_retain_grad(): +def test_retain_graph(): x = mx.nd.ones((2, 2)) dx = mx.nd.zeros((2, 2)) mark_variables([x], [dx], grad_reqs='add') @@ -519,3 +519,68 @@ def test_gradient(): dx.backward() assert abs(x.grad.asscalar() - 2.71828175) < 1e-7 +def test_retain_grad_drop_grad(): + x = nd.array([1,2,3,4]) + x.attach_grad() + y = nd.array([5,6,7,8]) + y.attach_grad() + + with mx.autograd.record(): + u = x * y + z = u * x + + u.attach_grad() + z.attach_grad() + out_grad = nd.array([10, 10, 10, 10]) + z.backward(out_grad, retain_graph=True) + + assert (u.grad == out_grad * x).asnumpy().all() + assert (z.grad == out_grad).asnumpy().all() + assert (x.grad == out_grad * 2 * x * y).asnumpy().all() + assert (y.grad == out_grad * x*x).asnumpy().all() + + u.drop_grad() + z.drop_grad() + y.drop_grad() + out_grad = nd.array([0.1, 0.1, 0.1, 0.1]) + z.backward(out_grad) + + assert u.grad is None and z.grad is None and y.grad is None + assert (x.grad == out_grad * 2 * x * y).asnumpy().all() + +def test_retain_grad_drop_grad_gluon(): + class CompBlock(mx.gluon.HybridBlock): + def __init__(self): + super().__init__() + self.marked_var = None + def forward(self, a, b): + out1 = a*b + out2 = out1 * a + self.marked_var = out1 + return out2 + x = mx.np.array([1,2,3,4]) + y = mx.np.array([5,6,7,8]) + x.attach_grad() + y.attach_grad() + block2 = CompBlock() + block2.initialize() + # block2.hybridize() + with mx.autograd.record(): + z = block2(x, y) + u = block2.marked_var + u.attach_grad() + z.attach_grad() + z.backward(retain_graph=True) + + assert (u.grad == x).all() + assert (z.grad == mx.np.array([1,1,1,1])).all() + assert (x.grad == 2 * x * y).all() + assert (y.grad == x*x).all() + + u.drop_grad() + z.drop_grad() + y.drop_grad() + z.backward() + + assert u.grad is None and z.grad is None and y.grad is None + assert (x.grad == 2 * x * y).all()