Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix test.
Browse files Browse the repository at this point in the history
  • Loading branch information
zheng-da committed Aug 28, 2018
1 parent f4683cc commit cea32a3
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions tests/python/unittest/test_subgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def make_subgraph4(stype):
make_subgraphs = [make_subgraph1,
lambda stype: make_subgraph2(stype, False),
lambda stype: make_subgraph2(stype, True),
make_subgraph3]
make_subgraph3, make_subgraph4]
stypes = ['default', 'row_sparse']
for make_subg in make_subgraphs:
for stype in stypes:
Expand All @@ -129,8 +129,8 @@ def make_subgraph4(stype):
args_grad = {key : mx.nd.empty(shape=all_inputs[key].shape) for key in all_inputs.keys()}
e2 = subg.bind(ctx=default_context(), args=all_inputs, args_grad=args_grad,
aux_states=all_inputs)
e1.forward()
e2.forward()
e1.forward(is_train=True)
e2.forward(is_train=True)
for i in range(len(e1.outputs)):
assert_almost_equal(e1.outputs[i].asnumpy(), e2.outputs[i].asnumpy(),
rtol=0.001, atol=0.0001)
Expand Down

0 comments on commit cea32a3

Please sign in to comment.