Skip to content

Commit 2a25b1d

Browse files
Tommliuanirudh2290
authored andcommitted
ffi for roll/rot90 (apache#17861)
1 parent 80c29dc commit 2a25b1d

File tree

9 files changed

+250
-71
lines changed

9 files changed

+250
-71
lines changed

benchmark/python/ffi/benchmark_ffi.py

+2
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,8 @@ def prepare_workloads():
7878
OpArgMngr.add_workload("random.uniform", low=0, high=1, size=1)
7979
OpArgMngr.add_workload("where", pool['2x3'], pool['2x3'], pool['2x1'])
8080
OpArgMngr.add_workload("may_share_memory", pool['2x3'][:0], pool['2x3'][:1])
81+
OpArgMngr.add_workload("roll", pool["2x2"], 1, axis=0)
82+
OpArgMngr.add_workload("rot90", pool["2x2"], 2)
8183

8284

8385
def benchmark_helper(f, *args, **kwargs):

python/mxnet/_numpy_op_doc.py

-64
Original file line numberDiff line numberDiff line change
@@ -538,70 +538,6 @@ def _np_reshape(a, newshape, order='C', out=None):
538538
"""
539539

540540

541-
def _np_roll(a, shift, axis=None):
542-
"""
543-
Roll array elements along a given axis.
544-
545-
Elements that roll beyond the last position are re-introduced at
546-
the first.
547-
548-
Parameters
549-
----------
550-
a : ndarray
551-
Input array.
552-
shift : int or tuple of ints
553-
The number of places by which elements are shifted. If a tuple,
554-
then `axis` must be a tuple of the same size, and each of the
555-
given axes is shifted by the corresponding number. If an int
556-
while `axis` is a tuple of ints, then the same value is used for
557-
all given axes.
558-
axis : int or tuple of ints, optional
559-
Axis or axes along which elements are shifted. By default, the
560-
array is flattened before shifting, after which the original
561-
shape is restored.
562-
563-
Returns
564-
-------
565-
res : ndarray
566-
Output array, with the same shape as `a`.
567-
568-
Notes
569-
-----
570-
Supports rolling over multiple dimensions simultaneously.
571-
572-
Examples
573-
--------
574-
>>> x = np.arange(10)
575-
>>> np.roll(x, 2)
576-
array([8., 9., 0., 1., 2., 3., 4., 5., 6., 7.])
577-
>>> np.roll(x, -2)
578-
array([2., 3., 4., 5., 6., 7., 8., 9., 0., 1.])
579-
580-
>>> x2 = np.reshape(x, (2,5))
581-
>>> x2
582-
array([[0., 1., 2., 3., 4.],
583-
[5., 6., 7., 8., 9.]])
584-
>>> np.roll(x2, 1)
585-
array([[9., 0., 1., 2., 3.],
586-
[4., 5., 6., 7., 8.]])
587-
>>> np.roll(x2, -1)
588-
array([[1., 2., 3., 4., 5.],
589-
[6., 7., 8., 9., 0.]])
590-
>>> np.roll(x2, 1, axis=0)
591-
array([[5., 6., 7., 8., 9.],
592-
[0., 1., 2., 3., 4.]])
593-
>>> np.roll(x2, -1, axis=0)
594-
array([[5., 6., 7., 8., 9.],
595-
[0., 1., 2., 3., 4.]])
596-
>>> np.roll(x2, 1, axis=1)
597-
array([[4., 0., 1., 2., 3.],
598-
[9., 5., 6., 7., 8.]])
599-
>>> np.roll(x2, -1, axis=1)
600-
array([[1., 2., 3., 4., 0.],
601-
[6., 7., 8., 9., 5.]])
602-
"""
603-
604-
605541
def _np_trace(a, offset=0, axis1=0, axis2=1, out=None):
606542
"""
607543
Return the sum along diagonals of the array.

python/mxnet/ndarray/numpy/_op.py

+68-2
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
'diag_indices_from', 'hanning', 'hamming', 'blackman', 'flip', 'flipud', 'fliplr',
4444
'hypot', 'bitwise_and', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm',
4545
'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer',
46-
'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'rot90', 'einsum',
46+
'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'roll', 'rot90', 'einsum',
4747
'true_divide', 'nonzero', 'quantile', 'percentile', 'shares_memory', 'may_share_memory',
4848
'diff', 'ediff1d', 'resize', 'polyval', 'nan_to_num', 'isnan', 'isinf', 'isposinf', 'isneginf', 'isfinite',
4949
'where', 'bincount', 'pad', 'cumsum']
@@ -6307,6 +6307,72 @@ def less_equal(x1, x2, out=None):
63076307
_npi.greater_equal_scalar, out)
63086308

63096309

6310+
@set_module('mxnet.ndarray.numpy')
6311+
def roll(a, shift, axis=None):
6312+
"""
6313+
Roll array elements along a given axis.
6314+
6315+
Elements that roll beyond the last position are re-introduced at
6316+
the first.
6317+
6318+
Parameters
6319+
----------
6320+
a : ndarray
6321+
Input array.
6322+
shift : int or tuple of ints
6323+
The number of places by which elements are shifted. If a tuple,
6324+
then `axis` must be a tuple of the same size, and each of the
6325+
given axes is shifted by the corresponding number. If an int
6326+
while `axis` is a tuple of ints, then the same value is used for
6327+
all given axes.
6328+
axis : int or tuple of ints, optional
6329+
Axis or axes along which elements are shifted. By default, the
6330+
array is flattened before shifting, after which the original
6331+
shape is restored.
6332+
6333+
Returns
6334+
-------
6335+
res : ndarray
6336+
Output array, with the same shape as `a`.
6337+
6338+
Notes
6339+
-----
6340+
Supports rolling over multiple dimensions simultaneously.
6341+
6342+
Examples
6343+
--------
6344+
>>> x = np.arange(10)
6345+
>>> np.roll(x, 2)
6346+
array([8., 9., 0., 1., 2., 3., 4., 5., 6., 7.])
6347+
>>> np.roll(x, -2)
6348+
array([2., 3., 4., 5., 6., 7., 8., 9., 0., 1.])
6349+
6350+
>>> x2 = np.reshape(x, (2,5))
6351+
>>> x2
6352+
array([[0., 1., 2., 3., 4.],
6353+
[5., 6., 7., 8., 9.]])
6354+
>>> np.roll(x2, 1)
6355+
array([[9., 0., 1., 2., 3.],
6356+
[4., 5., 6., 7., 8.]])
6357+
>>> np.roll(x2, -1)
6358+
array([[1., 2., 3., 4., 5.],
6359+
[6., 7., 8., 9., 0.]])
6360+
>>> np.roll(x2, 1, axis=0)
6361+
array([[5., 6., 7., 8., 9.],
6362+
[0., 1., 2., 3., 4.]])
6363+
>>> np.roll(x2, -1, axis=0)
6364+
array([[5., 6., 7., 8., 9.],
6365+
[0., 1., 2., 3., 4.]])
6366+
>>> np.roll(x2, 1, axis=1)
6367+
array([[4., 0., 1., 2., 3.],
6368+
[9., 5., 6., 7., 8.]])
6369+
>>> np.roll(x2, -1, axis=1)
6370+
array([[1., 2., 3., 4., 0.],
6371+
[6., 7., 8., 9., 5.]])
6372+
"""
6373+
return _api_internal.roll(a, shift, axis)
6374+
6375+
63106376
@set_module('mxnet.ndarray.numpy')
63116377
def rot90(m, k=1, axes=(0, 1)):
63126378
"""
@@ -6350,7 +6416,7 @@ def rot90(m, k=1, axes=(0, 1)):
63506416
[[5., 7.],
63516417
[4., 6.]]])
63526418
"""
6353-
return _npi.rot90(m, k=k, axes=axes)
6419+
return _api_internal.rot90(m, k, axes)
63546420

63556421

63566422
@set_module('mxnet.ndarray.numpy')

python/mxnet/numpy/multiarray.py

+67-1
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@
6666
'flip', 'flipud', 'fliplr', 'around', 'round', 'round_', 'arctan2', 'hypot',
6767
'bitwise_and', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad',
6868
'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal',
69-
'greater', 'less', 'greater_equal', 'less_equal', 'rot90', 'einsum', 'true_divide', 'nonzero',
69+
'greater', 'less', 'greater_equal', 'less_equal', 'roll', 'rot90', 'einsum', 'true_divide', 'nonzero',
7070
'quantile', 'percentile', 'shares_memory', 'may_share_memory', 'diff', 'ediff1d', 'resize', 'matmul',
7171
'nan_to_num', 'isnan', 'isinf', 'isposinf', 'isneginf', 'isfinite', 'polyval', 'where', 'bincount',
7272
'pad', 'cumsum']
@@ -8170,6 +8170,72 @@ def less_equal(x1, x2, out=None):
81708170
return _mx_nd_np.less_equal(x1, x2, out)
81718171

81728172

8173+
@set_module('mxnet.numpy')
8174+
def roll(a, shift, axis=None):
8175+
"""
8176+
Roll array elements along a given axis.
8177+
8178+
Elements that roll beyond the last position are re-introduced at
8179+
the first.
8180+
8181+
Parameters
8182+
----------
8183+
a : ndarray
8184+
Input array.
8185+
shift : int or tuple of ints
8186+
The number of places by which elements are shifted. If a tuple,
8187+
then `axis` must be a tuple of the same size, and each of the
8188+
given axes is shifted by the corresponding number. If an int
8189+
while `axis` is a tuple of ints, then the same value is used for
8190+
all given axes.
8191+
axis : int or tuple of ints, optional
8192+
Axis or axes along which elements are shifted. By default, the
8193+
array is flattened before shifting, after which the original
8194+
shape is restored.
8195+
8196+
Returns
8197+
-------
8198+
res : ndarray
8199+
Output array, with the same shape as `a`.
8200+
8201+
Notes
8202+
-----
8203+
Supports rolling over multiple dimensions simultaneously.
8204+
8205+
Examples
8206+
--------
8207+
>>> x = np.arange(10)
8208+
>>> np.roll(x, 2)
8209+
array([8., 9., 0., 1., 2., 3., 4., 5., 6., 7.])
8210+
>>> np.roll(x, -2)
8211+
array([2., 3., 4., 5., 6., 7., 8., 9., 0., 1.])
8212+
8213+
>>> x2 = np.reshape(x, (2,5))
8214+
>>> x2
8215+
array([[0., 1., 2., 3., 4.],
8216+
[5., 6., 7., 8., 9.]])
8217+
>>> np.roll(x2, 1)
8218+
array([[9., 0., 1., 2., 3.],
8219+
[4., 5., 6., 7., 8.]])
8220+
>>> np.roll(x2, -1)
8221+
array([[1., 2., 3., 4., 5.],
8222+
[6., 7., 8., 9., 0.]])
8223+
>>> np.roll(x2, 1, axis=0)
8224+
array([[5., 6., 7., 8., 9.],
8225+
[0., 1., 2., 3., 4.]])
8226+
>>> np.roll(x2, -1, axis=0)
8227+
array([[5., 6., 7., 8., 9.],
8228+
[0., 1., 2., 3., 4.]])
8229+
>>> np.roll(x2, 1, axis=1)
8230+
array([[4., 0., 1., 2., 3.],
8231+
[9., 5., 6., 7., 8.]])
8232+
>>> np.roll(x2, -1, axis=1)
8233+
array([[1., 2., 3., 4., 0.],
8234+
[6., 7., 8., 9., 5.]])
8235+
"""
8236+
return _mx_nd_np.roll(a, shift, axis=axis)
8237+
8238+
81738239
@set_module('mxnet.numpy')
81748240
def rot90(m, k=1, axes=(0, 1)):
81758241
"""

python/mxnet/symbol/numpy/_symbol.py

+36-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
'diag_indices_from', 'hanning', 'hamming', 'blackman', 'flip', 'flipud', 'fliplr',
4949
'hypot', 'bitwise_and', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm',
5050
'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer',
51-
'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'rot90', 'einsum',
51+
'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'roll', 'rot90', 'einsum',
5252
'true_divide', 'quantile', 'percentile', 'shares_memory', 'may_share_memory', 'diff', 'ediff1d',
5353
'resize', 'polyval', 'nan_to_num', 'isnan', 'isinf', 'isposinf', 'isneginf', 'isfinite',
5454
'where', 'bincount', 'pad', 'cumsum']
@@ -5841,6 +5841,41 @@ def less_equal(x1, x2, out=None):
58415841
_npi.greater_equal_scalar, out)
58425842

58435843

5844+
@set_module('mxnet.symbol.numpy')
5845+
def roll(a, shift, axis=None):
5846+
"""
5847+
Roll array elements along a given axis.
5848+
5849+
Elements that roll beyond the last position are re-introduced at
5850+
the first.
5851+
5852+
Parameters
5853+
----------
5854+
a : _Symbol
5855+
Input array.
5856+
shift : int or tuple of ints
5857+
The number of places by which elements are shifted. If a tuple,
5858+
then `axis` must be a tuple of the same size, and each of the
5859+
given axes is shifted by the corresponding number. If an int
5860+
while `axis` is a tuple of ints, then the same value is used for
5861+
all given axes.
5862+
axis : int or tuple of ints, optional
5863+
Axis or axes along which elements are shifted. By default, the
5864+
array is flattened before shifting, after which the original
5865+
shape is restored.
5866+
5867+
Returns
5868+
-------
5869+
res : _Symbol
5870+
Output array, with the same shape as `a`.
5871+
5872+
Notes
5873+
-----
5874+
Supports rolling over multiple dimensions simultaneously.
5875+
"""
5876+
return _npi.roll(a, shift, axis=axis)
5877+
5878+
58445879
@set_module('mxnet.symbol.numpy')
58455880
def rot90(m, k=1, axes=(0, 1)):
58465881
"""

src/api/operator/numpy/np_matrix_op.cc

+56
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,10 @@
2222
* \brief Implementation of the API of functions in src/operator/tensor/matrix_op.cc
2323
*/
2424
#include <mxnet/api_registry.h>
25+
#include <mxnet/runtime/packed_func.h>
2526
#include "../utils.h"
2627
#include "../../../operator/tensor/matrix_op-inl.h"
28+
#include "../../../operator/numpy/np_matrix_op-inl.h"
2729

2830
namespace mxnet {
2931

@@ -86,4 +88,58 @@ MXNET_REGISTER_API("_npi.split")
8688
*ret = ADT(0, ndarray_handles.begin(), ndarray_handles.end());
8789
});
8890

91+
MXNET_REGISTER_API("_npi.roll")
92+
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
93+
using namespace runtime;
94+
static const nnvm::Op* op = Op::Get("_npi_roll");
95+
nnvm::NodeAttrs attrs;
96+
op::NumpyRollParam param;
97+
if (args[1].type_code() == kNull) {
98+
param.shift = dmlc::nullopt;
99+
} else if (args[1].type_code() == kDLInt) {
100+
param.shift = TShape(1, args[1].operator int64_t());
101+
} else {
102+
param.shift = TShape(args[1].operator ObjectRef());
103+
}
104+
if (args[2].type_code() == kNull) {
105+
param.axis = dmlc::nullopt;
106+
} else if (args[2].type_code() == kDLInt) {
107+
param.axis = TShape(1, args[2].operator int64_t());
108+
} else {
109+
param.axis = TShape(args[2].operator ObjectRef());
110+
}
111+
attrs.parsed = std::move(param);
112+
attrs.op = op;
113+
SetAttrDict<op::NumpyRollParam>(&attrs);
114+
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};
115+
int num_inputs = 1;
116+
int num_outputs = 0;
117+
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr);
118+
*ret = ndoutputs[0];
119+
});
120+
121+
MXNET_REGISTER_API("_npi.rot90")
122+
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
123+
using namespace runtime;
124+
static const nnvm::Op* op = Op::Get("_npi_rot90");
125+
nnvm::NodeAttrs attrs;
126+
op::NumpyRot90Param param;
127+
param.k = args[1].operator int();
128+
if (args[2].type_code() == kNull) {
129+
param.axes = dmlc::nullopt;
130+
} else if (args[2].type_code() == kDLInt) {
131+
param.axes = TShape(1, args[2].operator int64_t());
132+
} else {
133+
param.axes = TShape(args[2].operator ObjectRef());
134+
}
135+
attrs.parsed = std::move(param);
136+
attrs.op = op;
137+
SetAttrDict<op::NumpyRot90Param>(&attrs);
138+
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};
139+
int num_inputs = 1;
140+
int num_outputs = 0;
141+
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr);
142+
*ret = ndoutputs[0];
143+
});
144+
89145
} // namespace mxnet

0 commit comments

Comments
 (0)