Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
reminisce committed Oct 20, 2019
1 parent f724092 commit 8c11ef9
Showing 1 changed file with 11 additions and 13 deletions.
24 changes: 11 additions & 13 deletions tests/python/unittest/test_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,29 +427,27 @@ def get_net():
weight = mx.sym.var('weight', shape=(num_hdidden, 0))
return mx.sym.FullyConnected(data, weight, num_hidden=num_hdidden)

with TemporaryDirectory() as work_dir:
for is_old_format in [True, False]:
for flag1 in [False, True]:
with np_shape(flag1):
net_json_str = get_net().tojson()
net_data = json.loads(net_json_str)
assert "attrs" in net_data
assert "is_np_shape" in net_data["attrs"]
np_shape_flag = net_data["attrs"]["is_np_shape"]

assert len(np_shape_flag) == 2
assert np_shape_flag[0] == 'int'
assert np_shape_flag[1] == 0

if is_old_format:
net_data["attrs"].pop("is_np_shape") # delete is_np_shape key-value to simulate 1.5.0 format
if flag1:
assert "is_np_shape" in net_data["attrs"]
else:
assert "is_np_shape" not in net_data["attrs"]

with TemporaryDirectory() as work_dir:
fname = os.path.join(work_dir, 'test_sym.json')
with open(fname, 'w') as fp:
json.dump(net_data, fp)

# test loading 1.5.0 symbol file since 1.6.0
# w/ or w/o np_shape semantics
for flag in [False, True]:
with np_shape(flag):
for flag2 in [False, True]:
if flag1: # Do not need to test this case since 0 indicates zero-size dim
continue
with np_shape(flag2):
net = mx.sym.load(fname)
arg_shapes, out_shapes, aux_shapes = net.infer_shape(data=(batch_size, num_features))
assert arg_shapes[0] == (batch_size, num_features) # data
Expand Down

0 comments on commit 8c11ef9

Please sign in to comment.