diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index f1d0cc7ac274..d2fe8ceb00e7 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -3079,6 +3079,47 @@ def forward(self, x): shape = (np.random.randint(1, 10), np.random.randint(1, 10), 1) block(mx.nd.ones(shape)) +@with_seed() +def test_reqs_switching_training_inference(): + class Foo(gluon.HybridBlock): + def __init__(self, **kwargs): + super(Foo, self).__init__(**kwargs) + + def hybrid_forward(self, F, x): + y = 2 * x + return F.sqrt(x) + F.sqrt(y) + + f = Foo() + f.hybridize(static_alloc=True) + x = mx.nd.ones(shape=(10,10)) + x.attach_grad() + x2 = mx.nd.ones(shape=x.shape) * 2 + x2.attach_grad() + + # Call first in training mode + with mx.autograd.record(): + y = f(x) + y.backward() + + grad1 = x.grad.asnumpy() + + # Compute the gradient with some other input + with mx.autograd.record(): + y = f(x2) + y.backward() + + # Call inference mode + y = f(x) + + # Call training mode again + with mx.autograd.record(): + y = f(x) + y.backward() + + grad2 = x.grad.asnumpy() + + mx.test_utils.assert_almost_equal(grad1, grad2) + if __name__ == '__main__': import nose nose.runmodule()