Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add unit test for record false
Browse files Browse the repository at this point in the history
  • Loading branch information
azai91 committed Aug 30, 2018
1 parent 9ac7c35 commit 1ccef95
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions tests/python/unittest/test_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -1172,7 +1172,7 @@ def test_zero_grad():
grad = net.collect_params()['test_zero_grad_weight'].grad()
assert_almost_equal(grad.asnumpy(), grad.asnumpy() * 0)

def check_hybrid_static_memory(**kwargs):
def check_hybrid_static_memory():
x = mx.nd.random.uniform(shape=(2, 3, 32, 32))
x.attach_grad()

Expand All @@ -1184,24 +1184,25 @@ def check_hybrid_static_memory(**kwargs):
net1(x)
net2(x)

def test(net, x):
with mx.autograd.record():
def test(net, x, record=True):
with mx.autograd.record(record):
y = net(x) + net(x)
y.backward()

grads = {k: v.grad() for k, v in net.collect_params().items() if v.grad_req != 'null'}

return y, grads

y1, grads1 = test(net1, x)
y2, grads2 = test(net2, x)
for record in (True, False):
y1, grads1 = test(net1, x, record)
y2, grads2 = test(net2, x, record)

assert_almost_equal(y1.asnumpy(), y2.asnumpy(), rtol=1e-3, atol=1e-5)
for key in grads1:
assert_almost_equal(grads1[key].asnumpy(), grads2[key].asnumpy(), rtol=1e-3, atol=1e-5)
assert_almost_equal(y1.asnumpy(), y2.asnumpy(), rtol=1e-3, atol=1e-5)
for key in grads1:
assert_almost_equal(grads1[key].asnumpy(), grads2[key].asnumpy(), rtol=1e-3, atol=1e-5)

@with_seed()
def test_hybrid_static_memory():
def test_hybrid_static_memory(**kwargs):
check_hybrid_static_memory()
check_hybrid_static_memory(static_alloc=True)
check_hybrid_static_memory(static_alloc=True, static_shape=True)
Expand Down

0 comments on commit 1ccef95

Please sign in to comment.