From 4da14a22385622c35e9a5c9c3e8a17c07f718cad Mon Sep 17 00:00:00 2001 From: Yiyan66 <57363390+Yiyan66@users.noreply.github.com> Date: Fri, 22 Nov 2019 02:41:29 +0800 Subject: [PATCH] add op bitwise_or [numpy] (#16801) * solve conflict * change test * add protocol * less blank line * update submodule * update 3rd party --- python/mxnet/ndarray/numpy/_op.py | 47 +++++++++++++++++-- python/mxnet/numpy/multiarray.py | 46 ++++++++++++++++-- python/mxnet/numpy_dispatch_protocol.py | 1 + python/mxnet/symbol/numpy/_symbol.py | 29 ++++++++++-- src/operator/mshadow_op.h | 2 + .../numpy/np_elemwise_broadcast_op.cc | 1 + .../np_elemwise_broadcast_op_extended.cc | 35 ++++++++++++++ .../np_elemwise_broadcast_op_extended.cu | 6 +++ src/operator/operator_tune.cc | 1 + .../unittest/test_numpy_interoperability.py | 11 +++++ tests/python/unittest/test_numpy_op.py | 1 + 11 files changed, 168 insertions(+), 12 deletions(-) diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index 3cc5b85c8384..d31681113933 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -36,11 +36,10 @@ 'linspace', 'logspace', 'expand_dims', 'tile', 'arange', 'split', 'vsplit', 'concatenate', 'append', 'stack', 'vstack', 'column_stack', 'dstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', - 'around', 'hypot', 'bitwise_xor', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', - 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', - 'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero', 'shares_memory', 'may_share_memory', 'diff', 'resize', - 'nan_to_num', 'where'] - + 'around', 'hypot', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', + 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', + 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero', 'shares_memory', 'may_share_memory', + 'diff', 'resize', 'nan_to_num', 'where'] @set_module('mxnet.ndarray.numpy') def shape(a): @@ -4365,6 +4364,44 @@ def bitwise_xor(x1, x2, out=None, **kwargs): return _ufunc_helper(x1, x2, _npi.bitwise_xor, _np.bitwise_xor, _npi.bitwise_xor_scalar, None, out) +@set_module('mxnet.ndarray.numpy') +@wrap_np_binary_func +def bitwise_or(x1, x2, out=None, **kwargs): + r""" + Compute the bit-wise OR of two arrays element-wise. + + Parameters + ---------- + x1, x2 : ndarray or scalar + Only integer and boolean types are handled. If x1.shape != x2.shape, + they must be broadcastable to a common shape (which becomes the shape of the output). + out : ndarray, optional + A location into which the result is stored. If provided, it must have a shape that the + inputs broadcast to. If not provided or None, a freshly-allocated array is returned. + + Returns + ------- + out : ndarray + Result. + + Examples + -------- + >>> np.bitwise_or(13, 17) + 29 + + >>> np.bitwise_or(31, 5) + 31 + >>> np.bitwise_or(np.array([31,3], dtype='int32'), 5) + array([31, 7]) + + >>> np.bitwise_or(np.array([31,3], dtype='int32'), np.array([5,6], dtype='int32')) + array([31, 7]) + >>> np.bitwise_or(np.array([True, True], dtype='bool'), np.array([False, True], dtype='bool')) + array([ True, True]) + """ + return _ufunc_helper(x1, x2, _npi.bitwise_or, _np.bitwise_or, _npi.bitwise_or_scalar, None, out) + + @set_module('mxnet.ndarray.numpy') @wrap_np_binary_func def ldexp(x1, x2, out=None, **kwargs): diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index e94d4c8341b4..3ad9254eb0e6 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -54,10 +54,10 @@ 'tensordot', 'histogram', 'eye', 'linspace', 'logspace', 'expand_dims', 'tile', 'arange', 'split', 'vsplit', 'concatenate', 'stack', 'vstack', 'column_stack', 'dstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', - 'blackman', 'flip', 'around', 'arctan2', 'hypot', 'bitwise_xor', 'rad2deg', 'deg2rad', 'unique', 'lcm', - 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', - 'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero', 'shares_memory', - 'may_share_memory', 'diff', 'resize', 'nan_to_num', 'where'] + 'blackman', 'flip', 'around', 'arctan2', 'hypot', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', + 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', + 'greater', 'less', 'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero', + 'shares_memory', 'may_share_memory', 'diff', 'resize', 'nan_to_num', 'where'] # Return code for dispatching indexing function call _NDARRAY_UNSUPPORTED_INDEXING = -1 @@ -6277,6 +6277,44 @@ def bitwise_xor(x1, x2, out=None, **kwargs): return _mx_nd_np.bitwise_xor(x1, x2, out=out) +@set_module('mxnet.numpy') +@wrap_np_binary_func +def bitwise_or(x1, x2, out=None, **kwargs): + r""" + Compute the bit-wise OR of two arrays element-wise. + + Parameters + ---------- + x1, x2 : ndarray or scalar + Only integer and boolean types are handled. If x1.shape != x2.shape, + they must be broadcastable to a common shape (which becomes the shape of the output). + out : ndarray, optional + A location into which the result is stored. If provided, it must have a shape that the + inputs broadcast to. If not provided or None, a freshly-allocated array is returned. + + Returns + ------- + out : ndarray + Result. + + Examples + -------- + >>> np.bitwise_or(13, 17) + 29 + + >>> np.bitwise_or(31, 5) + 31 + >>> np.bitwise_or(np.array([31,3], dtype=np.int32), 5) + array([31, 7]) + + >>> np.bitwise_or(np.array([31,3], dtype='int32'), np.array([5,6], dtype='int32')) + array([31, 7]) + >>> np.bitwise_or(np.array([True, True], dtype='bool'), np.array([False, True], dtype='bool')) + array([ True, True]) + """ + return _mx_nd_np.bitwise_or(x1, x2, out=out) + + @set_module('mxnet.numpy') @wrap_np_binary_func def ldexp(x1, x2, out=None, **kwargs): diff --git a/python/mxnet/numpy_dispatch_protocol.py b/python/mxnet/numpy_dispatch_protocol.py index 67af2724503e..13e186f4008a 100644 --- a/python/mxnet/numpy_dispatch_protocol.py +++ b/python/mxnet/numpy_dispatch_protocol.py @@ -230,6 +230,7 @@ def _register_array_function(): 'trunc', 'floor', 'bitwise_xor', + 'bitwise_or', 'logical_not', 'equal', 'not_equal', diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index 7da771966f1f..734242f4e86d 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -38,12 +38,11 @@ 'linspace', 'logspace', 'expand_dims', 'tile', 'arange', 'split', 'vsplit', 'concatenate', 'append', 'stack', 'vstack', 'column_stack', 'dstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', - 'around', 'hypot', 'bitwise_xor', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', - 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', + 'around', 'hypot', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', + 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'shares_memory', 'may_share_memory', 'diff', 'resize', 'nan_to_num', 'where'] - def _num_outputs(sym): return len(sym.as_nd_ndarray()) @@ -4103,6 +4102,30 @@ def bitwise_xor(x1, x2, out=None, **kwargs): return _ufunc_helper(x1, x2, _npi.bitwise_xor, _np.bitwise_xor, _npi.bitwise_xor_scalar, None, out) +@set_module('mxnet.symbol.numpy') +@wrap_np_binary_func +def bitwise_or(x1, x2, out=None, **kwargs): + r""" + Compute the bit-wise OR of two arrays element-wise. + + Parameters + ---------- + x1, x2 : _Symbol or scalar + Only integer and boolean types are handled. If x1.shape != x2.shape, + they must be broadcastable to a common shape (which becomes the shape of the output). + out : _Symbol or None, optional + A location into which the result is stored. If provided, it must have + a shape that the inputs broadcast to. If not provided or `None`, + a freshly-allocated array is returned. + + Returns + ------- + out : _Symbol or scalar + Result. + """ + return _ufunc_helper(x1, x2, _npi.bitwise_or, _np.bitwise_or, _npi.bitwise_or_scalar, None, out) + + @set_module('mxnet.symbol.numpy') def unique(ar, return_index=False, return_inverse=False, return_counts=False, axis=None): """ diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h index 4ae587188d1b..a0a424b6209b 100644 --- a/src/operator/mshadow_op.h +++ b/src/operator/mshadow_op.h @@ -562,6 +562,8 @@ MXNET_BINARY_MATH_OP(logical_xor, (a || b) && !(a && b) ? DType(1) : DType(0)); MXNET_BINARY_MATH_OP(bitwise_xor, static_cast(a) ^ static_cast(b)); +MXNET_BINARY_MATH_OP(bitwise_or, static_cast(a) | static_cast(b)); + MXNET_UNARY_MATH_OP(square_root, math::sqrt(a)); MXNET_UNARY_MATH_OP(square_root_grad, 0.5f / math::id(a)); diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cc b/src/operator/numpy/np_elemwise_broadcast_op.cc index f2adfc125d02..fbe04eeb4eff 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op.cc +++ b/src/operator/numpy/np_elemwise_broadcast_op.cc @@ -161,6 +161,7 @@ MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_mod) MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_power) .set_attr("FCompute", BinaryBroadcastCompute) .set_attr("FGradient", ElemwiseGradUseIn{"_backward_broadcast_power"}); + MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_add_scalar) .set_attr("FCompute", BinaryScalarOp::Compute) .set_attr("FGradient", ElemwiseGradUseNone{"_copy"}); diff --git a/src/operator/numpy/np_elemwise_broadcast_op_extended.cc b/src/operator/numpy/np_elemwise_broadcast_op_extended.cc index 84c47e597883..34ff20e733e5 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op_extended.cc +++ b/src/operator/numpy/np_elemwise_broadcast_op_extended.cc @@ -118,6 +118,24 @@ NNVM_REGISTER_OP(_npi_bitwise_xor) .add_argument("lhs", "NDArray-or-Symbol", "First input to the function") .add_argument("rhs", "NDArray-or-Symbol", "Second input to the function"); +NNVM_REGISTER_OP(_npi_bitwise_or) +.set_num_inputs(2) +.set_num_outputs(1) +.set_attr("FListInputNames", +[](const NodeAttrs& attrs) { + return std::vector{"lhs", "rhs"}; +}) +.set_attr("FInferShape", BinaryBroadcastShape) +.set_attr("FInferType", ElemwiseIntType<2, 1>) +.set_attr("FInplaceOption", +[](const NodeAttrs& attrs){ + return std::vector >{{0, 0}, {1, 0}}; +}) +.set_attr("FGradient", MakeZeroGradNodes) +.set_attr("FCompute", BinaryBroadcastIntCompute) +.add_argument("lhs", "NDArray-or-Symbol", "First input to the function") +.add_argument("rhs", "NDArray-or-Symbol", "Second input to the function"); + NNVM_REGISTER_OP(_npi_bitwise_xor_scalar) .set_num_inputs(1) .set_num_outputs(1) @@ -135,6 +153,23 @@ NNVM_REGISTER_OP(_npi_bitwise_xor_scalar) .add_argument("scalar", "int", "scalar input") .set_attr("FCompute", BinaryScalarOp::ComputeInt); +NNVM_REGISTER_OP(_npi_bitwise_or_scalar) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr_parser([](NodeAttrs* attrs) { + attrs->parsed = std::stod(attrs->dict["scalar"]); + }) +.set_attr("FInferShape", ElemwiseShape<1, 1>) +.set_attr("FInferType", ElemwiseIntType<1, 1>) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs){ + return std::vector >{{0, 0}}; + }) +.set_attr("FGradient", MakeZeroGradNodes) +.add_argument("data", "NDArray-or-Symbol", "source input") +.add_argument("scalar", "int", "scalar input") +.set_attr("FCompute", BinaryScalarOp::ComputeInt); + MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_copysign_scalar) .set_attr("FCompute", BinaryScalarOp::Compute) .set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_copysign_scalar"}); diff --git a/src/operator/numpy/np_elemwise_broadcast_op_extended.cu b/src/operator/numpy/np_elemwise_broadcast_op_extended.cu index f858fb4a4e79..90f11b1cd93a 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op_extended.cu +++ b/src/operator/numpy/np_elemwise_broadcast_op_extended.cu @@ -37,6 +37,9 @@ NNVM_REGISTER_OP(_npi_lcm) NNVM_REGISTER_OP(_npi_bitwise_xor) .set_attr("FCompute", BinaryBroadcastIntCompute); +NNVM_REGISTER_OP(_npi_bitwise_or) +.set_attr("FCompute", BinaryBroadcastIntCompute); + NNVM_REGISTER_OP(_backward_npi_copysign) .set_attr("FCompute", BinaryBroadcastBackwardUseIn); @@ -85,6 +88,9 @@ NNVM_REGISTER_OP(_npi_lcm_scalar) NNVM_REGISTER_OP(_npi_bitwise_xor_scalar) .set_attr("FCompute", BinaryScalarOp::ComputeInt); +NNVM_REGISTER_OP(_npi_bitwise_or_scalar) +.set_attr("FCompute", BinaryScalarOp::ComputeInt); + NNVM_REGISTER_OP(_npi_ldexp) .set_attr("FCompute", BinaryBroadcastCompute); diff --git a/src/operator/operator_tune.cc b/src/operator/operator_tune.cc index e2a4c8af3099..249da3d049c1 100644 --- a/src/operator/operator_tune.cc +++ b/src/operator/operator_tune.cc @@ -397,6 +397,7 @@ IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::logical_or); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::logical_xor); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::logical_xor); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::bitwise_xor); // NOLINT() +IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::bitwise_or); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::smooth_l1_loss); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::smooth_l1_gradient); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::lcm); // NOLINT() diff --git a/tests/python/unittest/test_numpy_interoperability.py b/tests/python/unittest/test_numpy_interoperability.py index 088306ff01c8..87c48ab5fd79 100644 --- a/tests/python/unittest/test_numpy_interoperability.py +++ b/tests/python/unittest/test_numpy_interoperability.py @@ -845,6 +845,16 @@ def _add_workload_lcm(): OpArgMngr.add_workload('lcm', np.array([12, 120], dtype=np.uint8), np.array([20, 200], dtype=np.uint8)) OpArgMngr.add_workload('lcm', np.array(195225786*2, dtype=np.int32), np.array(195225786*5, dtype=np.int32)) +def _add_workload_bitwise_or(): + OpArgMngr.add_workload('bitwise_or', np.array([False, False, True, True], dtype=np.bool), + np.array([False, True, False, True], dtype=np.bool)) + for dtype in [np.int8, np.int32, np.int64]: + zeros = np.array([0], dtype=dtype) + ones = np.array([-1], dtype=dtype) + OpArgMngr.add_workload('bitwise_or', zeros, zeros) + OpArgMngr.add_workload('bitwise_or', ones, zeros) + OpArgMngr.add_workload('bitwise_or', zeros, ones) + OpArgMngr.add_workload('bitwise_or', ones, ones) def _add_workload_bitwise_xor(): OpArgMngr.add_workload('bitwise_xor', np.array([False, False, True, True], dtype=np.bool), @@ -1337,6 +1347,7 @@ def _prepare_workloads(): _add_workload_hypot() _add_workload_lcm() _add_workload_bitwise_xor() + _add_workload_bitwise_or() _add_workload_ldexp() _add_workload_subtract(array_pool) _add_workload_multiply(array_pool) diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 7dd165b5421f..dc99fc6fc251 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -1671,6 +1671,7 @@ def hybrid_forward(self, F, a, b, *args, **kwargs): [lambda y, x1, x2: _np.power(x1, x2) * _np.log(x1)]), 'lcm': (-100, 100, [None], None, [[_np.int32]]), 'bitwise_xor': (-100, 100, [None], None, [[_np.int32]]), + 'bitwise_or': (-100, 100, [None], None, [[_np.int32]]), 'maximum': (-1, 1, [lambda y, x1, x2: _np.ones(y.shape) * (x1 >= x2)], [lambda y, x1, x2: _np.ones(y.shape) * (x1 < x2)]), 'minimum': (-1, 1, [lambda y, x1, x2: _np.ones(y.shape) * (x1 <= x2)],