Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
numpy operator ravel, derive from reshape
Browse files Browse the repository at this point in the history
* it is the same as reshape(x, -1)

* register reshape with prefix _npi_

* fix format error

* edit examples in doc

* fix error in review

* add out in wrapper
  • Loading branch information
Ying committed Sep 12, 2019
1 parent 287e3b5 commit 076a24e
Show file tree
Hide file tree
Showing 5 changed files with 213 additions and 3 deletions.
54 changes: 53 additions & 1 deletion python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor',
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot',
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'mean',
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices']
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'ravel']


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -2432,3 +2432,55 @@ def indices(dimensions, dtype=_np.int32, ctx=None):
else:
raise ValueError("The dimensions must be sequence of ints")
# pylint: enable=redefined-outer-name


@set_module('mxnet.ndarray.numpy')
def ravel(x, out=None):
r"""
ravel(x, out=None)
Return a contiguous flattened array.
A 1-D array, containing the elements of the input, is returned. A copy is
made only if needed.
Parameters
----------
x : ndarray
Input array. The elements in `x` are read in row-major, C-style order and
packed as a 1-D array.
out : ndarray or None, optional
A location into which the result is stored. If not provided or `None`,
a freshly-allocated array is returned.
Returns
-------
y : ndarray
y is an array of the same subtype as `x`, with shape ``(x.size,)``.
Note that matrices are special cased for backward compatibility, if `x`
is a matrix, then y is a 1-D ndarray.
Notes
-----
This function differs from the original numpy.arange in the following aspects:
- Only support row-major, C-style order.
Examples
--------
It is equivalent to ``reshape(x, -1)``.
>>> x = np.array([[1, 2, 3], [4, 5, 6]])
>>> print(np.ravel(x))
[1. 2. 3. 4. 5. 6.]
>>> print(x.reshape(-1))
[1. 2. 3. 4. 5. 6.]
>>> print(np.ravel(x.T))
[1. 4. 2. 5. 3. 6.]
"""
if isinstance(x, numeric_types):
return _np.reshape(x, -1)
elif isinstance(x, NDArray):
return _npi.reshape(x, -1, out=out)
else:
raise TypeError('type {} not supported'.format(str(type(x))))
49 changes: 48 additions & 1 deletion python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative',
'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh',
'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate',
'stack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices']
'stack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'ravel']

# Return code for dispatching indexing function call
_NDARRAY_UNSUPPORTED_INDEXING = -1
Expand Down Expand Up @@ -3873,3 +3873,50 @@ def indices(dimensions, dtype=_np.int32, ctx=None):
"""
return _mx_nd_np.indices(dimensions=dimensions, dtype=dtype, ctx=ctx)
# pylint: enable=redefined-outer-name


@set_module('mxnet.numpy')
def ravel(x, out=None):
r"""
ravel(x, out=None)
Return a contiguous flattened array.
A 1-D array, containing the elements of the input, is returned. A copy is
made only if needed.
Parameters
----------
x : ndarray
Input array. The elements in `x` are read in row-major, C-style order and
packed as a 1-D array.
out : ndarray or None, optional
A location into which the result is stored. If not provided or `None`,
a freshly-allocated array is returned.
Returns
-------
y : ndarray
y is an array of the same subtype as `x`, with shape ``(x.size,)``.
Note that matrices are special cased for backward compatibility, if `x`
is a matrix, then y is a 1-D ndarray.
Notes
-----
This function differs from the original numpy.arange in the following aspects:
- Only support row-major, C-style order.
Examples
--------
It is equivalent to ``reshape(x, -1)``.
>>> x = np.array([[1, 2, 3], [4, 5, 6]])
>>> print(np.ravel(x))
[1. 2. 3. 4. 5. 6.]
>>> print(x.reshape(-1))
[1. 2. 3. 4. 5. 6.]
>>> print(np.ravel(x.T))
[1. 4. 2. 5. 3. 6.]
"""
return _mx_nd_np.ravel(x, out=out)
40 changes: 39 additions & 1 deletion python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor',
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot',
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'mean',
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices']
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'ravel']


def _num_outputs(sym):
Expand Down Expand Up @@ -2748,4 +2748,42 @@ def indices(dimensions, dtype=_np.int32, ctx=None):
# pylint: enable=redefined-outer-name


@set_module('mxnet.symbol.numpy')
def ravel(x, out=None):
r"""
ravel(x, out=None)
Return a contiguous flattened array.
A 1-D array, containing the elements of the input, is returned. A copy is
made only if needed.
Parameters
----------
x : ndarray
Input array. The elements in `x` are read in row-major, C-style order and
packed as a 1-D array.
out : ndarray or None, optional
A location into which the result is stored. If not provided or `None`,
a freshly-allocated array is returned.
Returns
-------
y : ndarray
y is an array of the same subtype as `x`, with shape ``(x.size,)``.
Note that matrices are special cased for backward compatibility, if `x`
is a matrix, then y is a 1-D ndarray.
Notes
-----
This function differs from the original numpy.arange in the following aspects:
- Only support row-major, C-style order.
"""
if isinstance(x, numeric_types):
return _np.reshape(x, -1)
elif isinstance(x, _Symbol):
return _npi.reshape(x, -1, out=out)
else:
raise TypeError('type {} not supported'.format(str(type(x))))


_set_np_symbol_class(_Symbol)
1 change: 1 addition & 0 deletions src/operator/numpy/np_matrix_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ bool NumpyReshapeShape(const nnvm::NodeAttrs& attrs,

NNVM_REGISTER_OP(_np_reshape)
.describe(R"code()code" ADD_FILELINE)
.add_alias("_npi_reshape")
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr_parser(ParamParser<NumpyReshapeParam>)
Expand Down
72 changes: 72 additions & 0 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1324,6 +1324,39 @@ def hybrid_forward(self, F, a, *args):
assert same(mx_out.asnumpy(), np_out)


@with_seed()
@use_np
def test_np_ravel():
class TestRavel(HybridBlock):
def __init__(self):
super(TestRavel, self).__init__()

def hybrid_forward(self, F, a):
return F.np.ravel(a)

types = ['float64', 'float32', 'float16', 'int64', 'int32', 'int8']
for oneType in types:
for hybridize in [True, False]:
for shape in [(), (2,), (2, 2), (1, 2, 3), (3, 0), (1, 0, 2)]:
test_ravel = TestRavel()
if hybridize:
test_ravel.hybridize()
x = rand_ndarray(shape, dtype=oneType).as_np_ndarray()
x.attach_grad()
np_out = _np.ravel(x.asnumpy())
with mx.autograd.record():
mx_out = test_ravel(x)
assert mx_out.shape == np_out.shape
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
mx_out.backward()
np_backward = _np.ones(shape)
assert_almost_equal(x.grad.asnumpy(), np_backward, rtol=1e-3, atol=1e-5)

mx_out = np.ravel(x)
np_out = _np.ravel(x.asnumpy())
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)


@with_seed()
@use_np
def test_np_randint():
Expand Down Expand Up @@ -1790,6 +1823,45 @@ def hybrid_forward(self, F, x):
assert mx_out.shape == np_out.shape


@with_seed()
@use_np
def test_np_ravel():
class TestRavel(HybridBlock):
def __init__(self):
super(TestRavel, self).__init__()

def hybrid_forward(self, F, a):
return F.np.ravel(a)

types = ['float64', 'float32', 'float16', 'int64', 'int32', 'int8']
for oneType in types:
for hybridize in [True, False]:
for shape in [(),
(2,),
(2, 2),
(1, 2, 3),
(3, 0),
(1, 0, 2)
]:
test_ravel = TestRavel()
if hybridize:
test_ravel.hybridize()
x = rand_ndarray(shape, dtype=oneType).as_np_ndarray()
x.attach_grad()
np_out = _np.ravel(x.asnumpy())
with mx.autograd.record():
mx_out = test_ravel(x)
assert mx_out.shape == np_out.shape
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
mx_out.backward()
np_backward = _np.ones(shape)
assert_almost_equal(x.grad.asnumpy(), np_backward, rtol=1e-3, atol=1e-5)

mx_out = np.ravel(x)
np_out = _np.ravel(x.asnumpy())
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)


if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit 076a24e

Please sign in to comment.