Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

in-place reshape ops #14053

Merged
merged 3 commits into from
Jul 28, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 111 additions & 15 deletions python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1259,13 +1259,38 @@ def sign(self, *args, **kwargs):
"""
return op.sign(self, *args, **kwargs)

def flatten(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`flatten`.
def flatten(self, inplace=False):
"""Flatten this array without altering any data.

The arguments are the same as for :py:func:`flatten`, with
this array as data.
Parameters
----------
inplace : bool, default False
If True, this method returns a **view** of this array
that shares data with this array. Otherwise, a copy is returned.

Returns
-------
NDArray
An array with flattened shape `(d1, d2*...*dk)` that shares data with
this array with shape `(d1, d2, ..., dk)`.

Examples
--------
>>> x = mx.nd.arange(30).reshape(5,2,3)
>>> y = x.flatten(inplace=True)
>>> z = x.flatten()
>>> y.shape
(5, 6)
>>> y[0].asnumpy()
array([0., 1., 2., 3., 4., 5.], dtype=float32)
>>> y[:] = -1
>>> x[0].asnumpy()
array([[-1., -1., -1.],
[-1., -1., -1.]], dtype=float32)
>>> z[0].asnumpy()
array([0., 1., 2., 3., 4., 5.], dtype=float32)
"""
return op.flatten(self, *args, **kwargs)
return op.flatten(self) if not inplace else self.reshape((0, -1))

def shape_array(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`shape_array`.
Expand All @@ -1283,13 +1308,52 @@ def size_array(self, *args, **kwargs):
"""
return op.size_array(self, *args, **kwargs)

def expand_dims(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`expand_dims`.
def expand_dims(self, axis, inplace=False):
"""Adds an additional dimension to the current array without altering any data.

The arguments are the same as for :py:func:`expand_dims`, with
this array as data.
Parameters
----------
axis : int
Position where new axis is to be inserted.
Suppose that the input NDArray's dimension is ndim,
the range of the inserted axis is [-ndim, ndim].
inplace : bool, default False
If True, this method returns a **view** of this array
that shares data with this array. Otherwise, a copy is returned.

Returns
-------
NDArray
An array with expanded shape `(d1, d2, ..., 1, di, ..., dk)`
that shares data with this array with shape `(d1, d2, ..., dk)`,
given input axis `i`.

Examples
--------
>>> x = mx.nd.arange(6).reshape(2,3)
>>> y = x.expand_dims(1, inplace=True)
>>> z = x.expand_dims(1)
>>> y.shape
(2, 1, 3)
>>> y[0].asnumpy()
array([[0., 1., 2.]], dtype=float32)
>>> y[:] = -1
>>> x.asnumpy()
array([[-1., -1., -1.],
[-1., -1., -1.]], dtype=float32)
>>> z[0].asnumpy()
array([[0., 1., 2.]], dtype=float32)
"""
return op.expand_dims(self, *args, **kwargs)
if not inplace:
return op.expand_dims(self, axis=axis)
else:
new_shape = list(self.shape)
assert -len(new_shape)-1 <= axis <= len(new_shape), \
"axis {} is out of range for {}d array".format(axis, len(new_shape))
if axis < 0:
axis += len(new_shape) + 1
new_shape.insert(axis, 1)
return self.reshape(new_shape)

def tile(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`tile`.
Expand Down Expand Up @@ -1699,13 +1763,45 @@ def softmin(self, *args, **kwargs):
"""
return op.softmin(self, *args, **kwargs)

def squeeze(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`squeeze`.
def squeeze(self, axis=None, inplace=False):
"""Remove dimensions with size 1 from this array without altering any data.

The arguments are the same as for :py:func:`squeeze`, with
this array as data.
Parameters
----------
axis : int, tuple of int, or None
Selects a subset of the single-dimensional entries in the shape.
If an axis is selected with shape entry greater than one, an error is raised.
inplace : bool, default False
If True, this method returns a **view** of this array
that shares data with this array. Otherwise, a copy is returned.
"""
return op.squeeze(self, *args, **kwargs)
if not inplace:
return op.squeeze(self, axis=axis)
else:
new_shape = list(self.shape)
axes = axis # rename variable for readability
if isinstance(axes, int):
axes = [axes]
if axes:
assert len(axes) == len(set(axes)), \
"axis {} contains duplicate which is not allowed.".format(axes)
resolved_axes = [i if i >= 0 else i+len(self.shape) for i in axes]
for arg_axis, actual_axis in zip(axes, resolved_axes):
assert -len(new_shape) <= arg_axis < len(new_shape), \
"axis {} is out of range for {}d array".format(arg_axis, len(new_shape))
axis_size = new_shape[actual_axis]
assert axis_size == 1, \
"Squeeze target axis {} must be size 1, got {}.".format(arg_axis, axis_size)
for i in sorted(resolved_axes, reverse=True):
del new_shape[i]
else:
for i in reversed(range(len(new_shape))):
if new_shape[i] == 1:
del new_shape[i]
if not new_shape:
new_shape.append(1)

return self.reshape(new_shape)

# pylint: disable= undefined-variable
def broadcast_to(self, shape):
Expand Down
12 changes: 6 additions & 6 deletions python/mxnet/symbol/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -2083,13 +2083,13 @@ def sign(self, *args, **kwargs):
"""
return op.sign(self, *args, **kwargs)

def flatten(self, *args, **kwargs):
def flatten(self, inplace=False, **kwargs): # pylint: disable=unused-argument
"""Convenience fluent method for :py:func:`flatten`.

The arguments are the same as for :py:func:`flatten`, with
this array as data.
"""
return op.flatten(self, *args, **kwargs)
return op.flatten(self, **kwargs)

def shape_array(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`shape_array`.
Expand All @@ -2107,13 +2107,13 @@ def size_array(self, *args, **kwargs):
"""
return op.size_array(self, *args, **kwargs)

def expand_dims(self, *args, **kwargs):
def expand_dims(self, axis, inplace=False, **kwargs): # pylint: disable=unused-argument
"""Convenience fluent method for :py:func:`expand_dims`.

The arguments are the same as for :py:func:`expand_dims`, with
this array as data.
"""
return op.expand_dims(self, *args, **kwargs)
return op.expand_dims(self, axis=axis, **kwargs)

def broadcast_to(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`broadcast_to`.
Expand Down Expand Up @@ -2539,13 +2539,13 @@ def softmin(self, *args, **kwargs):
"""
return op.softmin(self, *args, **kwargs)

def squeeze(self, *args, **kwargs):
def squeeze(self, axis, inplace=False, **kwargs): # pylint: disable=unused-argument
"""Convenience fluent method for :py:func:`squeeze`.

The arguments are the same as for :py:func:`squeeze`, with
this array as data.
"""
return op.squeeze(self, *args, **kwargs)
return op.squeeze(self, axis=axis, **kwargs)

def get_backend_symbol(self, backend):
"""Return symbol for target backend.
Expand Down
59 changes: 59 additions & 0 deletions tests/python/unittest/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,65 @@ def test_ndarray_reshape():
assert same(tensor.reshape(-1, 0, reverse=True).asnumpy(), true_res.reshape(6, 5).asnumpy())


@with_seed()
def test_ndarray_flatten():
tensor = (mx.nd.arange(30) + 1).reshape(2, 3, 5)
copy = tensor.flatten()
ref = tensor.flatten(inplace=True)
assert same(copy.asnumpy(), tensor.reshape(2, 15).asnumpy())
assert same(ref.asnumpy(), tensor.reshape(2, 15).asnumpy())

tensor[0] = -1
assert not same(copy.asnumpy(), tensor.reshape(2, 15).asnumpy())
assert same(ref.asnumpy(), tensor.reshape(2, 15).asnumpy())


@with_seed()
def test_ndarray_squeeze():
def check_squeeze(shape, axis=None):
data = mx.random.uniform(low=-10.0, high=10.0, shape=shape)
copy = data.squeeze(axis=axis)
ref = data.squeeze(axis=axis, inplace=True)
out_expected = np.squeeze(data.asnumpy(), axis=axis)
if copy.shape == (1,): # as an exception (1, 1, 1) will be squeezed to (1,)
out_expected = np.squeeze(data.asnumpy(), axis=tuple([i for i in range(1, len(shape))]))
assert same(copy.asnumpy(), out_expected)
assert same(ref.asnumpy(), out_expected)
data[0][0] = -1
assert same(copy.asnumpy(), out_expected)
assert not same(ref.asnumpy(), out_expected)

# check forward
check_squeeze((1, 5, 1, 3, 1), 0)
check_squeeze((1, 5, 1, 3, 1), 2)
check_squeeze((1, 5, 1, 3, 1), 4)
check_squeeze((1, 5, 1, 3, 1), (0, 4))
check_squeeze((1, 5, 1, 3, 1), (0, 2, 4))
check_squeeze((1, 5, 1, 3, 1), -5)
check_squeeze((1, 5, 1, 3, 1), -3)
check_squeeze((1, 5, 1, 3, 1), -1)
check_squeeze((1, 5, 1, 3, 1), (0, 4))
check_squeeze((1, 5, 1, 3, 1), (0, 2, 4))
check_squeeze((1, 5, 1, 3, 1))
check_squeeze((1, 1, 1, 1))


@with_seed()
def test_ndarray_expand_dims():
for ndim in range(1, 6):
for axis in range(-ndim-1, ndim+1):
shape = list(np.random.randint(1, 10, size=ndim))
data = mx.random.normal(shape=shape)
copy = data.expand_dims(axis=axis)
ref = data.expand_dims(axis=axis, inplace=True)
out_expected = np.expand_dims(data.asnumpy(), axis=axis)
assert same(copy.asnumpy(), out_expected)
assert same(ref.asnumpy(), out_expected), (shape, axis, ref.asnumpy().shape, out_expected.shape)
data[0] = -1
assert same(copy.asnumpy(), out_expected)
assert not same(ref.asnumpy(), out_expected)


@with_seed()
def test_ndarray_choose():
shape = (100, 20)
Expand Down