Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[BUGFIX] Fix test_zero_sized_dim save/restore of np_shape state (#20365)
Browse files Browse the repository at this point in the history
* Fix test_zero_sized_dim save/restore of np_shape state

* Trigger CI
  • Loading branch information
DickJC123 committed Jun 25, 2021
1 parent 7b4d61d commit dc69b04
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 19 deletions.
40 changes: 23 additions & 17 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5333,19 +5333,21 @@ def test_boolean_mask():
assert same(data.grad.asnumpy(), expected_grad)

# test 0-size output
mx.set_np_shape(True)
data = mx.nd.array([[1, 2, 3],[4, 5, 6],[7, 8, 9]])
index = mx.nd.array([0, 0, 0])
data.attach_grad()
with mx.autograd.record():
out = mx.nd.contrib.boolean_mask(data, index)
out.backward()
data.grad.wait_to_read()
expected = np.zeros((0, 3))
expected_grad = np.array([[0, 0, 0], [0, 0, 0], [0, 0, 0]])
assert same(out.asnumpy(), expected)
assert same(data.grad.asnumpy(), expected_grad)
mx.set_np_shape(False)
prev_np_shape = mx.set_np_shape(True)
try:
data = mx.nd.array([[1, 2, 3],[4, 5, 6],[7, 8, 9]])
index = mx.nd.array([0, 0, 0])
data.attach_grad()
with mx.autograd.record():
out = mx.nd.contrib.boolean_mask(data, index)
out.backward()
data.grad.wait_to_read()
expected = np.zeros((0, 3))
expected_grad = np.array([[0, 0, 0], [0, 0, 0], [0, 0, 0]])
assert same(out.asnumpy(), expected)
assert same(data.grad.asnumpy(), expected_grad)
finally:
mx.set_np_shape(prev_np_shape)

# test gradient
shape = (100, 30)
Expand Down Expand Up @@ -9463,7 +9465,8 @@ def test_sldwin_atten_op_impl(batch_size, seq_length, num_heads,

def test_zero_sized_dim():

mx.util.set_np_shape(True) # Must be done to prevent zero-sized dimension conversion to 'unknown'
# Must be done to prevent zero-sized dimension conversion to 'unknown'
prev_np_shape = mx.util.set_np_shape(True)

def seq_last():
"""Test for issue: https://github.com/apache/incubator-mxnet/issues/18938"""
Expand All @@ -9483,9 +9486,12 @@ def seq_reverse():
res = mx.nd.op.SequenceReverse(data)
assert data.shape == res.shape

seq_last()
seq_reverse()
seq_mask()
try:
seq_last()
seq_reverse()
seq_mask()
finally:
mx.util.set_np_shape(prev_np_shape)

@mx.util.use_np
def test_take_grads():
Expand Down
4 changes: 2 additions & 2 deletions tests/python/unittest/test_thread_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def g():


def test_np_global_shape():
set_np_shape(2)
prev_np_shape = set_np_shape(2)
data = []

def f():
Expand All @@ -229,4 +229,4 @@ def f():
assert_almost_equal(data[0].asnumpy(), np.ones(shape=()))
assert_almost_equal(data[1].asnumpy(), np.ones(shape=(0, 1, 2)))
finally:
set_np_shape(0)
set_np_shape(prev_np_shape)

0 comments on commit dc69b04

Please sign in to comment.