Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[BUGFIX] Fix #20293 (#20462)
Browse files Browse the repository at this point in the history
* fix 20293

* avoid state.array_reqs being overrided by reqs

* update

* fix AddTo grad_req in staticbackward

* fix lint

* fix executor
  • Loading branch information
barry-jin authored Oct 27, 2021
1 parent 64999b4 commit 026dbf8
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 9 deletions.
5 changes: 3 additions & 2 deletions python/mxnet/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,15 @@ class Executor:
>>> c = 2 * a + b
>>> texec = c._bind(mx.cpu(), {'a': mx.nd.array([1,2]), 'b':mx.nd.array([2,3])})
"""
def __init__(self, sym, ctx, args, args_grad, grad_req, aux_states):
def __init__(self, sym, ctx, args, args_grad, grad_req, aux_states, static_alloc=False):
self.outputs = None
self._input_names = sym.list_inputs()
self._aux_names = sym.list_auxiliary_states()
self._arg_names = sym.list_arguments()
self._output_names = sym.list_outputs()
self._ctx = ctx
self._grad_req = grad_req
self.static_alloc = static_alloc
# grad_req
self._requires_grad = False
if isinstance(grad_req, dict):
Expand Down Expand Up @@ -121,7 +122,7 @@ def __init__(self, sym, ctx, args, args_grad, grad_req, aux_states):
with self._ctx:
self._args[i].attach_grad(req, stype=g.stype)
self._args[i].grad[:] = g
self._cached_op = ndarray.CachedOp(sym)
self._cached_op = ndarray.CachedOp(sym, flags=[("static_alloc", self.static_alloc)])

def get_optimized_symbol(self):
"""Get an optimized version of the symbol from the executor.
Expand Down
7 changes: 5 additions & 2 deletions python/mxnet/symbol/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -1793,7 +1793,7 @@ def _simple_bind(self, ctx, grad_req='write', type_dict=None, stype_dict=None,
return Executor(self, ctx, args, args_grad, grad_req, aux_states)

def _bind(self, ctx, args, args_grad=None, grad_req='write',
aux_states=None):
aux_states=None, static_alloc=False):
"""Binds the current symbol to an executor and returns it.
We first declare the computation and then bind to the data to run.
Expand Down Expand Up @@ -1856,6 +1856,9 @@ def _bind(self, ctx, args, args_grad=None, grad_req='write',
`auxiliary_states` to the corresponding `NDArray`,
- In either case, all the auxiliary states need to be provided.
static_alloc : bool, default False
Statically allocate memory to improve speed. Memory usage may increase.
Returns
-------
executor : Executor
Expand All @@ -1874,7 +1877,7 @@ def _bind(self, ctx, args, args_grad=None, grad_req='write',
gradient they interested in.
"""
assert isinstance(grad_req, (str, dict))
return Executor(self, ctx, args, args_grad, grad_req, aux_states)
return Executor(self, ctx, args, args_grad, grad_req, aux_states, static_alloc)

def gradient(self, wrt):
"""Gets the autodiff of current symbol.
Expand Down
21 changes: 17 additions & 4 deletions src/imperative/cached_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,21 @@ bool CachedOp::SetBackwardGraph(GraphInfo* info,
g.attrs[AddPrefix(BACKWARD, REF_COUNT)] = std::make_shared<dmlc::any>(std::move(ref_count));
}

// Set AddTo Entry based on the req that users provide
if (detect_inplace_addto) {
std::vector<int> addto_entry(idx.num_node_entries(), 0);
for (size_t i = 0; i < info->grad_graph.outputs.size(); ++i) {
if (reqs[i] == kAddTo) {
auto entry = info->grad_graph.outputs[i];
if (!idx.exist(entry.node.get()))
continue;
auto eid = idx.entry_id(entry);
addto_entry[eid] = 1;
}
}
g.attrs["addto_entry"] = std::make_shared<nnvm::any>(std::move(addto_entry));
}

auto shapes = info->fwd_graph.GetAttr<mxnet::ShapeVector>("shape");
shapes.resize(idx.num_node_entries(), mxnet::TShape());
auto dtypes = info->fwd_graph.GetAttr<DTypeVector>("dtype");
Expand Down Expand Up @@ -1047,8 +1062,7 @@ void CachedOp::StaticBackward(const bool retain_graph,
auto entry = state.info.grad_graph.outputs[iter->second];
if (!idx.exist(entry.node.get()))
continue;
auto eid = idx.entry_id(entry);
state.array_reqs[eid] = reqs[iter->second];
auto eid = idx.entry_id(entry);
// An input and an output may share the same array.
INIT_DETACHED(outputs[iter->second], arrays[eid]);
arrays[eid] = outputs[iter->second];
Expand All @@ -1058,8 +1072,7 @@ void CachedOp::StaticBackward(const bool retain_graph,
auto entry = state.info.grad_graph.outputs[i];
if (!idx.exist(entry.node.get()))
continue;
auto eid = idx.entry_id(entry);
state.array_reqs[eid] = reqs[i];
auto eid = idx.entry_id(entry);
// An input and an output may share the same array.
INIT_DETACHED(outputs[i], arrays[eid]);
arrays[eid] = outputs[i];
Expand Down
7 changes: 6 additions & 1 deletion src/imperative/inplace_addto_detect_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,12 @@ Graph DetectInplaceAddTo(Graph g) {
auto& idx = g.indexed_graph();
// reference cont.
std::vector<int> ref_count(idx.num_node_entries(), 0);
std::vector<int> addto_entry(idx.num_node_entries(), 0);
std::vector<int> addto_entry;
if (g.attrs.count("addto_entry")) {
addto_entry = g.GetAttr<std::vector<int> >("addto_entry");
} else {
addto_entry = std::vector<int>(idx.num_node_entries(), 0);
}
std::vector<int> skip_plus_node(idx.num_nodes(), 0);

for (auto& e : idx.outputs()) {
Expand Down
18 changes: 18 additions & 0 deletions tests/python/unittest/test_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,21 @@ def check_init(static_alloc, static_shape):
check_init(False, False)
check_init(True, False)
check_init(True, True)

def test_elemwise_add_grad():
json = "{\"nodes\": [{\"op\":\"null\",\"name\":\".Inputs.Input1\",\"inputs\":[]},{\"op\":\"null\",\"name\":\".Inputs.Input2\",\"inputs\":[]},{\"op\":\"elemwise_add\",\"name\":\".$0\",\"inputs\":[[0,0,0],[1,0,0]]},{\"op\":\"_copy\",\"name\":\".Outputs.Output\",\"inputs\":[[2,0,0]]}],\"arg_nodes\":[0,1],\"heads\":[[3,0,0]]}"
sym = mx.symbol.fromjson(json)

ex = sym._bind(
mx.cpu(),
{'.Inputs.Input1': mx.nd.array([0.4]), '.Inputs.Input2': mx.nd.array([0.5])},
args_grad={
'.Inputs.Input1': mx.ndarray.zeros((1)),
'.Inputs.Input2': mx.ndarray.zeros((1))
},
grad_req={'.Inputs.Input1': 'null', '.Inputs.Input2': 'write'}
)
ex.forward(is_train=True)
print(ex.outputs)
ex.backward(out_grads=mx.nd.array([1]))
print(ex.grad_arrays)

0 comments on commit 026dbf8

Please sign in to comment.