Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[BUGFIX] Fix test_zero_sized_dim save/restore of np_shape state #20365

Merged
merged 2 commits into from
Jun 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 23 additions & 17 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5333,19 +5333,21 @@ def test_boolean_mask():
assert same(data.grad.asnumpy(), expected_grad)

# test 0-size output
mx.set_np_shape(True)
data = mx.nd.array([[1, 2, 3],[4, 5, 6],[7, 8, 9]])
index = mx.nd.array([0, 0, 0])
data.attach_grad()
with mx.autograd.record():
out = mx.nd.contrib.boolean_mask(data, index)
out.backward()
data.grad.wait_to_read()
expected = np.zeros((0, 3))
expected_grad = np.array([[0, 0, 0], [0, 0, 0], [0, 0, 0]])
assert same(out.asnumpy(), expected)
assert same(data.grad.asnumpy(), expected_grad)
mx.set_np_shape(False)
prev_np_shape = mx.set_np_shape(True)
try:
data = mx.nd.array([[1, 2, 3],[4, 5, 6],[7, 8, 9]])
index = mx.nd.array([0, 0, 0])
data.attach_grad()
with mx.autograd.record():
out = mx.nd.contrib.boolean_mask(data, index)
out.backward()
data.grad.wait_to_read()
expected = np.zeros((0, 3))
expected_grad = np.array([[0, 0, 0], [0, 0, 0], [0, 0, 0]])
assert same(out.asnumpy(), expected)
assert same(data.grad.asnumpy(), expected_grad)
finally:
mx.set_np_shape(prev_np_shape)

# test gradient
shape = (100, 30)
Expand Down Expand Up @@ -9463,7 +9465,8 @@ def test_sldwin_atten_op_impl(batch_size, seq_length, num_heads,

def test_zero_sized_dim():

mx.util.set_np_shape(True) # Must be done to prevent zero-sized dimension conversion to 'unknown'
# Must be done to prevent zero-sized dimension conversion to 'unknown'
prev_np_shape = mx.util.set_np_shape(True)

def seq_last():
"""Test for issue: https://github.com/apache/incubator-mxnet/issues/18938"""
Expand All @@ -9483,9 +9486,12 @@ def seq_reverse():
res = mx.nd.op.SequenceReverse(data)
assert data.shape == res.shape

seq_last()
seq_reverse()
seq_mask()
try:
seq_last()
seq_reverse()
seq_mask()
finally:
mx.util.set_np_shape(prev_np_shape)

def test_take_grads():
# Test for https://github.com/apache/incubator-mxnet/issues/19817
Expand Down
4 changes: 2 additions & 2 deletions tests/python/unittest/test_thread_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def g():


def test_np_global_shape():
set_np_shape(2)
prev_np_shape = set_np_shape(2)
data = []

def f():
Expand All @@ -229,4 +229,4 @@ def f():
assert_almost_equal(data[0].asnumpy(), np.ones(shape=()))
assert_almost_equal(data[1].asnumpy(), np.ones(shape=(0, 1, 2)))
finally:
set_np_shape(0)
set_np_shape(prev_np_shape)