Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add forward tests
Browse files Browse the repository at this point in the history
  • Loading branch information
AntiZpvoh committed Jun 10, 2020
1 parent afe4107 commit 4b07eee
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 2 deletions.
8 changes: 6 additions & 2 deletions python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def __getitem__(self, key): # pylint: disable = too-many-return-statements, inco
end = []
step = []
new_shape = ()
result = None
result = self
is_symbol_tuple = False
if len(key) == 0:
return self
Expand All @@ -139,8 +139,12 @@ def __getitem__(self, key): # pylint: disable = too-many-return-statements, inco
step.append(-1)
new_shape += (-3,)
elif isinstance(index, Symbol):
if new_shape != ():
new_shape += (-4,)
sliced = _npi.slice(result, begin, end, step)
result = _npi.reshape(sliced, new_shape)
if not is_symbol_tuple:
result = _npi.advanced_indexing(self, index)
result = _npi.advanced_indexing(result, index)
is_symbol_tuple = True
else:
result = _npi.advanced_indexing_multiple(result, index)
Expand Down
130 changes: 130 additions & 0 deletions tests/python/unittest/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1417,6 +1417,76 @@ def test_getitem(np_array, index, is_scalar=False):

assert same(np_indexed_array, mx_indexed_array), 'Failed with index = {}'.format(index)


def test_getitem_symbol(np_array, index):
np_index = index
np_array_sym = mx.sym.Variable('np_array').as_np_ndarray()
if type(index)==int or type(index)==np.int32 or type(index)==np.int64 or type(index)==slice:
np_index = index
np_index_sym = mx.sym.Variable('np_index').as_np_ndarray().astype(np.int32)
np_indexed_array_sym = np_array_sym[index].astype(np.int32)
np_model = np_indexed_array_sym.bind(ctx=mx.cpu(), args={'np_array':mx.nd.array(np_array)})
np_indexed_array = np_model.forward()

if type(index) == list:
index = np.array(index)
if type(index) == mx.nd.NDArray:
np_index = index
np_index_sym = mx.sym.Variable('np_index').as_np_ndarray().astype(np.int32)
np_indexed_array_sym = np_array_sym[np_index_sym].astype(np.int32)
np_model = np_indexed_array_sym.bind(ctx=mx.cpu(), args={'np_array':mx.nd.array(np_array),
'np_index':np_index})
np_indexed_array = np_model.forward()
np_index = np_index.asnumpy()
if isinstance(index, np.ndarray):
np_index = index
np_index_sym = mx.sym.Variable('np_index').as_np_ndarray().astype(np.int32)
np_indexed_array_sym = np_array_sym[np_index_sym].astype(np.int32)
np_model = np_indexed_array_sym.bind(ctx=mx.cpu(), args={'np_array':mx.nd.array(np_array),
'np_index':mx.nd.array(np_index)})
np_indexed_array = np_model.forward()
if isinstance(index, tuple):
np_index = tuple([
idx.asnumpy() if isinstance(idx, mx.nd.NDArray) else idx
for idx in index]
)
num = 0
sym_args = {'np_array':mx.nd.array(np_array)}
np_index_list = []
for np_index_ele in np_index:
if isinstance(np_index_ele, mx.nd.NDArray):
np_index_ele = np_index_ele.asnumpy()
elif type(np_index_ele) == list:
np_index_ele = np.array(np_index_ele)
if type(np_index_ele)==int or type(np_index_ele)==np.int32 or \
type(np_index_ele)==np.int64 or type(np_index_ele)==slice:
np_index_list.append(np_index_ele)
elif isinstance(np_index_ele, np.ndarray):
tmp_sym = mx.sym.Variable('np_index_'+str(num)).as_np_ndarray().astype(np.int32)
np_index_list.append(tmp_sym)
sym_args['np_index_'+str(num)] = mx.nd.array(np_index_ele)
num+=1
np_index_list = tuple(np_index_list)
np_indexed_array_sym = np_array_sym[np_index_list].astype(np.int32)
np_model = np_indexed_array_sym.bind(ctx=mx.cpu(), args=sym_args)
np_indexed_array = np_model.forward()

if(np_indexed_array != None):
np_indexed_array = np_indexed_array[0].asnumpy()
mx_np_array = np.array(np_array, dtype=np_array.dtype)
for autograd in [True, False]:
try:
if autograd:
with mx.autograd.record():
mx_indexed_array = mx_np_array[np_index]
else:
mx_indexed_array = mx_np_array[np_index]
except Exception as e:
print('Failed with index = {}'.format(np_index))
raise e
assert same(np_indexed_array, mx_indexed_array), \
'Failed with index = {}, as type = {}'.format(index, type(index))

def test_setitem(np_array, index, is_scalar):
def assert_same(np_array, np_index, mx_array, mx_index, mx_value, np_value=None):
if np_value is not None:
Expand Down Expand Up @@ -1661,6 +1731,66 @@ def convert(num):
test_getitem_autograd(np_array, index[0])
test_setitem_autograd(np_array, index[0])

index_list_symbol = [# Basic indexing
# Single int as index
(0, False), (np.int32(0), False), (np.int64(0), False),
(5, False), (np.int32(5), False), (np.int64(5), False),
(-1, False), (np.int32(-1), False), (np.int64(-1), False),
# Slicing as index
(slice(5), False), (np_int(slice(5), np.int32), False), (np_int(slice(5), np.int64), False),
(slice(1, 5), False), (np_int(slice(1, 5), np.int32), False), (np_int(slice(1, 5), np.int64), False),
(slice(1, 5, 2), False), (np_int(slice(1, 5, 2), np.int32), False),
(np_int(slice(1, 5, 2), np.int64), False),
(slice(7, 0, -1), False), (np_int(slice(7, 0, -1)), False),
(np_int(slice(7, 0, -1), np.int64), False),
(slice(None, 6), False), (np_int(slice(None, 6)), False),
(np_int(slice(None, 6), np.int64), False),
(slice(None, 6, 3), False), (np_int(slice(None, 6, 3)), False),
(np_int(slice(None, 6, 3), np.int64), False),
(slice(1, None), False), (np_int(slice(1, None)), False),
(np_int(slice(1, None), np.int64), False),
(slice(1, None, 3), False), (np_int(slice(1, None, 3)), False),
(np_int(slice(1, None, 3), np.int64), False),
(slice(None, None, 2), False), (np_int(slice(None, None, 2)), False),
(np_int(slice(None, None, 2), np.int64), False),
(slice(None, None, -1), False),
(np_int(slice(None, None, -1)), False), (np_int(slice(None, None, -1), np.int64), False),
(slice(None, None, -2), False),
(np_int(slice(None, None, -2), np.int32), False), (np_int(slice(None, None, -2), np.int64), False),
# Multiple ints as indices
((1, 2, 3), False),
(np_int((1, 2, 3)), False),
(np_int((1, 2, 3), np.int64), False),
((-1, -2, -3), False),
(np_int((-1, -2, -3)), False),
(np_int((-1, -2, -3), np.int64), False),
((1, 2, 3, 4), True),
(np_int((1, 2, 3, 4)), True),
(np_int((1, 2, 3, 4), np.int64), True),
((-4, -3, -2, -1), True),
(np_int((-4, -3, -2, -1)), True),
(np_int((-4, -3, -2, -1), np.int64), True),
# Advanced indexing
([1], False), ([1, 2], False), ([2, 1, 3], False), ([7, 5, 0, 3, 6, 2, 1], False),
(np.array([6, 3], dtype=np.int32), False),
(np.array([[3, 4], [0, 6]], dtype=np.int32), False),
(np.array([[7, 3], [2, 6], [0, 5], [4, 1]], dtype=np.int32), False),
(np.array([[7, 3], [2, 6], [0, 5], [4, 1]], dtype=np.int64), False),
(np.array([[2], [0], [1]], dtype=np.int32), False),
(np.array([[2], [0], [1]], dtype=np.int64), False),
(mx.nd.array([4, 7], dtype=np.int32), False),
(mx.nd.array([4, 7], dtype=np.int64), False),
(mx.nd.array([[3, 6], [2, 1]], dtype=np.int32), False),
(mx.nd.array([[3, 6], [2, 1]], dtype=np.int64), False),
(mx.nd.array([[7, 3], [2, 6], [0, 5], [4, 1]], dtype=np.int32), False),
(mx.nd.array([[7, 3], [2, 6], [0, 5], [4, 1]], dtype=np.int64), False),
((1, [2, 3]), False),
([0], False), ([0, 1], False), ([1, 2, 3], False), ([2, 0, 5, 6], False),
(([1, 1], [2, 3]), False),
]

for index in index_list_symbol:
test_getitem_symbol(np_array, index[0])

def test_assign_float_value_to_ndarray():
"""Test case from https://github.com/apache/incubator-mxnet/issues/8668"""
Expand Down

0 comments on commit 4b07eee

Please sign in to comment.