Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add inplace option
Browse files Browse the repository at this point in the history
  • Loading branch information
szha committed May 30, 2019
1 parent 192f74f commit b7c70cb
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 33 deletions.
94 changes: 61 additions & 33 deletions python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1259,8 +1259,14 @@ def sign(self, *args, **kwargs):
"""
return op.sign(self, *args, **kwargs)

def flatten(self):
"""Returns a flattened **view** of this array without altering any data.
def flatten(self, inplace=False):
"""Flatten this array without altering any data.
Parameters
----------
inplace : bool, default False
If True, this method returns a **view** of this array
that shares data with this array. Otherwise, a copy is returned.
Returns
-------
Expand All @@ -1271,7 +1277,8 @@ def flatten(self):
Examples
--------
>>> x = mx.nd.arange(30).reshape(5,2,3)
>>> y = x.flatten()
>>> y = x.flatten(inplace=True)
>>> z = x.flatten()
>>> y.shape
(5, 6)
>>> y[0].asnumpy()
Expand All @@ -1280,8 +1287,10 @@ def flatten(self):
>>> x[0].asnumpy()
array([[-1., -1., -1.],
[-1., -1., -1.]], dtype=float32)
>>> z[0].asnumpy()
array([0., 1., 2., 3., 4., 5.], dtype=float32)
"""
return self.reshape((0, -1))
return op.flatten(self) if not inplace else self.reshape((0, -1))

def shape_array(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`shape_array`.
Expand All @@ -1299,15 +1308,18 @@ def size_array(self, *args, **kwargs):
"""
return op.size_array(self, *args, **kwargs)

def expand_dims(self, axis):
"""Returns a **view** of this array with additional dimension without altering any data.
def expand_dims(self, axis, inplace=False):
"""Adds an additional dimension to the current array without altering any data.
Parameters
----------
axis : int
Position where new axis is to be inserted.
Suppose that the input NDArray's dimension is ndim,
the range of the inserted axis is [-ndim, ndim].
inplace : bool, default False
If True, this method returns a **view** of this array
that shares data with this array. Otherwise, a copy is returned.
Returns
-------
Expand All @@ -1319,7 +1331,8 @@ def expand_dims(self, axis):
Examples
--------
>>> x = mx.nd.arange(6).reshape(2,3)
>>> y = x.expand_dims(1)
>>> y = x.expand_dims(1, inplace=True)
>>> z = x.expand_dims(1)
>>> y.shape
(2, 1, 3)
>>> y[0].asnumpy()
Expand All @@ -1328,14 +1341,19 @@ def expand_dims(self, axis):
>>> x.asnumpy()
array([[-1., -1., -1.],
[-1., -1., -1.]], dtype=float32)
>>> z[0].asnumpy()
array([[0., 1., 2.]], dtype=float32)
"""
new_shape = list(self.shape)
assert -len(new_shape) <= axis <= len(new_shape), \
"axis {} is out of range for {}d array".format(axis, len(new_shape))
if axis < 0:
axis += len(new_shape)
new_shape.insert(axis, 1)
return self.reshape(new_shape)
if not inplace:
return op.expand_dims(self, axis=axis)
else:
new_shape = list(self.shape)
assert -len(new_shape)-1 <= axis <= len(new_shape), \
"axis {} is out of range for {}d array".format(axis, len(new_shape))
if axis < 0:
axis += len(new_shape) + 1
new_shape.insert(axis, 1)
return self.reshape(new_shape)

def tile(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`tile`.
Expand Down Expand Up @@ -1745,35 +1763,45 @@ def softmin(self, *args, **kwargs):
"""
return op.softmin(self, *args, **kwargs)

def squeeze(self, axis=None):
"""Returns a **view** of this array with squeezed shape without altering any data.
def squeeze(self, axis=None, inplace=False):
"""Remove dimensions with size 1 from this array without altering any data.
Parameters
----------
axis : int, tuple of int, or None
Selects a subset of the single-dimensional entries in the shape.
If an axis is selected with shape entry greater than one, an error is raised.
inplace : bool, default False
If True, this method returns a **view** of this array
that shares data with this array. Otherwise, a copy is returned.
"""
new_shape = list(self.shape)
if isinstance(axis, int):
axis = [axis]
if axis:
assert len(axis) == len(set(axis)), \
"axis {} contains duplicate which is not allowed.".format(axis)
for i in axis:
assert 0 <= i < len(new_shape), "axis {} is out of range.".format(i)
assert new_shape[i] == 1, \
"Squeeze target axis {} must be size 1, got {}.".format(i, new_shape[i])
for i in sorted(axis, reverse=True):
del new_shape[i]
if not inplace:
return op.squeeze(self, axis=axis)
else:
for i in reversed(range(len(new_shape))):
if new_shape[i] == 1:
new_shape = list(self.shape)
axes = axis # rename variable for readability
if isinstance(axes, int):
axes = [axes]
if axes:
assert len(axes) == len(set(axes)), \
"axis {} contains duplicate which is not allowed.".format(axes)
resolved_axes = [i if i >= 0 else i+len(self.shape) for i in axes]
for arg_axis, actual_axis in zip(axes, resolved_axes):
assert -len(new_shape) <= arg_axis < len(new_shape), \
"axis {} is out of range for {}d array".format(arg_axis, len(new_shape))
axis_size = new_shape[actual_axis]
assert axis_size == 1, \
"Squeeze target axis {} must be size 1, got {}.".format(arg_axis, axis_size)
for i in sorted(resolved_axes, reverse=True):
del new_shape[i]
if not new_shape:
new_shape.append(1)
else:
for i in reversed(range(len(new_shape))):
if new_shape[i] == 1:
del new_shape[i]
if not new_shape:
new_shape.append(1)

return self.reshape(new_shape)
return self.reshape(new_shape)

# pylint: disable= undefined-variable
def broadcast_to(self, shape):
Expand Down
59 changes: 59 additions & 0 deletions tests/python/unittest/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,65 @@ def test_ndarray_reshape():
assert same(tensor.reshape(-1, 0, reverse=True).asnumpy(), true_res.reshape(6, 5).asnumpy())


@with_seed()
def test_ndarray_flatten():
tensor = (mx.nd.arange(30) + 1).reshape(2, 3, 5)
copy = tensor.flatten()
ref = tensor.flatten(inplace=True)
assert same(copy.asnumpy(), tensor.reshape(2, 15).asnumpy())
assert same(ref.asnumpy(), tensor.reshape(2, 15).asnumpy())

tensor[0] = -1
assert not same(copy.asnumpy(), tensor.reshape(2, 15).asnumpy())
assert same(ref.asnumpy(), tensor.reshape(2, 15).asnumpy())


@with_seed()
def test_ndarray_squeeze():
def check_squeeze(shape, axis=None):
data = mx.random.uniform(low=-10.0, high=10.0, shape=shape)
copy = data.squeeze(axis=axis)
ref = data.squeeze(axis=axis, inplace=True)
out_expected = np.squeeze(data.asnumpy(), axis=axis)
if copy.shape == (1,): # as an exception (1, 1, 1) will be squeezed to (1,)
out_expected = np.squeeze(data.asnumpy(), axis=tuple([i for i in range(1, len(shape))]))
assert same(copy.asnumpy(), out_expected)
assert same(ref.asnumpy(), out_expected)
data[0][0] = -1
assert same(copy.asnumpy(), out_expected)
assert not same(ref.asnumpy(), out_expected)

# check forward
check_squeeze((1, 5, 1, 3, 1), 0)
check_squeeze((1, 5, 1, 3, 1), 2)
check_squeeze((1, 5, 1, 3, 1), 4)
check_squeeze((1, 5, 1, 3, 1), (0, 4))
check_squeeze((1, 5, 1, 3, 1), (0, 2, 4))
check_squeeze((1, 5, 1, 3, 1), -5)
check_squeeze((1, 5, 1, 3, 1), -3)
check_squeeze((1, 5, 1, 3, 1), -1)
check_squeeze((1, 5, 1, 3, 1), (0, 4))
check_squeeze((1, 5, 1, 3, 1), (0, 2, 4))
check_squeeze((1, 5, 1, 3, 1))
check_squeeze((1, 1, 1, 1))


@with_seed()
def test_ndarray_expand_dims():
for ndim in range(1, 6):
for axis in range(-ndim-1, ndim+1):
shape = list(np.random.randint(1, 10, size=ndim))
data = mx.random.normal(shape=shape)
copy = data.expand_dims(axis=axis)
ref = data.expand_dims(axis=axis, inplace=True)
out_expected = np.expand_dims(data.asnumpy(), axis=axis)
assert same(copy.asnumpy(), out_expected)
assert same(ref.asnumpy(), out_expected), (shape, axis, ref.asnumpy().shape, out_expected.shape)
data[0] = -1
assert same(copy.asnumpy(), out_expected)
assert not same(ref.asnumpy(), out_expected)


@with_seed()
def test_ndarray_choose():
shape = (100, 20)
Expand Down

0 comments on commit b7c70cb

Please sign in to comment.