diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index f9164855dfe9..e1d62f8d96d5 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -39,7 +39,7 @@ 'around', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero', 'shares_memory', 'may_share_memory', 'diff', 'resize', - 'nan_to_num'] + 'nan_to_num','bitwise_or'] @set_module('mxnet.ndarray.numpy') @@ -1148,6 +1148,55 @@ def lcm(x1, x2, out=None, **kwargs): array([ 0, 20, 20, 60, 20, 20], dtype=int64) """ return _ufunc_helper(x1, x2, _npi.lcm, _np.lcm, _npi.lcm_scalar, None, out) + +@set_module('mxnet.ndarray.numpy') +@wrap_np_binary_func +def bitwise_or(x1, x2, out=None, **kwargs): + """ + Returns the bit-wise OR of two arrays element-wise or the bit-wise OR of + the underlying binary representation of the integers in the input arrays + + Parameters + ---------- + *x : array_like + Input arrays. + out : ndarray, None, or tuple of ndarray and None, optional + Alternate array object(s) in which to put the result; if provided, it + must have a shape that the inputs broadcast to. A tuple of arrays + (possible only as a keyword argument) must have length equal to the + number of outputs; use `None` for uninitialized outputs to be + allocated by the ufunc. + where : array_like, optional + This condition is broadcast over the input. At locations where the + condition is True, the `out` array will be set to the ufunc result. + Elsewhere, the `out` array will retain its original value. + Note that if an uninitialized `out` array is created via the default + ``out=None``, locations within it where the condition is False will + remain uninitialized. + **kwargs + For other keyword-only arguments, see the :ref:`ufunc docs `. + + Returns + ------- + r : ndarray or tuple of ndarray + `r` will have the shape that the arrays in `x` broadcast to; if `out` is + provided, it will be returned. If not, `r` will be allocated and + may contain uninitialized values. If the function has more than one + output, then the result will be a tuple of arrays. + + See Also + -------- + logical_or, bitwise_and, bitwise_xor + binary_repr: Return the binary representation of the input number as a string. + + Examples + -------- + >>> np.bitwise_or(13, 16) + 29 + >>> np.bitwise_or([33, 4], 1) + array([33, 5]) + """ + return _ufunc_helper(x1, x2, _npi.bitwise_or, _np.bitwise_or, _npi.bitwise_or_scalar, None, out) @set_module('mxnet.ndarray.numpy') diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index b6816d75a98e..a11a02c0de85 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -57,7 +57,7 @@ 'blackman', 'flip', 'around', 'arctan2', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero', 'shares_memory', - 'may_share_memory', 'diff', 'resize', 'nan_to_num'] + 'may_share_memory', 'diff', 'resize', 'nan_to_num','bitwise_or'] # Return code for dispatching indexing function call _NDARRAY_UNSUPPORTED_INDEXING = -1 @@ -2749,6 +2749,55 @@ def lcm(x1, x2, out=None, **kwargs): """ return _mx_nd_np.lcm(x1, x2, out=out) +@set_module('mxnet.numpy') +@wrap_np_binary_func +def bitwise_or(x1, x2, out=None, **kwargs): + """ + Returns the bit-wise OR of two arrays element-wise or the bit-wise OR of + the underlying binary representation of the integers in the input arrays + + Parameters + ---------- + *x : array_like + Input arrays. + out : ndarray, None, or tuple of ndarray and None, optional + Alternate array object(s) in which to put the result; if provided, it + must have a shape that the inputs broadcast to. A tuple of arrays + (possible only as a keyword argument) must have length equal to the + number of outputs; use `None` for uninitialized outputs to be + allocated by the ufunc. + where : array_like, optional + This condition is broadcast over the input. At locations where the + condition is True, the `out` array will be set to the ufunc result. + Elsewhere, the `out` array will retain its original value. + Note that if an uninitialized `out` array is created via the default + ``out=None``, locations within it where the condition is False will + remain uninitialized. + **kwargs + For other keyword-only arguments, see the :ref:`ufunc docs `. + + Returns + ------- + r : ndarray or tuple of ndarray + `r` will have the shape that the arrays in `x` broadcast to; if `out` is + provided, it will be returned. If not, `r` will be allocated and + may contain uninitialized values. If the function has more than one + output, then the result will be a tuple of arrays. + + See Also + -------- + logical_or, bitwise_and, bitwise_xor + binary_repr: Return the binary representation of the input number as a string. + + Examples + -------- + >>> np.bitwise_or(13, 16) + 29 + >>> np.bitwise_or([33, 4], 1) + array([33, 5]) + """ + return _mx_nd_np.bitwise_or(x1, x2, out=out) + @set_module('mxnet.numpy') @wrap_np_unary_func diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index 0d7303865b92..bc29e607ef1a 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -41,7 +41,7 @@ 'around', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'shares_memory', 'may_share_memory', 'diff', - 'resize', 'nan_to_num'] + 'resize', 'nan_to_num','bitwise_or'] def _num_outputs(sym): @@ -1209,6 +1209,48 @@ def lcm(x1, x2, out=None, **kwargs): gcd : The greatest common divisor """ return _ufunc_helper(x1, x2, _npi.lcm, _np.lcm, _npi.lcm_scalar, None, out) + +@set_module('mxnet.symbol.numpy') +@wrap_np_binary_func +def bitwise_or(x1, x2, out=None, **kwargs): + """ + Returns the bit-wise OR of two arrays element-wise or the bit-wise OR of + the underlying binary representation of the integers in the input arrays + + Parameters + ---------- + *x : array_like + Input arrays. + out : ndarray, None, or tuple of ndarray and None, optional + Alternate array object(s) in which to put the result; if provided, it + must have a shape that the inputs broadcast to. A tuple of arrays + (possible only as a keyword argument) must have length equal to the + number of outputs; use `None` for uninitialized outputs to be + allocated by the ufunc. + where : array_like, optional + This condition is broadcast over the input. At locations where the + condition is True, the `out` array will be set to the ufunc result. + Elsewhere, the `out` array will retain its original value. + Note that if an uninitialized `out` array is created via the default + ``out=None``, locations within it where the condition is False will + remain uninitialized. + **kwargs + For other keyword-only arguments, see the :ref:`ufunc docs `. + + Returns + ------- + r : ndarray or tuple of ndarray + `r` will have the shape that the arrays in `x` broadcast to; if `out` is + provided, it will be returned. If not, `r` will be allocated and + may contain uninitialized values. If the function has more than one + output, then the result will be a tuple of arrays. + + See Also + -------- + logical_or, bitwise_and, bitwise_xor + binary_repr: Return the binary representation of the input number as a string. + """ + return _ufunc_helper(x1, x2, _npi.bitwise_or, _np.bitwise_or, _npi.bitwise_or_scalar, None, out) @set_module('mxnet.symbol.numpy') diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h index 4ae587188d1b..bc24ef335b1a 100644 --- a/src/operator/mshadow_op.h +++ b/src/operator/mshadow_op.h @@ -1327,6 +1327,22 @@ struct lcm : public mxnet_op::tunable { } }; +struct bitwise_or : public mxnet_op::tunable { + template + MSHADOW_XINLINE static typename enable_if::value, DType>::type + Map(DType a, DType b) { + DType c; + c= a | b; + return c; + } + template + MSHADOW_XINLINE static typename enable_if::value, DType>::type + Map(DType a, DType b) { + return DType(0.0f); + } +}; + + } // namespace mshadow_op } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cc b/src/operator/numpy/np_elemwise_broadcast_op.cc index a76e59d30dc6..8bdc57134c40 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op.cc +++ b/src/operator/numpy/np_elemwise_broadcast_op.cc @@ -212,6 +212,41 @@ NNVM_REGISTER_OP(_npi_lcm_scalar) .add_argument("scalar", "int", "scalar input") .set_attr("FCompute", BinaryScalarOp::Compute); +NNVM_REGISTER_OP(_npi_bitwise_or) +.set_num_inputs(2) +.set_num_outputs(1) +.set_attr("FListInputNames", +[](const NodeAttrs& attrs) { +return std::vector{"lhs", "rhs"}; +}) +.set_attr("FInferShape", BinaryBroadcastShape) +.set_attr("FInferType", ElemwiseIntType<2, 1>) +.set_attr("FInplaceOption", +[](const NodeAttrs& attrs){ +return std::vector >{{0, 0}, {1, 0}}; +}) +.set_attr("FGradient", MakeZeroGradNodes) +.set_attr("FCompute", BinaryBroadcastCompute) +.add_argument("lhs", "NDArray-or-Symbol", "First input to the function") +.add_argument("rhs", "NDArray-or-Symbol", "Second input to the function"); + +NNVM_REGISTER_OP(_npi_bitwise_or_scalar) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr_parser([](NodeAttrs* attrs) { +attrs->parsed = std::stod(attrs->dict["scalar"]); +}) +.set_attr("FInferShape", ElemwiseShape<1, 1>) +.set_attr("FInferType", ElemwiseIntType<1, 1>) +.set_attr("FInplaceOption", +[](const NodeAttrs& attrs){ +return std::vector >{{0, 0}}; +}) +.set_attr("FGradient", MakeZeroGradNodes) +.add_argument("data", "NDArray-or-Symbol", "source input") +.add_argument("scalar", "int", "scalar input") +.set_attr("FCompute", BinaryScalarOp::Compute); + MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_add_scalar) .set_attr("FCompute", BinaryScalarOp::Compute) .set_attr("FGradient", ElemwiseGradUseNone{"_copy"}); diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cu b/src/operator/numpy/np_elemwise_broadcast_op.cu index a0a277df211f..2f2a5be0acff 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op.cu +++ b/src/operator/numpy/np_elemwise_broadcast_op.cu @@ -76,6 +76,9 @@ NNVM_REGISTER_OP(_npi_copysign) NNVM_REGISTER_OP(_npi_lcm) .set_attr("FCompute", BinaryBroadcastCompute); +NNVM_REGISTER_OP(_npi_bitwise_or) +.set_attr("FCompute", BinaryBroadcastCompute); + NNVM_REGISTER_OP(_backward_npi_copysign) .set_attr("FCompute", BinaryBroadcastBackwardUseIn); @@ -146,6 +149,9 @@ NNVM_REGISTER_OP(_backward_npi_rarctan2_scalar) NNVM_REGISTER_OP(_npi_lcm_scalar) .set_attr("FCompute", BinaryScalarOp::Compute); +NNVM_REGISTER_OP(_npi_bitwise_or_scalar) +.set_attr("FCompute", BinaryScalarOp::Compute); + NNVM_REGISTER_OP(_npi_ldexp) .set_attr("FCompute", BinaryBroadcastCompute); diff --git a/src/operator/operator_tune.cc b/src/operator/operator_tune.cc index 633f63026bc0..22619c7824db 100644 --- a/src/operator/operator_tune.cc +++ b/src/operator/operator_tune.cc @@ -397,6 +397,7 @@ IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::logical_or); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::logical_xor); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::logical_xor); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::bitwise_xor); // NOLINT() +IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::bitwise_or); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::smooth_l1_loss); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::smooth_l1_gradient); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::lcm); // NOLINT() diff --git a/tests/python/unittest/test_numpy_interoperability.py b/tests/python/unittest/test_numpy_interoperability.py index 6d9c63f9f857..e2e59d613944 100644 --- a/tests/python/unittest/test_numpy_interoperability.py +++ b/tests/python/unittest/test_numpy_interoperability.py @@ -815,6 +815,10 @@ def _add_workload_lcm(): OpArgMngr.add_workload('lcm', np.array([12, 120], dtype=np.uint8), np.array([20, 200], dtype=np.uint8)) OpArgMngr.add_workload('lcm', np.array(195225786*2, dtype=np.int32), np.array(195225786*5, dtype=np.int32)) +def _add_workload_bitwise_or(): + OpArgMngr.add_workload('bitwise_or', np.array([12, 120], dtype=np.int8), np.array([20, 200], dtype=np.int8)) + OpArgMngr.add_workload('bitwise_or', np.array([12, 120], dtype=np.uint8), np.array([20, 200], dtype=np.uint8)) + OpArgMngr.add_workload('bitwise_or', np.array(195225786*2, dtype=np.int32), np.array(195225786*5, dtype=np.int32)) def _add_workload_ldexp(): OpArgMngr.add_workload('ldexp', np.array(2., np.float32), np.array(3, np.int8)) @@ -1236,6 +1240,7 @@ def _prepare_workloads(): _add_workload_inner() _add_workload_hypot() _add_workload_lcm() + _add_workload_bitwise_or() _add_workload_ldexp() _add_workload_subtract(array_pool) _add_workload_multiply(array_pool) diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 9aabdfd4cabc..431efa172cc5 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -1618,6 +1618,7 @@ def hybrid_forward(self, F, a, b, *args, **kwargs): 'power': (1.0, 2.0, [lambda y, x1, x2: _np.power(x1, x2 - 1.0) * x2], [lambda y, x1, x2: _np.power(x1, x2) * _np.log(x1)]), 'lcm': (-100, 100, [None], None, [[_np.int32]]), + 'bitwise_or': (-100, 100, [None], None, [[_np.int32]]), 'maximum': (-1, 1, [lambda y, x1, x2: _np.ones(y.shape) * (x1 >= x2)], [lambda y, x1, x2: _np.ones(y.shape) * (x1 < x2)]), 'minimum': (-1, 1, [lambda y, x1, x2: _np.ones(y.shape) * (x1 <= x2)],