Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
in-place reshape ops
Browse files Browse the repository at this point in the history
  • Loading branch information
szha committed Feb 19, 2019
1 parent 3a6fe22 commit 6cc0b12
Showing 1 changed file with 83 additions and 15 deletions.
98 changes: 83 additions & 15 deletions python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1261,13 +1261,29 @@ def sign(self, *args, **kwargs):
"""
return op.sign(self, *args, **kwargs)

def flatten(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`flatten`.
def flatten(self):
"""Returns a flattened **view** of this array without altering any data.
The arguments are the same as for :py:func:`flatten`, with
this array as data.
Returns
-------
NDArray
An array with flattened shape `(d1, d2*...*dk)` that shares data with
this array with shape `(d1, d2, ..., dk)`.
Examples
--------
>>> x = mx.nd.arange(30).reshape(5,2,3)
>>> y = x.flatten()
>>> y.shape
(5, 6)
>>> y[0].asnumpy()
array([0., 1., 2., 3., 4., 5.], dtype=float32)
>>> y[:] = -1
>>> x[0].asnumpy()
array([[-1., -1., -1.],
[-1., -1., -1.]], dtype=float32)
"""
return op.flatten(self, *args, **kwargs)
return self.reshape((0, -1))

def shape_array(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`shape_array`.
Expand All @@ -1285,13 +1301,43 @@ def size_array(self, *args, **kwargs):
"""
return op.size_array(self, *args, **kwargs)

def expand_dims(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`expand_dims`.
def expand_dims(self, axis):
"""Returns a **view** of this array with additional dimension without altering any data.
The arguments are the same as for :py:func:`expand_dims`, with
this array as data.
Parameters
----------
axis : int
Position where new axis is to be inserted.
Suppose that the input NDArray's dimension is ndim,
the range of the inserted axis is [-ndim, ndim].
Returns
-------
NDArray
An array with expanded shape `(d1, d2, ..., 1, di, ..., dk)`
that shares data with this array with shape `(d1, d2, ..., dk)`,
given input axis `i`.
Examples
--------
>>> x = mx.nd.arange(6).reshape(2,3)
>>> y = x.expand_dims(1)
>>> y.shape
(2, 1, 3)
>>> y[0].asnumpy()
array([[0., 1., 2.]], dtype=float32)
>>> y[:] = -1
>>> x.asnumpy()
array([[-1., -1., -1.],
[-1., -1., -1.]], dtype=float32)
"""
return op.expand_dims(self, *args, **kwargs)
new_shape = list(self.shape)
assert -len(new_shape) <= axis <= len(new_shape), \
"axis {} is out of range for {}d array".format(axis, len(new_shape))
if axis < 0:
axis += len(new_shape)
new_shape.insert(axis, 1)
return self.reshape(new_shape)

def tile(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`tile`.
Expand Down Expand Up @@ -1701,13 +1747,35 @@ def softmin(self, *args, **kwargs):
"""
return op.softmin(self, *args, **kwargs)

def squeeze(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`squeeze`.
def squeeze(self, axis=None):
"""Returns a **view** of this array with squeezed shape without altering any data.
The arguments are the same as for :py:func:`squeeze`, with
this array as data.
Parameters
----------
axis : int, tuple of int, or None
Selects a subset of the single-dimensional entries in the shape.
If an axis is selected with shape entry greater than one, an error is raised.
"""
return op.squeeze(self, *args, **kwargs)
new_shape = list(self.shape)
if isinstance(axis, int):
axis = [axis]
if axis:
assert len(axis) == len(set(axis)), \
"axis {} contains duplicate which is not allowed.".format(axis)
for i in axis:
assert 0 <= i < len(new_shape), "axis {} is out of range.".format(i)
assert new_shape[i] == 1, \
"Squeeze target axis {} must be size 1, got {}.".format(i, new_shape[i])
for i in sorted(axis, reverse=True):
del new_shape[i]
else:
for i in reversed(range(len(new_shape))):
if new_shape[i] == 1:
del new_shape[i]
if not new_shape:
new_shape.append(1)

return self.reshape(new_shape)

# pylint: disable= undefined-variable
def broadcast_to(self, shape):
Expand Down

0 comments on commit 6cc0b12

Please sign in to comment.