Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
ffi_roll/rot90
Browse files Browse the repository at this point in the history
  • Loading branch information
Tommliu committed Mar 17, 2020
1 parent 1368a08 commit b89c6c2
Show file tree
Hide file tree
Showing 9 changed files with 250 additions and 71 deletions.
2 changes: 2 additions & 0 deletions benchmark/python/ffi/benchmark_ffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ def prepare_workloads():
OpArgMngr.add_workload("cumsum", pool['3x2'], axis=0, out=pool['3x2'])
OpArgMngr.add_workload("add", pool['2x2'], pool['2x2'])
OpArgMngr.add_workload("random.uniform", low=0, high=1, size=1)
OpArgMngr.add_workload("roll", pool["2x2"], 1, axis=0)
OpArgMngr.add_workload("rot90", pool["2x2"], 2)


def benchmark_helper(f, *args, **kwargs):
Expand Down
64 changes: 0 additions & 64 deletions python/mxnet/_numpy_op_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,70 +538,6 @@ def _np_reshape(a, newshape, order='C', out=None):
"""


def _np_roll(a, shift, axis=None):
"""
Roll array elements along a given axis.
Elements that roll beyond the last position are re-introduced at
the first.
Parameters
----------
a : ndarray
Input array.
shift : int or tuple of ints
The number of places by which elements are shifted. If a tuple,
then `axis` must be a tuple of the same size, and each of the
given axes is shifted by the corresponding number. If an int
while `axis` is a tuple of ints, then the same value is used for
all given axes.
axis : int or tuple of ints, optional
Axis or axes along which elements are shifted. By default, the
array is flattened before shifting, after which the original
shape is restored.
Returns
-------
res : ndarray
Output array, with the same shape as `a`.
Notes
-----
Supports rolling over multiple dimensions simultaneously.
Examples
--------
>>> x = np.arange(10)
>>> np.roll(x, 2)
array([8., 9., 0., 1., 2., 3., 4., 5., 6., 7.])
>>> np.roll(x, -2)
array([2., 3., 4., 5., 6., 7., 8., 9., 0., 1.])
>>> x2 = np.reshape(x, (2,5))
>>> x2
array([[0., 1., 2., 3., 4.],
[5., 6., 7., 8., 9.]])
>>> np.roll(x2, 1)
array([[9., 0., 1., 2., 3.],
[4., 5., 6., 7., 8.]])
>>> np.roll(x2, -1)
array([[1., 2., 3., 4., 5.],
[6., 7., 8., 9., 0.]])
>>> np.roll(x2, 1, axis=0)
array([[5., 6., 7., 8., 9.],
[0., 1., 2., 3., 4.]])
>>> np.roll(x2, -1, axis=0)
array([[5., 6., 7., 8., 9.],
[0., 1., 2., 3., 4.]])
>>> np.roll(x2, 1, axis=1)
array([[4., 0., 1., 2., 3.],
[9., 5., 6., 7., 8.]])
>>> np.roll(x2, -1, axis=1)
array([[1., 2., 3., 4., 0.],
[6., 7., 8., 9., 5.]])
"""


def _np_trace(a, offset=0, axis1=0, axis2=1, out=None):
"""
Return the sum along diagonals of the array.
Expand Down
70 changes: 68 additions & 2 deletions python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
'diag_indices_from', 'hanning', 'hamming', 'blackman', 'flip', 'flipud', 'fliplr',
'hypot', 'bitwise_and', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm',
'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer',
'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'rot90', 'einsum',
'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'roll', 'rot90', 'einsum',
'true_divide', 'nonzero', 'quantile', 'percentile', 'shares_memory', 'may_share_memory',
'diff', 'ediff1d', 'resize', 'polyval', 'nan_to_num', 'isnan', 'isinf', 'isposinf', 'isneginf', 'isfinite',
'where', 'bincount', 'pad', 'cumsum']
Expand Down Expand Up @@ -6308,6 +6308,72 @@ def less_equal(x1, x2, out=None):
_npi.greater_equal_scalar, out)


@set_module('mxnet.ndarray.numpy')
def roll(a, shift, axis=None):
"""
Roll array elements along a given axis.
Elements that roll beyond the last position are re-introduced at
the first.
Parameters
----------
a : ndarray
Input array.
shift : int or tuple of ints
The number of places by which elements are shifted. If a tuple,
then `axis` must be a tuple of the same size, and each of the
given axes is shifted by the corresponding number. If an int
while `axis` is a tuple of ints, then the same value is used for
all given axes.
axis : int or tuple of ints, optional
Axis or axes along which elements are shifted. By default, the
array is flattened before shifting, after which the original
shape is restored.
Returns
-------
res : ndarray
Output array, with the same shape as `a`.
Notes
-----
Supports rolling over multiple dimensions simultaneously.
Examples
--------
>>> x = np.arange(10)
>>> np.roll(x, 2)
array([8., 9., 0., 1., 2., 3., 4., 5., 6., 7.])
>>> np.roll(x, -2)
array([2., 3., 4., 5., 6., 7., 8., 9., 0., 1.])
>>> x2 = np.reshape(x, (2,5))
>>> x2
array([[0., 1., 2., 3., 4.],
[5., 6., 7., 8., 9.]])
>>> np.roll(x2, 1)
array([[9., 0., 1., 2., 3.],
[4., 5., 6., 7., 8.]])
>>> np.roll(x2, -1)
array([[1., 2., 3., 4., 5.],
[6., 7., 8., 9., 0.]])
>>> np.roll(x2, 1, axis=0)
array([[5., 6., 7., 8., 9.],
[0., 1., 2., 3., 4.]])
>>> np.roll(x2, -1, axis=0)
array([[5., 6., 7., 8., 9.],
[0., 1., 2., 3., 4.]])
>>> np.roll(x2, 1, axis=1)
array([[4., 0., 1., 2., 3.],
[9., 5., 6., 7., 8.]])
>>> np.roll(x2, -1, axis=1)
array([[1., 2., 3., 4., 0.],
[6., 7., 8., 9., 5.]])
"""
return _api_internal.roll(a, shift, axis)


@set_module('mxnet.ndarray.numpy')
def rot90(m, k=1, axes=(0, 1)):
"""
Expand Down Expand Up @@ -6351,7 +6417,7 @@ def rot90(m, k=1, axes=(0, 1)):
[[5., 7.],
[4., 6.]]])
"""
return _npi.rot90(m, k=k, axes=axes)
return _api_internal.rot90(m, k, axes)


@set_module('mxnet.ndarray.numpy')
Expand Down
68 changes: 67 additions & 1 deletion python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
'flip', 'flipud', 'fliplr', 'around', 'round', 'round_', 'arctan2', 'hypot',
'bitwise_and', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad',
'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal',
'greater', 'less', 'greater_equal', 'less_equal', 'rot90', 'einsum', 'true_divide', 'nonzero',
'greater', 'less', 'greater_equal', 'less_equal', 'roll', 'rot90', 'einsum', 'true_divide', 'nonzero',
'quantile', 'percentile', 'shares_memory', 'may_share_memory', 'diff', 'ediff1d', 'resize', 'matmul',
'nan_to_num', 'isnan', 'isinf', 'isposinf', 'isneginf', 'isfinite', 'polyval', 'where', 'bincount',
'pad', 'cumsum']
Expand Down Expand Up @@ -8170,6 +8170,72 @@ def less_equal(x1, x2, out=None):
return _mx_nd_np.less_equal(x1, x2, out)


@set_module('mxnet.numpy')
def roll(a, shift, axis=None):
"""
Roll array elements along a given axis.
Elements that roll beyond the last position are re-introduced at
the first.
Parameters
----------
a : ndarray
Input array.
shift : int or tuple of ints
The number of places by which elements are shifted. If a tuple,
then `axis` must be a tuple of the same size, and each of the
given axes is shifted by the corresponding number. If an int
while `axis` is a tuple of ints, then the same value is used for
all given axes.
axis : int or tuple of ints, optional
Axis or axes along which elements are shifted. By default, the
array is flattened before shifting, after which the original
shape is restored.
Returns
-------
res : ndarray
Output array, with the same shape as `a`.
Notes
-----
Supports rolling over multiple dimensions simultaneously.
Examples
--------
>>> x = np.arange(10)
>>> np.roll(x, 2)
array([8., 9., 0., 1., 2., 3., 4., 5., 6., 7.])
>>> np.roll(x, -2)
array([2., 3., 4., 5., 6., 7., 8., 9., 0., 1.])
>>> x2 = np.reshape(x, (2,5))
>>> x2
array([[0., 1., 2., 3., 4.],
[5., 6., 7., 8., 9.]])
>>> np.roll(x2, 1)
array([[9., 0., 1., 2., 3.],
[4., 5., 6., 7., 8.]])
>>> np.roll(x2, -1)
array([[1., 2., 3., 4., 5.],
[6., 7., 8., 9., 0.]])
>>> np.roll(x2, 1, axis=0)
array([[5., 6., 7., 8., 9.],
[0., 1., 2., 3., 4.]])
>>> np.roll(x2, -1, axis=0)
array([[5., 6., 7., 8., 9.],
[0., 1., 2., 3., 4.]])
>>> np.roll(x2, 1, axis=1)
array([[4., 0., 1., 2., 3.],
[9., 5., 6., 7., 8.]])
>>> np.roll(x2, -1, axis=1)
array([[1., 2., 3., 4., 0.],
[6., 7., 8., 9., 5.]])
"""
return _mx_nd_np.roll(a, shift, axis=axis)


@set_module('mxnet.numpy')
def rot90(m, k=1, axes=(0, 1)):
"""
Expand Down
37 changes: 36 additions & 1 deletion python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
'diag_indices_from', 'hanning', 'hamming', 'blackman', 'flip', 'flipud', 'fliplr',
'hypot', 'bitwise_and', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm',
'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer',
'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'rot90', 'einsum',
'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'roll', 'rot90', 'einsum',
'true_divide', 'quantile', 'percentile', 'shares_memory', 'may_share_memory', 'diff', 'ediff1d',
'resize', 'polyval', 'nan_to_num', 'isnan', 'isinf', 'isposinf', 'isneginf', 'isfinite',
'where', 'bincount', 'pad', 'cumsum']
Expand Down Expand Up @@ -5841,6 +5841,41 @@ def less_equal(x1, x2, out=None):
_npi.greater_equal_scalar, out)


@set_module('mxnet.symbol.numpy')
def roll(a, shift, axis=None):
"""
Roll array elements along a given axis.
Elements that roll beyond the last position are re-introduced at
the first.
Parameters
----------
a : _Symbol
Input array.
shift : int or tuple of ints
The number of places by which elements are shifted. If a tuple,
then `axis` must be a tuple of the same size, and each of the
given axes is shifted by the corresponding number. If an int
while `axis` is a tuple of ints, then the same value is used for
all given axes.
axis : int or tuple of ints, optional
Axis or axes along which elements are shifted. By default, the
array is flattened before shifting, after which the original
shape is restored.
Returns
-------
res : _Symbol
Output array, with the same shape as `a`.
Notes
-----
Supports rolling over multiple dimensions simultaneously.
"""
return _npi.roll(a, shift, axis=axis)


@set_module('mxnet.symbol.numpy')
def rot90(m, k=1, axes=(0, 1)):
"""
Expand Down
56 changes: 56 additions & 0 deletions src/api/operator/numpy/np_matrix_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@
* \brief Implementation of the API of functions in src/operator/tensor/matrix_op.cc
*/
#include <mxnet/api_registry.h>
#include <mxnet/runtime/packed_func.h>
#include "../utils.h"
#include "../../../operator/tensor/matrix_op-inl.h"
#include "../../../operator/numpy/np_matrix_op-inl.h"

namespace mxnet {

Expand All @@ -46,4 +48,58 @@ MXNET_REGISTER_API("_npi.expand_dims")
*ret = ndoutputs[0];
});

MXNET_REGISTER_API("_npi.roll")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
static const nnvm::Op* op = Op::Get("_npi_roll");
nnvm::NodeAttrs attrs;
op::NumpyRollParam param;
if (args[1].type_code() == kNull) {
param.shift = dmlc::nullopt;
} else if (args[1].type_code() == kDLInt) {
param.shift = TShape(1, args[1].operator int64_t());
} else {
param.shift = TShape(args[1].operator ObjectRef());
}
if (args[2].type_code() == kNull) {
param.axis = dmlc::nullopt;
} else if (args[2].type_code() == kDLInt) {
param.axis = TShape(1, args[2].operator int64_t());
} else {
param.axis = TShape(args[2].operator ObjectRef());
}
attrs.parsed = std::move(param);
attrs.op = op;
SetAttrDict<op::NumpyRollParam>(&attrs);
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};
int num_inputs = 1;
int num_outputs = 0;
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr);
*ret = ndoutputs[0];
});

MXNET_REGISTER_API("_npi.rot90")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
static const nnvm::Op* op = Op::Get("_npi_rot90");
nnvm::NodeAttrs attrs;
op::NumpyRot90Param param;
param.k = args[1].operator int();
if (args[2].type_code() == kNull) {
param.axes = dmlc::nullopt;
} else if (args[2].type_code() == kDLInt) {
param.axes = TShape(1, args[2].operator int64_t());
} else {
param.axes = TShape(args[2].operator ObjectRef());
}
attrs.parsed = std::move(param);
attrs.op = op;
SetAttrDict<op::NumpyRot90Param>(&attrs);
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};
int num_inputs = 1;
int num_outputs = 0;
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr);
*ret = ndoutputs[0];
});

} // namespace mxnet
Loading

0 comments on commit b89c6c2

Please sign in to comment.