From b2963d0a9de40fb3ae4f8c33f6ed187e5b7bf74f Mon Sep 17 00:00:00 2001 From: Sheng Zha Date: Sun, 17 Feb 2019 22:29:16 -0800 Subject: [PATCH] add inplace option --- python/mxnet/ndarray/ndarray.py | 86 +++++++++++++++++++++------------ 1 file changed, 55 insertions(+), 31 deletions(-) diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py index c180ad317e6f..790e5fd9ce20 100644 --- a/python/mxnet/ndarray/ndarray.py +++ b/python/mxnet/ndarray/ndarray.py @@ -1261,8 +1261,14 @@ def sign(self, *args, **kwargs): """ return op.sign(self, *args, **kwargs) - def flatten(self): - """Returns a flattened **view** of this array without altering any data. + def flatten(self, inplace=False): + """Flatten this array without altering any data. + + Parameters + ---------- + inplace : bool, default False + If True, this method returns a **view** of this array + that shares data with this array. Otherwise, a copy is returned. Returns ------- @@ -1273,7 +1279,8 @@ def flatten(self): Examples -------- >>> x = mx.nd.arange(30).reshape(5,2,3) - >>> y = x.flatten() + >>> y = x.flatten(inplace=True) + >>> z = x.flatten() >>> y.shape (5, 6) >>> y[0].asnumpy() @@ -1282,8 +1289,10 @@ def flatten(self): >>> x[0].asnumpy() array([[-1., -1., -1.], [-1., -1., -1.]], dtype=float32) + >>> z[0].asnumpy() + array([0., 1., 2., 3., 4., 5.], dtype=float32) """ - return self.reshape((0, -1)) + return op.flatten(self) if not inplace else self.reshape((0, -1)) def shape_array(self, *args, **kwargs): """Convenience fluent method for :py:func:`shape_array`. @@ -1301,7 +1310,7 @@ def size_array(self, *args, **kwargs): """ return op.size_array(self, *args, **kwargs) - def expand_dims(self, axis): + def expand_dims(self, axis, inplace=False): """Returns a **view** of this array with additional dimension without altering any data. Parameters @@ -1310,6 +1319,9 @@ def expand_dims(self, axis): Position where new axis is to be inserted. Suppose that the input NDArray's dimension is ndim, the range of the inserted axis is [-ndim, ndim]. + inplace : bool, default False + If True, this method returns a **view** of this array + that shares data with this array. Otherwise, a copy is returned. Returns ------- @@ -1321,7 +1333,8 @@ def expand_dims(self, axis): Examples -------- >>> x = mx.nd.arange(6).reshape(2,3) - >>> y = x.expand_dims(1) + >>> y = x.expand_dims(1, inplace=True) + >>> z = x.expand_dims(1) >>> y.shape (2, 1, 3) >>> y[0].asnumpy() @@ -1330,14 +1343,19 @@ def expand_dims(self, axis): >>> x.asnumpy() array([[-1., -1., -1.], [-1., -1., -1.]], dtype=float32) + >>> z[0].asnumpy() + array([[0., 1., 2.]], dtype=float32) """ - new_shape = list(self.shape) - assert -len(new_shape) <= axis <= len(new_shape), \ - "axis {} is out of range for {}d array".format(axis, len(new_shape)) - if axis < 0: - axis += len(new_shape) - new_shape.insert(axis, 1) - return self.reshape(new_shape) + if not inplace: + op.expand_dims(self, axis=axis) + else: + new_shape = list(self.shape) + assert -len(new_shape) <= axis <= len(new_shape), \ + "axis {} is out of range for {}d array".format(axis, len(new_shape)) + if axis < 0: + axis += len(new_shape) + new_shape.insert(axis, 1) + return self.reshape(new_shape) def tile(self, *args, **kwargs): """Convenience fluent method for :py:func:`tile`. @@ -1747,7 +1765,7 @@ def softmin(self, *args, **kwargs): """ return op.softmin(self, *args, **kwargs) - def squeeze(self, axis=None): + def squeeze(self, axis=None, inplace=False): """Returns a **view** of this array with squeezed shape without altering any data. Parameters @@ -1755,27 +1773,33 @@ def squeeze(self, axis=None): axis : int, tuple of int, or None Selects a subset of the single-dimensional entries in the shape. If an axis is selected with shape entry greater than one, an error is raised. + inplace : bool, default False + If True, this method returns a **view** of this array + that shares data with this array. Otherwise, a copy is returned. """ - new_shape = list(self.shape) - if isinstance(axis, int): - axis = [axis] - if axis: - assert len(axis) == len(set(axis)), \ - "axis {} contains duplicate which is not allowed.".format(axis) - for i in axis: - assert 0 <= i < len(new_shape), "axis {} is out of range.".format(i) - assert new_shape[i] == 1, \ - "Squeeze target axis {} must be size 1, got {}.".format(i, new_shape[i]) - for i in sorted(axis, reverse=True): - del new_shape[i] + if not inplace: + op.squeeze(self, axis=axis) else: - for i in reversed(range(len(new_shape))): - if new_shape[i] == 1: + new_shape = list(self.shape) + if isinstance(axis, int): + axis = [axis] + if axis: + assert len(axis) == len(set(axis)), \ + "axis {} contains duplicate which is not allowed.".format(axis) + for i in axis: + assert 0 <= i < len(new_shape), "axis {} is out of range.".format(i) + assert new_shape[i] == 1, \ + "Squeeze target axis {} must be size 1, got {}.".format(i, new_shape[i]) + for i in sorted(axis, reverse=True): del new_shape[i] - if not new_shape: - new_shape.append(1) + else: + for i in reversed(range(len(new_shape))): + if new_shape[i] == 1: + del new_shape[i] + if not new_shape: + new_shape.append(1) - return self.reshape(new_shape) + return self.reshape(new_shape) # pylint: disable= undefined-variable def broadcast_to(self, shape):