Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[BUGFIX] Fix a bug in Auto Function. (#15184)
Browse files Browse the repository at this point in the history
* add test.

* fix.

* update test.
  • Loading branch information
zheng-da authored and szha committed Jun 8, 2019
1 parent 745a41c commit c5874dd
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 2 deletions.
3 changes: 1 addition & 2 deletions python/mxnet/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,15 +493,14 @@ def delete_entry(_):
POINTER(CFUNCTYPE(c_int))),
cast(c_array(c_void_p, [None]*len(callbacks)),
POINTER(c_void_p)))
Function._registry.ref_holder[key] = context
check_call(_LIB.MXCustomFunctionRecord(
c_int(len(inputs)),
c_handle_array(inputs),
c_int(len(outputs)),
c_handle_array(outputs),
ctypes.byref(context)))

Function._registry.ref_holder[key] = context

return ret_outputs

def forward(self, *inputs):
Expand Down
21 changes: 21 additions & 0 deletions tests/python/unittest/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,27 @@ def backward(self, dm, dn):
assert_almost_equal(y.grad.asnumpy(), dy1, atol=atol)


@with_seed()
def test_function1():
class Foo(mx.autograd.Function):
def __init__(self):
super(Foo, self).__init__()

def forward(self, X):
return X + 1;

def backward(self, dY):
return dY

with mx.autograd.record():
X = mx.nd.zeros((3, 4))
#X.attach_grad() # uncommenting this line works
for i in range(5):
f = Foo()
X = f(X)
X.wait_to_read()


@with_seed()
def test_get_symbol():
x = mx.nd.ones((1,))
Expand Down

0 comments on commit c5874dd

Please sign in to comment.