Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add unit test for record false
Browse files Browse the repository at this point in the history
  • Loading branch information
azai91 committed Nov 1, 2018
1 parent 7bcc962 commit 64fe9c4
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions tests/python/unittest/test_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -1220,21 +1220,22 @@ def check_hybrid_static_memory(**kwargs):
net1(x)
net2(x)

def test(net, x):
with mx.autograd.record():
def test(net, x, record=True):
with mx.autograd.record(record):
y = net(x) + net(x)
y.backward()

grads = {k: v.grad() for k, v in net.collect_params().items() if v.grad_req != 'null'}

return y, grads

y1, grads1 = test(net1, x)
y2, grads2 = test(net2, x)
for record in (True, False):
y1, grads1 = test(net1, x, record)
y2, grads2 = test(net2, x, record)

assert_almost_equal(y1.asnumpy(), y2.asnumpy(), rtol=1e-3, atol=1e-5)
for key in grads1:
assert_almost_equal(grads1[key].asnumpy(), grads2[key].asnumpy(), rtol=1e-3, atol=1e-5)
assert_almost_equal(y1.asnumpy(), y2.asnumpy(), rtol=1e-3, atol=1e-5)
for key in grads1:
assert_almost_equal(grads1[key].asnumpy(), grads2[key].asnumpy(), rtol=1e-3, atol=1e-5)

@with_seed()
def test_hybrid_static_memory():
Expand Down

0 comments on commit 64fe9c4

Please sign in to comment.