Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
adding tests for completeness
Browse files Browse the repository at this point in the history
  • Loading branch information
mseth10 committed Mar 12, 2019
1 parent 44bcf43 commit 56c49f5
Showing 1 changed file with 25 additions and 0 deletions.
25 changes: 25 additions & 0 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6606,6 +6606,31 @@ def test_slice_forward_backward(a, index):
for index in index_list:
test_slice_forward_backward(arr, index)

def test_begin_equals_end():
shape = (2, 4)
begin = (None, 2)
end = (None, 2)
step = (1, -1)
in_arr = mx.nd.arange(np.prod(shape)).reshape(shape=shape)
out_arr = mx.nd.slice(in_arr, begin=begin, end=end, step=step)

assert_raises(MXNetError, test_begin_equals_end)

# check for end=-1 and negative step
shape = (4, 3)
begin = (1, 2)
end = (-1, -1)
step = (-1, -2)
in_arr = mx.nd.arange(np.prod(shape)).reshape(shape=shape)
out_arr = mx.nd.slice(in_arr, begin=begin, end=end, step=step)
out_exp = np.array([[5., 3.],[2., 0.]])
assert same(out_arr.asnumpy(), out_exp)
data = mx.sym.Variable('data')
slice_sym = mx.sym.slice(data, begin=begin, end=end, step=step)
grad_exp = np.zeros(shape)
grad_exp[(slice(None, 2, 1),slice(None, None, 2))] = np.array([[0., 2.],[3., 5.]])
check_symbolic_backward(slice_sym, [in_arr.asnumpy()], [out_exp], [grad_exp])

# check numeric gradient
in_data = np.arange(36).reshape(2, 2, 3, 3)
data = mx.sym.Variable('data')
Expand Down

0 comments on commit 56c49f5

Please sign in to comment.