Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[NumPy] add support for np.ndarray in autograd.function #18790

Merged
merged 1 commit into from
Jul 25, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions python/mxnet/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from .ndarray import NDArray, _ndarray_cls
from .ndarray import _GRAD_REQ_MAP
from .symbol import Symbol
from .util import is_np_array


def set_recording(is_recording): #pylint: disable=redefined-outer-name
Expand Down Expand Up @@ -448,25 +449,30 @@ def __call__(self, *inputs):
outputs = (outputs,)

key = Function._registry.inc()
if is_np_array():
from .numpy import ndarray
array_cls = ndarray
else:
array_cls = NDArray

def backward_entry(num_ograds, num_igrads, ptrs, reqs, is_train, _):
"""entry point for backward."""
# pylint: disable=W0613
try:
output_grads = [NDArray(ctypes.cast(i, NDArrayHandle), writable=False) \
output_grads = [array_cls(ctypes.cast(i, NDArrayHandle), writable=False) \
for i in ptrs[:num_ograds]]
input_grads = [NDArray(ctypes.cast(i, NDArrayHandle), writable=True) \
input_grads = [array_cls(ctypes.cast(i, NDArrayHandle), writable=True) \
for i in ptrs[num_ograds:num_ograds+num_igrads]]
reqs = [reqs[i] for i in range(num_igrads)]
rets = self.backward(*output_grads)
if isinstance(rets, NDArray):
if isinstance(rets, array_cls):
rets = (rets,)
assert len(rets) == len(input_grads), \
"%s.backward must return exactly the same number " \
"of NDArrays as the number of NDArrays arguments to forward." \
"Expecting %d got %d"%(self.__class__.name, len(input_grads), len(rets))
for igrad, ret, req in zip(input_grads, rets, reqs):
assert isinstance(ret, NDArray), \
assert isinstance(ret, array_cls), \
"autograd.Function.backward must return NDArrays, not %s"%type(ret)
if req == 0: # null
return True
Expand Down
61 changes: 61 additions & 0 deletions tests/python/unittest/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,67 @@ def backward(self, dY):
X.wait_to_read()


@with_seed()
@pytest.mark.garbage_expected
@use_np
def test_np_function():
class func(Function):
def forward(self, x, y):
m = x / y
n = x * y
self.save_for_backward(x, y)
return m, n

def backward(self, dm, dn):
x, y = self.saved_tensors
dx = dm/y + dn*y
dy = dn*x - dm * x / y / y
return dx, dy

f = func()
x = mx.np.random.uniform(size=(10,))
x.attach_grad()
y = mx.np.random.uniform(size=(10,))
y.attach_grad()
with record():
m, n = f(x, y)
backward([m, n])

dx1 = x.grad.asnumpy()
dy1 = y.grad.asnumpy()

with record():
backward([x/y, x*y])

# Non-zero atol required, as exposed by seed 630179191
atol = 1e-6
assert_almost_equal(x.grad.asnumpy(), dx1, atol=atol)
assert_almost_equal(y.grad.asnumpy(), dy1, atol=atol)


@with_seed()
@pytest.mark.garbage_expected
@use_np
def test_np_function1():
class Foo(mx.autograd.Function):
def __init__(self):
super(Foo, self).__init__()

def forward(self, X):
return X + 1;

def backward(self, dY):
return dY

with mx.autograd.record():
X = mx.np.zeros((3, 4))
#X.attach_grad() # uncommenting this line works
for i in range(5):
f = Foo()
X = f(X)
X.wait_to_read()


@with_seed()
@pytest.mark.garbage_expected
def test_get_symbol():
Expand Down