Skip to content

Commit

Permalink
Fix tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
zheng-da committed Jun 28, 2018
1 parent 1eb3abf commit e2f814d
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5955,10 +5955,14 @@ def verify_foreach(step, in_syms, state_syms, free_syms,
res2 = _as_list(res)
for i in range(len(res2)):
res2[i] = res2[i] * 2
outs = []
outs[:] = res2[:]
if isinstance(states, list):
outs.extend(states)
states = [mx.nd.expand_dims(s, 0) for s in states]
res2.extend(states)
else:
outs.append(states)
states = mx.nd.expand_dims(states, 0)
res2.append(states)
res = mx.nd.concat(*res2, dim=0)
Expand All @@ -5968,8 +5972,9 @@ def verify_foreach(step, in_syms, state_syms, free_syms,
tmp_grads.extend(tmp_grads1)
if (is_train):
res.backward(mx.nd.concat(*tmp_grads, dim=0))
for i in range(len(res2)):
assert_almost_equal(e.outputs[i].asnumpy(), res2[i].asnumpy(),
for i in range(len(outs)):
assert e.outputs[i].shape == outs[i].shape
assert_almost_equal(e.outputs[i].asnumpy(), outs[i].asnumpy(),
rtol=0.001, atol=0.0001)
if (is_train):
all_ins = _as_list(in_arrs)[:]
Expand Down

0 comments on commit e2f814d

Please sign in to comment.