diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py index fb329f1865a9..c180ad317e6f 100644 --- a/python/mxnet/ndarray/ndarray.py +++ b/python/mxnet/ndarray/ndarray.py @@ -1261,13 +1261,29 @@ def sign(self, *args, **kwargs): """ return op.sign(self, *args, **kwargs) - def flatten(self, *args, **kwargs): - """Convenience fluent method for :py:func:`flatten`. + def flatten(self): + """Returns a flattened **view** of this array without altering any data. - The arguments are the same as for :py:func:`flatten`, with - this array as data. + Returns + ------- + NDArray + An array with flattened shape `(d1, d2*...*dk)` that shares data with + this array with shape `(d1, d2, ..., dk)`. + + Examples + -------- + >>> x = mx.nd.arange(30).reshape(5,2,3) + >>> y = x.flatten() + >>> y.shape + (5, 6) + >>> y[0].asnumpy() + array([0., 1., 2., 3., 4., 5.], dtype=float32) + >>> y[:] = -1 + >>> x[0].asnumpy() + array([[-1., -1., -1.], + [-1., -1., -1.]], dtype=float32) """ - return op.flatten(self, *args, **kwargs) + return self.reshape((0, -1)) def shape_array(self, *args, **kwargs): """Convenience fluent method for :py:func:`shape_array`. @@ -1285,13 +1301,43 @@ def size_array(self, *args, **kwargs): """ return op.size_array(self, *args, **kwargs) - def expand_dims(self, *args, **kwargs): - """Convenience fluent method for :py:func:`expand_dims`. + def expand_dims(self, axis): + """Returns a **view** of this array with additional dimension without altering any data. - The arguments are the same as for :py:func:`expand_dims`, with - this array as data. + Parameters + ---------- + axis : int + Position where new axis is to be inserted. + Suppose that the input NDArray's dimension is ndim, + the range of the inserted axis is [-ndim, ndim]. + + Returns + ------- + NDArray + An array with expanded shape `(d1, d2, ..., 1, di, ..., dk)` + that shares data with this array with shape `(d1, d2, ..., dk)`, + given input axis `i`. + + Examples + -------- + >>> x = mx.nd.arange(6).reshape(2,3) + >>> y = x.expand_dims(1) + >>> y.shape + (2, 1, 3) + >>> y[0].asnumpy() + array([[0., 1., 2.]], dtype=float32) + >>> y[:] = -1 + >>> x.asnumpy() + array([[-1., -1., -1.], + [-1., -1., -1.]], dtype=float32) """ - return op.expand_dims(self, *args, **kwargs) + new_shape = list(self.shape) + assert -len(new_shape) <= axis <= len(new_shape), \ + "axis {} is out of range for {}d array".format(axis, len(new_shape)) + if axis < 0: + axis += len(new_shape) + new_shape.insert(axis, 1) + return self.reshape(new_shape) def tile(self, *args, **kwargs): """Convenience fluent method for :py:func:`tile`. @@ -1701,13 +1747,35 @@ def softmin(self, *args, **kwargs): """ return op.softmin(self, *args, **kwargs) - def squeeze(self, *args, **kwargs): - """Convenience fluent method for :py:func:`squeeze`. + def squeeze(self, axis=None): + """Returns a **view** of this array with squeezed shape without altering any data. - The arguments are the same as for :py:func:`squeeze`, with - this array as data. + Parameters + ---------- + axis : int, tuple of int, or None + Selects a subset of the single-dimensional entries in the shape. + If an axis is selected with shape entry greater than one, an error is raised. """ - return op.squeeze(self, *args, **kwargs) + new_shape = list(self.shape) + if isinstance(axis, int): + axis = [axis] + if axis: + assert len(axis) == len(set(axis)), \ + "axis {} contains duplicate which is not allowed.".format(axis) + for i in axis: + assert 0 <= i < len(new_shape), "axis {} is out of range.".format(i) + assert new_shape[i] == 1, \ + "Squeeze target axis {} must be size 1, got {}.".format(i, new_shape[i]) + for i in sorted(axis, reverse=True): + del new_shape[i] + else: + for i in reversed(range(len(new_shape))): + if new_shape[i] == 1: + del new_shape[i] + if not new_shape: + new_shape.append(1) + + return self.reshape(new_shape) # pylint: disable= undefined-variable def broadcast_to(self, shape):