Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix unit test failure
Browse files Browse the repository at this point in the history
  • Loading branch information
reminisce committed Aug 7, 2019
1 parent 4234412 commit fe6336d
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 2 deletions.
8 changes: 8 additions & 0 deletions python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,10 @@ def __setitem__(self, key, value):
array([[ 6., 5., 5.],
[ 6., 0., 4.]], dtype=float32)
"""
if self.ndim == 0 and key == ():
_internal._full(shape=self.shape, value=float(value), ctx=self.context,
dtype=self.dtype, out=self)
return
key = _indexing_key_expand_implicit_axes(key, self.shape)
slc_key = tuple(idx for idx in key if idx is not None)

Expand Down Expand Up @@ -602,6 +606,8 @@ def __getitem__(self, key):
array([[[4., 5.],
[6., 7.]]], dtype=float32)
"""
if self.ndim == 0 and key == ():
return self
key = _indexing_key_expand_implicit_axes(key, self.shape)
if len(key) == 0:
raise ValueError('indexing key cannot be an empty tuple')
Expand Down Expand Up @@ -2741,6 +2747,8 @@ def _get_dim_size(start, stop, step):
"""Given start, stop, and stop, calculate the number of elements
of this slice."""
assert step != 0
if stop == start:
return 0
if step > 0:
assert start < stop
dim_size = (stop - start - 1) // step + 1
Expand Down
6 changes: 5 additions & 1 deletion python/mxnet/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1062,7 +1062,11 @@ def check_symbolic_forward(sym, location, expected, rtol=1E-4, atol=None,

executor = sym.bind(ctx=ctx, args=location, args_grad=args_grad_data, aux_states=aux_states)
for g in executor.grad_arrays:
g[:] = 0
print(g.shape)
if g.ndim == 0:
g[()] = 0
else:
g[:] = 0

executor.forward(is_train=False)

Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8862,7 +8862,7 @@ def test_index_array_default():

@mx.use_np_shape
def test_index_array_default_zero_dim():
data = mx.symbol.Variable("data")
data = mx.symbol.Variable("data")
index_array = mx.sym.contrib.index_array(data)

input_array = np.ones(())
Expand Down

0 comments on commit fe6336d

Please sign in to comment.