Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
in-place reshape ops (#14053)
Browse files Browse the repository at this point in the history
* in-place reshape ops

* add inplace option

* add dummy arguments to symbol
  • Loading branch information
szha authored Jul 28, 2019
1 parent 08fd98d commit 3ececb3
Show file tree
Hide file tree
Showing 3 changed files with 176 additions and 21 deletions.
126 changes: 111 additions & 15 deletions python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1259,13 +1259,38 @@ def sign(self, *args, **kwargs):
"""
return op.sign(self, *args, **kwargs)

def flatten(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`flatten`.
def flatten(self, inplace=False):
"""Flatten this array without altering any data.
The arguments are the same as for :py:func:`flatten`, with
this array as data.
Parameters
----------
inplace : bool, default False
If True, this method returns a **view** of this array
that shares data with this array. Otherwise, a copy is returned.
Returns
-------
NDArray
An array with flattened shape `(d1, d2*...*dk)` that shares data with
this array with shape `(d1, d2, ..., dk)`.
Examples
--------
>>> x = mx.nd.arange(30).reshape(5,2,3)
>>> y = x.flatten(inplace=True)
>>> z = x.flatten()
>>> y.shape
(5, 6)
>>> y[0].asnumpy()
array([0., 1., 2., 3., 4., 5.], dtype=float32)
>>> y[:] = -1
>>> x[0].asnumpy()
array([[-1., -1., -1.],
[-1., -1., -1.]], dtype=float32)
>>> z[0].asnumpy()
array([0., 1., 2., 3., 4., 5.], dtype=float32)
"""
return op.flatten(self, *args, **kwargs)
return op.flatten(self) if not inplace else self.reshape((0, -1))

def shape_array(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`shape_array`.
Expand All @@ -1283,13 +1308,52 @@ def size_array(self, *args, **kwargs):
"""
return op.size_array(self, *args, **kwargs)

def expand_dims(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`expand_dims`.
def expand_dims(self, axis, inplace=False):
"""Adds an additional dimension to the current array without altering any data.
The arguments are the same as for :py:func:`expand_dims`, with
this array as data.
Parameters
----------
axis : int
Position where new axis is to be inserted.
Suppose that the input NDArray's dimension is ndim,
the range of the inserted axis is [-ndim, ndim].
inplace : bool, default False
If True, this method returns a **view** of this array
that shares data with this array. Otherwise, a copy is returned.
Returns
-------
NDArray
An array with expanded shape `(d1, d2, ..., 1, di, ..., dk)`
that shares data with this array with shape `(d1, d2, ..., dk)`,
given input axis `i`.
Examples
--------
>>> x = mx.nd.arange(6).reshape(2,3)
>>> y = x.expand_dims(1, inplace=True)
>>> z = x.expand_dims(1)
>>> y.shape
(2, 1, 3)
>>> y[0].asnumpy()
array([[0., 1., 2.]], dtype=float32)
>>> y[:] = -1
>>> x.asnumpy()
array([[-1., -1., -1.],
[-1., -1., -1.]], dtype=float32)
>>> z[0].asnumpy()
array([[0., 1., 2.]], dtype=float32)
"""
return op.expand_dims(self, *args, **kwargs)
if not inplace:
return op.expand_dims(self, axis=axis)
else:
new_shape = list(self.shape)
assert -len(new_shape)-1 <= axis <= len(new_shape), \
"axis {} is out of range for {}d array".format(axis, len(new_shape))
if axis < 0:
axis += len(new_shape) + 1
new_shape.insert(axis, 1)
return self.reshape(new_shape)

def tile(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`tile`.
Expand Down Expand Up @@ -1699,13 +1763,45 @@ def softmin(self, *args, **kwargs):
"""
return op.softmin(self, *args, **kwargs)

def squeeze(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`squeeze`.
def squeeze(self, axis=None, inplace=False):
"""Remove dimensions with size 1 from this array without altering any data.
The arguments are the same as for :py:func:`squeeze`, with
this array as data.
Parameters
----------
axis : int, tuple of int, or None
Selects a subset of the single-dimensional entries in the shape.
If an axis is selected with shape entry greater than one, an error is raised.
inplace : bool, default False
If True, this method returns a **view** of this array
that shares data with this array. Otherwise, a copy is returned.
"""
return op.squeeze(self, *args, **kwargs)
if not inplace:
return op.squeeze(self, axis=axis)
else:
new_shape = list(self.shape)
axes = axis # rename variable for readability
if isinstance(axes, int):
axes = [axes]
if axes:
assert len(axes) == len(set(axes)), \
"axis {} contains duplicate which is not allowed.".format(axes)
resolved_axes = [i if i >= 0 else i+len(self.shape) for i in axes]
for arg_axis, actual_axis in zip(axes, resolved_axes):
assert -len(new_shape) <= arg_axis < len(new_shape), \
"axis {} is out of range for {}d array".format(arg_axis, len(new_shape))
axis_size = new_shape[actual_axis]
assert axis_size == 1, \
"Squeeze target axis {} must be size 1, got {}.".format(arg_axis, axis_size)
for i in sorted(resolved_axes, reverse=True):
del new_shape[i]
else:
for i in reversed(range(len(new_shape))):
if new_shape[i] == 1:
del new_shape[i]
if not new_shape:
new_shape.append(1)

return self.reshape(new_shape)

# pylint: disable= undefined-variable
def broadcast_to(self, shape):
Expand Down
12 changes: 6 additions & 6 deletions python/mxnet/symbol/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -2083,13 +2083,13 @@ def sign(self, *args, **kwargs):
"""
return op.sign(self, *args, **kwargs)

def flatten(self, *args, **kwargs):
def flatten(self, inplace=False, **kwargs): # pylint: disable=unused-argument
"""Convenience fluent method for :py:func:`flatten`.
The arguments are the same as for :py:func:`flatten`, with
this array as data.
"""
return op.flatten(self, *args, **kwargs)
return op.flatten(self, **kwargs)

def shape_array(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`shape_array`.
Expand All @@ -2107,13 +2107,13 @@ def size_array(self, *args, **kwargs):
"""
return op.size_array(self, *args, **kwargs)

def expand_dims(self, *args, **kwargs):
def expand_dims(self, axis, inplace=False, **kwargs): # pylint: disable=unused-argument
"""Convenience fluent method for :py:func:`expand_dims`.
The arguments are the same as for :py:func:`expand_dims`, with
this array as data.
"""
return op.expand_dims(self, *args, **kwargs)
return op.expand_dims(self, axis=axis, **kwargs)

def broadcast_to(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`broadcast_to`.
Expand Down Expand Up @@ -2539,13 +2539,13 @@ def softmin(self, *args, **kwargs):
"""
return op.softmin(self, *args, **kwargs)

def squeeze(self, *args, **kwargs):
def squeeze(self, axis, inplace=False, **kwargs): # pylint: disable=unused-argument
"""Convenience fluent method for :py:func:`squeeze`.
The arguments are the same as for :py:func:`squeeze`, with
this array as data.
"""
return op.squeeze(self, *args, **kwargs)
return op.squeeze(self, axis=axis, **kwargs)

def get_backend_symbol(self, backend):
"""Return symbol for target backend.
Expand Down
59 changes: 59 additions & 0 deletions tests/python/unittest/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,65 @@ def test_ndarray_reshape():
assert same(tensor.reshape(-1, 0, reverse=True).asnumpy(), true_res.reshape(6, 5).asnumpy())


@with_seed()
def test_ndarray_flatten():
tensor = (mx.nd.arange(30) + 1).reshape(2, 3, 5)
copy = tensor.flatten()
ref = tensor.flatten(inplace=True)
assert same(copy.asnumpy(), tensor.reshape(2, 15).asnumpy())
assert same(ref.asnumpy(), tensor.reshape(2, 15).asnumpy())

tensor[0] = -1
assert not same(copy.asnumpy(), tensor.reshape(2, 15).asnumpy())
assert same(ref.asnumpy(), tensor.reshape(2, 15).asnumpy())


@with_seed()
def test_ndarray_squeeze():
def check_squeeze(shape, axis=None):
data = mx.random.uniform(low=-10.0, high=10.0, shape=shape)
copy = data.squeeze(axis=axis)
ref = data.squeeze(axis=axis, inplace=True)
out_expected = np.squeeze(data.asnumpy(), axis=axis)
if copy.shape == (1,): # as an exception (1, 1, 1) will be squeezed to (1,)
out_expected = np.squeeze(data.asnumpy(), axis=tuple([i for i in range(1, len(shape))]))
assert same(copy.asnumpy(), out_expected)
assert same(ref.asnumpy(), out_expected)
data[0][0] = -1
assert same(copy.asnumpy(), out_expected)
assert not same(ref.asnumpy(), out_expected)

# check forward
check_squeeze((1, 5, 1, 3, 1), 0)
check_squeeze((1, 5, 1, 3, 1), 2)
check_squeeze((1, 5, 1, 3, 1), 4)
check_squeeze((1, 5, 1, 3, 1), (0, 4))
check_squeeze((1, 5, 1, 3, 1), (0, 2, 4))
check_squeeze((1, 5, 1, 3, 1), -5)
check_squeeze((1, 5, 1, 3, 1), -3)
check_squeeze((1, 5, 1, 3, 1), -1)
check_squeeze((1, 5, 1, 3, 1), (0, 4))
check_squeeze((1, 5, 1, 3, 1), (0, 2, 4))
check_squeeze((1, 5, 1, 3, 1))
check_squeeze((1, 1, 1, 1))


@with_seed()
def test_ndarray_expand_dims():
for ndim in range(1, 6):
for axis in range(-ndim-1, ndim+1):
shape = list(np.random.randint(1, 10, size=ndim))
data = mx.random.normal(shape=shape)
copy = data.expand_dims(axis=axis)
ref = data.expand_dims(axis=axis, inplace=True)
out_expected = np.expand_dims(data.asnumpy(), axis=axis)
assert same(copy.asnumpy(), out_expected)
assert same(ref.asnumpy(), out_expected), (shape, axis, ref.asnumpy().shape, out_expected.shape)
data[0] = -1
assert same(copy.asnumpy(), out_expected)
assert not same(ref.asnumpy(), out_expected)


@with_seed()
def test_ndarray_choose():
shape = (100, 20)
Expand Down

0 comments on commit 3ececb3

Please sign in to comment.