Skip to content

Commit

Permalink
fix typo
Browse files Browse the repository at this point in the history
  • Loading branch information
azai91 committed Nov 29, 2018
1 parent 62dd2d0 commit e5c5976
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions tests/python/unittest/test_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -1221,29 +1221,29 @@ def check_hybrid_static_memory(train_modes, **kwargs):
net1(x)
net2(x)

def test(net, x, record=True):
with mx.autograd.record(record):
def test(net, x, train_mode=True):
with mx.autograd.record(train_mode=train_mode):
y = net(x) + net(x)
y.backward()
y.backward(train_mode=train_mode)

grads = {k: v.grad() for k, v in net.collect_params().items() if v.grad_req != 'null'}

return y, grads

for record in train_modes:
y1, grads1 = test(net1, x, record)
y2, grads2 = test(net2, x, record)
for train_mode in train_modes:
y1, grads1 = test(net1, x, train_mode)
y2, grads2 = test(net2, x, train_mode)

assert_almost_equal(y1.asnumpy(), y2.asnumpy(), rtol=1e-3, atol=1e-5)
for key in grads1:
assert_almost_equal(grads1[key].asnumpy(), grads2[key].asnumpy(), rtol=1e-3, atol=1e-5)

@with_seed()
def test_hybrid_static_memory():
check_hybrid_static_memory(train_mode=[True, False])
check_hybrid_static_memory(train_mode=[True, False], static_alloc=True)
check_hybrid_static_memory(train_modes=[True, False])
check_hybrid_static_memory(train_modes=[True, False], static_alloc=True)
# TODO: MKLDNN (issue #13445) does not work with static_shape backwards
check_hybrid_static_memory(train_mode=[True], static_alloc=True, static_shape=True)
check_hybrid_static_memory(train_modes=[True], static_alloc=True, static_shape=True)

def check_hybrid_static_memory_switching(**kwargs):
net = gluon.model_zoo.vision.get_resnet(
Expand Down

0 comments on commit e5c5976

Please sign in to comment.