diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py index 47053fa8e392..b98ebd905e6f 100644 --- a/python/mxnet/ndarray/ndarray.py +++ b/python/mxnet/ndarray/ndarray.py @@ -442,6 +442,10 @@ def __setitem__(self, key, value): array([[ 6., 5., 5.], [ 6., 0., 4.]], dtype=float32) """ + if self.ndim == 0 and key == (): + _internal._full(shape=self.shape, value=float(value), ctx=self.context, + dtype=self.dtype, out=self) + return key = _indexing_key_expand_implicit_axes(key, self.shape) slc_key = tuple(idx for idx in key if idx is not None) @@ -602,6 +606,8 @@ def __getitem__(self, key): array([[[4., 5.], [6., 7.]]], dtype=float32) """ + if self.ndim == 0 and key == (): + return self key = _indexing_key_expand_implicit_axes(key, self.shape) if len(key) == 0: raise ValueError('indexing key cannot be an empty tuple') @@ -2741,6 +2747,8 @@ def _get_dim_size(start, stop, step): """Given start, stop, and stop, calculate the number of elements of this slice.""" assert step != 0 + if stop == start: + return 0 if step > 0: assert start < stop dim_size = (stop - start - 1) // step + 1 diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py index 0e260ceb7676..75ecfd1b1ede 100644 --- a/python/mxnet/test_utils.py +++ b/python/mxnet/test_utils.py @@ -1062,7 +1062,11 @@ def check_symbolic_forward(sym, location, expected, rtol=1E-4, atol=None, executor = sym.bind(ctx=ctx, args=location, args_grad=args_grad_data, aux_states=aux_states) for g in executor.grad_arrays: - g[:] = 0 + print(g.shape) + if g.ndim == 0: + g[()] = 0 + else: + g[:] = 0 executor.forward(is_train=False) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 632c60fdc3ce..13ab03b94de9 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -8862,7 +8862,7 @@ def test_index_array_default(): @mx.use_np_shape def test_index_array_default_zero_dim(): - data = mx.symbol.Variable("data") + data = mx.symbol.Variable("data") index_array = mx.sym.contrib.index_array(data) input_array = np.ones(())