Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add inplace option
Browse files Browse the repository at this point in the history
  • Loading branch information
szha committed Feb 19, 2019
1 parent 6cc0b12 commit 81536a5
Showing 1 changed file with 55 additions and 31 deletions.
86 changes: 55 additions & 31 deletions python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1261,8 +1261,14 @@ def sign(self, *args, **kwargs):
"""
return op.sign(self, *args, **kwargs)

def flatten(self):
"""Returns a flattened **view** of this array without altering any data.
def flatten(self, inplace=False):
"""Flatten this array without altering any data.
Parameters
----------
inplace : bool, default False
If True, this method returns a **view** of this array
that shares data with this array. Otherwise, a copy is returned.
Returns
-------
Expand All @@ -1273,7 +1279,8 @@ def flatten(self):
Examples
--------
>>> x = mx.nd.arange(30).reshape(5,2,3)
>>> y = x.flatten()
>>> y = x.flatten(inplace=True)
>>> z = x.flatten()
>>> y.shape
(5, 6)
>>> y[0].asnumpy()
Expand All @@ -1282,8 +1289,10 @@ def flatten(self):
>>> x[0].asnumpy()
array([[-1., -1., -1.],
[-1., -1., -1.]], dtype=float32)
>>> z[0].asnumpy()
array([0., 1., 2., 3., 4., 5.], dtype=float32)
"""
return self.reshape((0, -1))
return op.flatten(self) if not inplace else self.reshape((0, -1))

def shape_array(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`shape_array`.
Expand All @@ -1301,7 +1310,7 @@ def size_array(self, *args, **kwargs):
"""
return op.size_array(self, *args, **kwargs)

def expand_dims(self, axis):
def expand_dims(self, axis, inplace=False):
"""Returns a **view** of this array with additional dimension without altering any data.
Parameters
Expand All @@ -1310,6 +1319,9 @@ def expand_dims(self, axis):
Position where new axis is to be inserted.
Suppose that the input NDArray's dimension is ndim,
the range of the inserted axis is [-ndim, ndim].
inplace : bool, default False
If True, this method returns a **view** of this array
that shares data with this array. Otherwise, a copy is returned.
Returns
-------
Expand All @@ -1321,7 +1333,8 @@ def expand_dims(self, axis):
Examples
--------
>>> x = mx.nd.arange(6).reshape(2,3)
>>> y = x.expand_dims(1)
>>> y = x.expand_dims(1, inplace=True)
>>> z = x.expand_dims(1)
>>> y.shape
(2, 1, 3)
>>> y[0].asnumpy()
Expand All @@ -1330,14 +1343,19 @@ def expand_dims(self, axis):
>>> x.asnumpy()
array([[-1., -1., -1.],
[-1., -1., -1.]], dtype=float32)
>>> z[0].asnumpy()
array([[0., 1., 2.]], dtype=float32)
"""
new_shape = list(self.shape)
assert -len(new_shape) <= axis <= len(new_shape), \
"axis {} is out of range for {}d array".format(axis, len(new_shape))
if axis < 0:
axis += len(new_shape)
new_shape.insert(axis, 1)
return self.reshape(new_shape)
if not inplace:
return op.expand_dims(self, axis=axis)
else:
new_shape = list(self.shape)
assert -len(new_shape) <= axis <= len(new_shape), \
"axis {} is out of range for {}d array".format(axis, len(new_shape))
if axis < 0:
axis += len(new_shape)
new_shape.insert(axis, 1)
return self.reshape(new_shape)

def tile(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`tile`.
Expand Down Expand Up @@ -1747,35 +1765,41 @@ def softmin(self, *args, **kwargs):
"""
return op.softmin(self, *args, **kwargs)

def squeeze(self, axis=None):
def squeeze(self, axis=None, inplace=False):
"""Returns a **view** of this array with squeezed shape without altering any data.
Parameters
----------
axis : int, tuple of int, or None
Selects a subset of the single-dimensional entries in the shape.
If an axis is selected with shape entry greater than one, an error is raised.
inplace : bool, default False
If True, this method returns a **view** of this array
that shares data with this array. Otherwise, a copy is returned.
"""
new_shape = list(self.shape)
if isinstance(axis, int):
axis = [axis]
if axis:
assert len(axis) == len(set(axis)), \
"axis {} contains duplicate which is not allowed.".format(axis)
for i in axis:
assert 0 <= i < len(new_shape), "axis {} is out of range.".format(i)
assert new_shape[i] == 1, \
"Squeeze target axis {} must be size 1, got {}.".format(i, new_shape[i])
for i in sorted(axis, reverse=True):
del new_shape[i]
if not inplace:
return op.squeeze(self, axis=axis)
else:
for i in reversed(range(len(new_shape))):
if new_shape[i] == 1:
new_shape = list(self.shape)
if isinstance(axis, int):
axis = [axis]
if axis:
assert len(axis) == len(set(axis)), \
"axis {} contains duplicate which is not allowed.".format(axis)
for i in axis:
assert 0 <= i < len(new_shape), "axis {} is out of range.".format(i)
assert new_shape[i] == 1, \
"Squeeze target axis {} must be size 1, got {}.".format(i, new_shape[i])
for i in sorted(axis, reverse=True):
del new_shape[i]
if not new_shape:
new_shape.append(1)
else:
for i in reversed(range(len(new_shape))):
if new_shape[i] == 1:
del new_shape[i]
if not new_shape:
new_shape.append(1)

return self.reshape(new_shape)
return self.reshape(new_shape)

# pylint: disable= undefined-variable
def broadcast_to(self, shape):
Expand Down

0 comments on commit 81536a5

Please sign in to comment.