Skip to content

Commit

Permalink
Test showcasing an issue fixed in PR apache#16553
Browse files Browse the repository at this point in the history
  • Loading branch information
ptrendx committed Oct 22, 2019
1 parent 80e36ba commit 36e5ce8
Showing 1 changed file with 41 additions and 0 deletions.
41 changes: 41 additions & 0 deletions tests/python/unittest/test_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -3079,6 +3079,47 @@ def forward(self, x):
shape = (np.random.randint(1, 10), np.random.randint(1, 10), 1)
block(mx.nd.ones(shape))

@with_seed()
def test_reqs_switching_training_inference():
class Foo(gluon.HybridBlock):
def __init__(self, **kwargs):
super(Foo, self).__init__(**kwargs)

def hybrid_forward(self, F, x):
y = 2 * x
return F.sqrt(x) + F.sqrt(y)

f = Foo()
f.hybridize(static_alloc=True)
x = mx.nd.ones(shape=(10,10))
x.attach_grad()
x2 = mx.nd.ones(shape=x.shape) * 2
x2.attach_grad()

# Call first in training mode
with mx.autograd.record():
y = f(x)
y.backward()

grad1 = x.grad.asnumpy()

# Compute the gradient with some other input
with mx.autograd.record():
y = f(x2)
y.backward()

# Call inference mode
y = f(x)

# Call training mode again
with mx.autograd.record():
y = f(x)
y.backward()

grad2 = x.grad.asnumpy()

mx.test_utils.assert_almost_equal(grad1, grad2)

if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit 36e5ce8

Please sign in to comment.