diff --git a/python/mxnet/autograd.py b/python/mxnet/autograd.py index f968275a1390..aac7cbc21a17 100644 --- a/python/mxnet/autograd.py +++ b/python/mxnet/autograd.py @@ -28,6 +28,7 @@ from .ndarray import NDArray, _ndarray_cls from .ndarray import _GRAD_REQ_MAP from .symbol import Symbol +from .util import is_np_array def set_recording(is_recording): #pylint: disable=redefined-outer-name @@ -448,25 +449,30 @@ def __call__(self, *inputs): outputs = (outputs,) key = Function._registry.inc() + if is_np_array(): + from .numpy import ndarray + array_cls = ndarray + else: + array_cls = NDArray def backward_entry(num_ograds, num_igrads, ptrs, reqs, is_train, _): """entry point for backward.""" # pylint: disable=W0613 try: - output_grads = [NDArray(ctypes.cast(i, NDArrayHandle), writable=False) \ + output_grads = [array_cls(ctypes.cast(i, NDArrayHandle), writable=False) \ for i in ptrs[:num_ograds]] - input_grads = [NDArray(ctypes.cast(i, NDArrayHandle), writable=True) \ + input_grads = [array_cls(ctypes.cast(i, NDArrayHandle), writable=True) \ for i in ptrs[num_ograds:num_ograds+num_igrads]] reqs = [reqs[i] for i in range(num_igrads)] rets = self.backward(*output_grads) - if isinstance(rets, NDArray): + if isinstance(rets, array_cls): rets = (rets,) assert len(rets) == len(input_grads), \ "%s.backward must return exactly the same number " \ "of NDArrays as the number of NDArrays arguments to forward." \ "Expecting %d got %d"%(self.__class__.name, len(input_grads), len(rets)) for igrad, ret, req in zip(input_grads, rets, reqs): - assert isinstance(ret, NDArray), \ + assert isinstance(ret, array_cls), \ "autograd.Function.backward must return NDArrays, not %s"%type(ret) if req == 0: # null return True diff --git a/tests/python/unittest/test_autograd.py b/tests/python/unittest/test_autograd.py index 6a75eed7d0bb..f9a7eccfa0a5 100644 --- a/tests/python/unittest/test_autograd.py +++ b/tests/python/unittest/test_autograd.py @@ -405,6 +405,67 @@ def backward(self, dY): X.wait_to_read() +@with_seed() +@pytest.mark.garbage_expected +@use_np +def test_np_function(): + class func(Function): + def forward(self, x, y): + m = x / y + n = x * y + self.save_for_backward(x, y) + return m, n + + def backward(self, dm, dn): + x, y = self.saved_tensors + dx = dm/y + dn*y + dy = dn*x - dm * x / y / y + return dx, dy + + f = func() + x = mx.np.random.uniform(size=(10,)) + x.attach_grad() + y = mx.np.random.uniform(size=(10,)) + y.attach_grad() + with record(): + m, n = f(x, y) + backward([m, n]) + + dx1 = x.grad.asnumpy() + dy1 = y.grad.asnumpy() + + with record(): + backward([x/y, x*y]) + + # Non-zero atol required, as exposed by seed 630179191 + atol = 1e-6 + assert_almost_equal(x.grad.asnumpy(), dx1, atol=atol) + assert_almost_equal(y.grad.asnumpy(), dy1, atol=atol) + + +@with_seed() +@pytest.mark.garbage_expected +@use_np +def test_np_function1(): + class Foo(mx.autograd.Function): + def __init__(self): + super(Foo, self).__init__() + + def forward(self, X): + return X + 1; + + def backward(self, dY): + return dY + + with mx.autograd.record(): + X = mx.np.zeros((3, 4)) + #X.attach_grad() # uncommenting this line works + for i in range(5): + f = Foo() + X = f(X) + X.wait_to_read() + + @with_seed() @pytest.mark.garbage_expected def test_get_symbol():