Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[FEATURE] Add feature of retain_grad (#20500)
Browse files Browse the repository at this point in the history
* Replace "CloneGradient" with "ElemwiseGradUseNone"

* fix issue elemwise_add

* fix elemwise_add issue with `ElemwiseGradUseNone`

* reverse_to_CloneGradient

* add_retain_grad

* unit_test

* tidy_up

* tidy_up

* sanity

* const_reference

* const_ref

* merge_rg_to_ag

* sanity

* sanity

* add_drop_grad

* sanity_check

* sanity_check

* sanity_check

* build_err

* build_err

* skip_remark_variable

* repetitive_mark

* ReInit_in_dropgrad

* ReInit_in_dropgrad

* sanity_check

* add drop and tests to gluon

* sanity

* update exec_pass.h

Co-authored-by: Zhenghui Jin <[email protected]>
  • Loading branch information
KexinFeng and barry-jin authored Sep 23, 2021
1 parent c1e06aa commit 05d1b26
Show file tree
Hide file tree
Showing 9 changed files with 211 additions and 25 deletions.
8 changes: 8 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1274,6 +1274,14 @@ MXNET_DLL int MXAutogradMarkVariables(uint32_t num_var,
NDArrayHandle *var_handles,
uint32_t *reqs_array,
NDArrayHandle *grad_handles);
/*!
* \brief unmark nonleaf NDArrays to free the memory
* \param num_var number of variable NDArrays
* \param var_handles variable NDArrays
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXAutogradDropGrads(uint32_t num_var,
NDArrayHandle *var_handles);
/*!
* \brief compute the gradient of outputs w.r.t variabels
* \param num_output number of output NDArray
Expand Down
4 changes: 4 additions & 0 deletions include/mxnet/imperative.h
Original file line number Diff line number Diff line change
Expand Up @@ -272,12 +272,16 @@ class Imperative {
void MarkVariables(const std::vector<NDArray*>& variables,
const std::vector<uint32_t>& grad_reqs,
const std::vector<NDArray*>& gradients);
/*! \brief unmark nonleaf variables to free the memory. */
void DropGrads(const std::vector<NDArray*>& variables);
/*! \brief compute the gradient of outputs w.r.t variables. */
std::vector<NDArray*> Backward(const std::vector<NDArray*>& outputs,
const std::vector<NDArray*>& ograds,
const std::vector<NDArray*>& variables,
bool is_train, bool retain_graph,
bool create_graph);
/*! \brief Return the marked nonleaf nodes. */
std::vector<nnvm::ObjectPtr> ListNonleafVariables(const nnvm::Symbol& sym) const;
/*! \return AutogradRuntime singleton */
static Imperative* Get();
/*! \brief Should op execution bulking be employed during inference. */
Expand Down
5 changes: 5 additions & 0 deletions python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2885,6 +2885,11 @@ def attach_grad(self, grad_req='write', stype=None):
ctypes.pointer(mx_uint(grad_req)),
ctypes.pointer(grad.handle)))

def drop_grad(self):
"""Free the memory of the marked ndarray."""
check_call(_LIB.MXAutogradDropGrads(
1, ctypes.pointer(self.handle)))

@property
def grad(self):
"""Returns gradient buffer attached to this NDArray."""
Expand Down
5 changes: 5 additions & 0 deletions python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1410,6 +1410,11 @@ def attach_grad(self, grad_req='write'): # pylint: disable=arguments-differ
ctypes.pointer(mx_uint(grad_req)),
ctypes.pointer(grad.handle)))

def drop_grad(self):
"""Free the memory of the marked ndarray."""
check_call(_LIB.MXAutogradDropGrads(
1, ctypes.pointer(self.handle)))

@property
def grad(self):
"""Returns gradient buffer attached to this ndarray."""
Expand Down
12 changes: 12 additions & 0 deletions src/c_api/c_api_ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,18 @@ int MXAutogradMarkVariables(uint32_t num_var,
API_END();
}

int MXAutogradDropGrads(uint32_t num_var,
NDArrayHandle *var_handles) {
API_BEGIN();
std::vector<NDArray*> variables;
variables.reserve(num_var);
for (uint32_t i = 0; i < num_var; ++i) {
variables.emplace_back(static_cast<NDArray*>(var_handles[i]));
}
Imperative::Get()->DropGrads(variables);
API_END();
}

int MXAutogradComputeGradient(uint32_t num_output, NDArrayHandle* output_handles) {
return MXAutogradBackward(num_output, output_handles, nullptr, 0);
}
Expand Down
4 changes: 3 additions & 1 deletion src/imperative/exec_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -287,12 +287,14 @@ inline Graph MXGradient(
std::vector<const Op*> zero_ops = std::vector<const Op*>(),
std::string copy_op_str = std::string(),
mxnet::ShapeVector in_arg_shapes = mxnet::ShapeVector(),
DTypeVector in_arg_dtypes = DTypeVector()) {
DTypeVector in_arg_dtypes = DTypeVector(),
std::vector<NodeEntry> us = std::vector<NodeEntry>() ) {
graph.attrs["grad_ys"] = std::make_shared<any>(std::move(ys));
graph.attrs["grad_xs"] = std::make_shared<any>(std::move(xs));
graph.attrs["grad_ys_out_grad"] = std::make_shared<any>(std::move(ys_out_grad));
graph.attrs["in_arg_shapes"] = std::make_shared<any>(std::move(in_arg_shapes));
graph.attrs["in_arg_dtypes"] = std::make_shared<any>(std::move(in_arg_dtypes));
graph.attrs["grad_us"] = std::make_shared<any>(std::move(us));

if (aggregate_fun != nullptr) {
graph.attrs["grad_aggregate_fun"] = std::make_shared<any>(aggregate_fun);
Expand Down
101 changes: 83 additions & 18 deletions src/imperative/imperative.cc
Original file line number Diff line number Diff line change
Expand Up @@ -142,29 +142,54 @@ void Imperative::MarkVariables(const std::vector<NDArray*>& variables,
const std::vector<uint32_t>& grad_reqs,
const std::vector<NDArray*>& gradients) {
for (uint32_t i = 0; i < variables.size(); ++i) {
std::string str_c(std::to_string(variable_count_++));

variables[i]->autograd_entry_ =
nnvm::NodeEntry{nnvm::Symbol::CreateVariable("var" + str_c).outputs[0].node, 0, 0};
AGInfo& info = AGInfo::Create(variables[i]->autograd_entry_.node);
info.outputs.emplace_back(variables[i]->Detach());
info.out_grads.emplace_back(gradients[i]->Detach());
info.grad_req = static_cast<OpReqType>(grad_reqs[i]);
info.ctx = variables[i]->ctx();

gradients[i]->autograd_entry_ =
nnvm::NodeEntry{nnvm::Symbol::CreateVariable("grad" + str_c).outputs[0].node, 0, 0};
AGInfo& grad_info = AGInfo::Create(gradients[i]->autograd_entry_.node);
grad_info.outputs.emplace_back(gradients[i]->Detach());
grad_info.ctx = gradients[i]->ctx();
// Unmarked leaf nodes have null autograd_entry_, while marked nonleaf nodes don't.
if (!variables[i]->autograd_entry_.node || variables[i]->autograd_entry_.node->is_variable()) {
std::string str_c(std::to_string(variable_count_++));
variables[i]->autograd_entry_ =
nnvm::NodeEntry{nnvm::Symbol::CreateVariable("var" + str_c).outputs[0].node, 0, 0};
AGInfo& info = AGInfo::Create(variables[i]->autograd_entry_.node);
info.outputs.emplace_back(variables[i]->Detach());
info.out_grads.emplace_back(gradients[i]->Detach());
info.grad_req = static_cast<OpReqType>(grad_reqs[i]);
info.ctx = variables[i]->ctx();

gradients[i]->autograd_entry_ =
nnvm::NodeEntry{nnvm::Symbol::CreateVariable("grad" + str_c).outputs[0].node, 0, 0};
AGInfo& grad_info = AGInfo::Create(gradients[i]->autograd_entry_.node);
grad_info.outputs.emplace_back(gradients[i]->Detach());
grad_info.ctx = gradients[i]->ctx();
} else {
AGInfo& info = AGInfo::Get(variables[i]->autograd_entry_.node);
CHECK_EQ(info.out_grads.size(), 0)
<<"The node has already been marked. Cannot mark it again.";
info.out_grads.emplace_back(gradients[i]->Detach());
info.grad_req = static_cast<OpReqType>(grad_reqs[i]);
info.ctx = variables[i]->ctx();
}
}
}

// Unmark the variables to free the memory.
void Imperative::DropGrads(const std::vector<NDArray*>& variables) {
for (auto variable : variables) {
if (variable->autograd_entry_.node) {
AGInfo& info = AGInfo::Get(variable->autograd_entry_.node);
CHECK_NE(info.out_grads.size(), 0)
<<"The node has empty out_grads already. Cannot DropGrads again.";
for (auto grad : info.out_grads) {
grad.ReInit();
}
info.out_grads.clear();
info.grad_req = kNullOp;
}
}
}

void Imperative::GetBackwardDependency(const nnvm::ObjectPtr& node,
uint32_t num_inputs,
uint32_t num_outputs,
std::vector<bool>* p_save_inputs,
std::vector<bool>* p_save_outputs) {
std::vector<bool> *p_save_inputs,
std::vector<bool> *p_save_outputs) {
static auto& fgradient = nnvm::Op::GetAttr<nnvm::FGradient>("FGradient");
std::vector<bool>& save_inputs = *p_save_inputs;
std::vector<bool>& save_outputs = *p_save_outputs;
Expand Down Expand Up @@ -488,6 +513,12 @@ std::vector<NDArray*> Imperative::Backward(const std::vector<NDArray*>& outputs,
}
CHECK_GT(xs.size(), 0) << "There are no inputs in computation graph that require gradients.";
}
std::vector<ObjectPtr> nleaf_vars = ListNonleafVariables(sym);
std::vector<NodeEntry> us;
us.reserve(nleaf_vars.size());
for (const auto& i : nleaf_vars) {
us.emplace_back(NodeEntry{i, 0, 0});
}

Graph g_graph = pass::MXGradient(graph,
graph.outputs,
Expand All @@ -496,7 +527,10 @@ std::vector<NDArray*> Imperative::Backward(const std::vector<NDArray*>& outputs,
mxnet::AggregateGradient,
nullptr,
zero_ops,
"_copy");
"_copy",
ShapeVector(),
DTypeVector(),
us);
CHECK_EQ(g_graph.outputs.size(), xs.size());
for (const auto& e : g_graph.outputs) {
if (e.node->op() == nullptr) {
Expand Down Expand Up @@ -575,6 +609,20 @@ std::vector<NDArray*> Imperative::Backward(const std::vector<NDArray*>& outputs,
arrays[eid] = x_grads[i - num_forward_outputs];
ref_count[eid] = 1;
}
const std::vector<NodeEntry>& us_grads =
g_graph.GetAttr<std::vector<NodeEntry>>("nleaf_grads");
CHECK_EQ(us_grads.size(), us.size())
<< "Size of queried nleaf_vars and size of their gradients don't match.";
for (size_t i = 0; i < us_grads.size(); i++) {
size_t eid = idx.entry_id(us_grads[i]);
AGInfo& info = AGInfo::Get(us[i].node);
if (arrays[eid]->dtype_ == -1) {
arrays[eid] = &info.out_grads[0];
} else {
info.out_grads[0] = *arrays[eid];
}
ref_count[eid] = 1;
}

// Assign context
auto vctx = PlaceDevice(idx);
Expand Down Expand Up @@ -627,6 +675,11 @@ std::vector<NDArray*> Imperative::Backward(const std::vector<NDArray*>& outputs,
size_t eid = idx.entry_id(idx.outputs()[i]);
array_reqs[eid] = x_reqs[i - num_forward_outputs];
}
for (size_t i = 0; i < us_grads.size(); i++) {
size_t eid = idx.entry_id(us_grads[i]);
AGInfo& info = AGInfo::Get(us[i].node);
array_reqs[eid] = info.grad_req;
}

const auto& shapes = graph.GetAttr<mxnet::ShapeVector>("shape");
const auto& dtypes = graph.GetAttr<DTypeVector>("dtype");
Expand Down Expand Up @@ -766,4 +819,16 @@ void Imperative::DCInfo::Compute(const NDArray& arr) {
info.outputs_.clear();
}

std::vector<nnvm::ObjectPtr> Imperative::ListNonleafVariables(const nnvm::Symbol& sym) const {
using namespace nnvm;
std::vector<ObjectPtr> ret;
DFSVisit(sym.outputs, [&ret](const ObjectPtr& node) {
AGInfo& info = AGInfo::Get(node);
if (info.out_grads.size() > 0 && !node->is_variable()) {
ret.push_back(node);
}
});
return ret;
}

} // namespace mxnet
30 changes: 25 additions & 5 deletions src/nnvm/gradient.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ Graph BuildGradientGraph(const Graph& src,
const std::vector<ObjectPtr>& topo_order,
std::unordered_map<const Node*, std::vector<GradEntry> > output_grads,
std::function<int(const Node&)> mirror_fun,
const std::unordered_map<const Node*, ObjectPtr>& mirror_map);
const std::unordered_map<const Node*, ObjectPtr>& mirror_map,
const std::vector<NodeEntry>& us = std::vector<NodeEntry>());

/*!
* \brief Auxiliary function that maps the forward node of the source graph to
Expand All @@ -88,6 +89,8 @@ Graph Gradient(Graph src) {
const std::vector<NodeEntry>& ys_out_grad =
src.GetAttr<std::vector<NodeEntry> >("grad_ys_out_grad");
CHECK_EQ(ys.size(), ys_out_grad.size());
const std::vector<NodeEntry>& us =
src.GetAttr<std::vector<NodeEntry> >("grad_us");

// initialize a topological order of the graph nodes and `output_grads`
// that maps every operator node to its gradient entries
Expand Down Expand Up @@ -120,7 +123,7 @@ Graph Gradient(Graph src) {
std::unordered_map<const Node*, ObjectPtr> mirror_map;

// complete the backward graph of the src, but without backward mirroring
nnvm::Graph gsrc = BuildGradientGraph(src, xs, topo_order, output_grads, nullptr, mirror_map);
nnvm::Graph gsrc = BuildGradientGraph(src, xs, topo_order, output_grads, nullptr, mirror_map, us);
if (mirror_fun == nullptr) {
return gsrc; // Gradient pass without mirroring ends here.
}
Expand Down Expand Up @@ -504,12 +507,14 @@ inline bool CheckGradAllZero(const std::vector<NodeEntry>& grads,
return true;
}


Graph BuildGradientGraph(const Graph& src,
const std::vector<NodeEntry>& xs,
const std::vector<ObjectPtr>& topo_order,
std::unordered_map<const Node*, std::vector<GradEntry> > output_grads,
std::function<int(const Node&)> mirror_fun,
const std::unordered_map<const Node*, ObjectPtr>& mirror_map) {
const std::unordered_map<const Node*, ObjectPtr>& mirror_map,
const std::vector<NodeEntry>& us) {
static auto& grad_fun_map = Op::GetAttr<nnvm::FGradient>("FGradient");

// gradient aggregation function
Expand Down Expand Up @@ -617,7 +622,7 @@ Graph BuildGradientGraph(const Graph& src,
CHECK(src_fwd_node->inputs.size() <= input_grads.size());
for (auto input_iter = src_fwd_node->inputs.begin(); input_iter != src_fwd_node->inputs.end();
++input_iter, ++input_grad_iter) {
// propagate the input gradients to the output gradients of the input nodes
// propagate the input_grads to the corresponding GradEntries mapped by output_grads
output_grads[input_iter->node.get()][input_iter->index].grads.emplace_back(
std::move(*input_grad_iter));
}
Expand Down Expand Up @@ -661,6 +666,20 @@ Graph BuildGradientGraph(const Graph& src,
ret.outputs[kv.second.second] = kv.first;
}
}

// Take the us' grad NodeEntry and store them in graph.attrs
std::vector<NodeEntry> nleaf_grads;
nleaf_grads.reserve(us.size());
for (const NodeEntry& e : us) {
GradEntry& entry = output_grads[e.node.get()][e.index];
// aggregate sum if it hasn't been
if (entry.sum.node.get() == nullptr) {
entry.sum = agg_fun(std::move(entry.grads));
}
nleaf_grads.push_back(entry.sum);
}
ret.attrs["nleaf_grads"] = std::make_shared<any>(std::move(nleaf_grads));

return ret;
}

Expand All @@ -673,7 +692,8 @@ NNVM_REGISTER_PASS(MXGradient)
.depend_graph_attr("grad_xs")
.depend_graph_attr("in_arg_shapes")
.depend_graph_attr("in_arg_dtypes")
.depend_graph_attr("grad_ys_out_grad");
.depend_graph_attr("grad_ys_out_grad")
.depend_graph_attr("grad_us");

} // namespace

Expand Down
67 changes: 66 additions & 1 deletion tests/python/unittest/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def test_detach_updated_grad():
assert x._fresh_grad == False


def test_retain_grad():
def test_retain_graph():
x = mx.nd.ones((2, 2))
dx = mx.nd.zeros((2, 2))
mark_variables([x], [dx], grad_reqs='add')
Expand Down Expand Up @@ -519,3 +519,68 @@ def test_gradient():
dx.backward()
assert abs(x.grad.asscalar() - 2.71828175) < 1e-7

def test_retain_grad_drop_grad():
x = nd.array([1,2,3,4])
x.attach_grad()
y = nd.array([5,6,7,8])
y.attach_grad()

with mx.autograd.record():
u = x * y
z = u * x

u.attach_grad()
z.attach_grad()
out_grad = nd.array([10, 10, 10, 10])
z.backward(out_grad, retain_graph=True)

assert (u.grad == out_grad * x).asnumpy().all()
assert (z.grad == out_grad).asnumpy().all()
assert (x.grad == out_grad * 2 * x * y).asnumpy().all()
assert (y.grad == out_grad * x*x).asnumpy().all()

u.drop_grad()
z.drop_grad()
y.drop_grad()
out_grad = nd.array([0.1, 0.1, 0.1, 0.1])
z.backward(out_grad)

assert u.grad is None and z.grad is None and y.grad is None
assert (x.grad == out_grad * 2 * x * y).asnumpy().all()

def test_retain_grad_drop_grad_gluon():
class CompBlock(mx.gluon.HybridBlock):
def __init__(self):
super().__init__()
self.marked_var = None
def forward(self, a, b):
out1 = a*b
out2 = out1 * a
self.marked_var = out1
return out2
x = mx.np.array([1,2,3,4])
y = mx.np.array([5,6,7,8])
x.attach_grad()
y.attach_grad()
block2 = CompBlock()
block2.initialize()
# block2.hybridize()
with mx.autograd.record():
z = block2(x, y)
u = block2.marked_var
u.attach_grad()
z.attach_grad()
z.backward(retain_graph=True)

assert (u.grad == x).all()
assert (z.grad == mx.np.array([1,1,1,1])).all()
assert (x.grad == 2 * x * y).all()
assert (y.grad == x*x).all()

u.drop_grad()
z.drop_grad()
y.drop_grad()
z.backward()

assert u.grad is None and z.grad is None and y.grad is None
assert (x.grad == 2 * x * y).all()

0 comments on commit 05d1b26

Please sign in to comment.