Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add op bitwise_or [numpy] (#16801)
Browse files Browse the repository at this point in the history
* solve conflict

* change test

* add protocol

* less blank line

* update submodule

* update 3rd party
  • Loading branch information
Yiyan66 authored and haojin2 committed Nov 21, 2019
1 parent ece027c commit 4da14a2
Show file tree
Hide file tree
Showing 11 changed files with 168 additions and 12 deletions.
47 changes: 42 additions & 5 deletions python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,10 @@
'linspace', 'logspace', 'expand_dims', 'tile', 'arange', 'split', 'vsplit', 'concatenate', 'append',
'stack', 'vstack', 'column_stack', 'dstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax',
'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip',
'around', 'hypot', 'bitwise_xor', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take',
'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal',
'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero', 'shares_memory', 'may_share_memory', 'diff', 'resize',
'nan_to_num', 'where']

'around', 'hypot', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity',
'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal',
'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero', 'shares_memory', 'may_share_memory',
'diff', 'resize', 'nan_to_num', 'where']

@set_module('mxnet.ndarray.numpy')
def shape(a):
Expand Down Expand Up @@ -4365,6 +4364,44 @@ def bitwise_xor(x1, x2, out=None, **kwargs):
return _ufunc_helper(x1, x2, _npi.bitwise_xor, _np.bitwise_xor, _npi.bitwise_xor_scalar, None, out)


@set_module('mxnet.ndarray.numpy')
@wrap_np_binary_func
def bitwise_or(x1, x2, out=None, **kwargs):
r"""
Compute the bit-wise OR of two arrays element-wise.
Parameters
----------
x1, x2 : ndarray or scalar
Only integer and boolean types are handled. If x1.shape != x2.shape,
they must be broadcastable to a common shape (which becomes the shape of the output).
out : ndarray, optional
A location into which the result is stored. If provided, it must have a shape that the
inputs broadcast to. If not provided or None, a freshly-allocated array is returned.
Returns
-------
out : ndarray
Result.
Examples
--------
>>> np.bitwise_or(13, 17)
29
>>> np.bitwise_or(31, 5)
31
>>> np.bitwise_or(np.array([31,3], dtype='int32'), 5)
array([31, 7])
>>> np.bitwise_or(np.array([31,3], dtype='int32'), np.array([5,6], dtype='int32'))
array([31, 7])
>>> np.bitwise_or(np.array([True, True], dtype='bool'), np.array([False, True], dtype='bool'))
array([ True, True])
"""
return _ufunc_helper(x1, x2, _npi.bitwise_or, _np.bitwise_or, _npi.bitwise_or_scalar, None, out)


@set_module('mxnet.ndarray.numpy')
@wrap_np_binary_func
def ldexp(x1, x2, out=None, **kwargs):
Expand Down
46 changes: 42 additions & 4 deletions python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@
'tensordot', 'histogram', 'eye', 'linspace', 'logspace', 'expand_dims', 'tile', 'arange',
'split', 'vsplit', 'concatenate', 'stack', 'vstack', 'column_stack', 'dstack', 'mean', 'maximum', 'minimum',
'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming',
'blackman', 'flip', 'around', 'arctan2', 'hypot', 'bitwise_xor', 'rad2deg', 'deg2rad', 'unique', 'lcm',
'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less',
'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero', 'shares_memory',
'may_share_memory', 'diff', 'resize', 'nan_to_num', 'where']
'blackman', 'flip', 'around', 'arctan2', 'hypot', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad',
'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal',
'greater', 'less', 'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero',
'shares_memory', 'may_share_memory', 'diff', 'resize', 'nan_to_num', 'where']

# Return code for dispatching indexing function call
_NDARRAY_UNSUPPORTED_INDEXING = -1
Expand Down Expand Up @@ -6277,6 +6277,44 @@ def bitwise_xor(x1, x2, out=None, **kwargs):
return _mx_nd_np.bitwise_xor(x1, x2, out=out)


@set_module('mxnet.numpy')
@wrap_np_binary_func
def bitwise_or(x1, x2, out=None, **kwargs):
r"""
Compute the bit-wise OR of two arrays element-wise.
Parameters
----------
x1, x2 : ndarray or scalar
Only integer and boolean types are handled. If x1.shape != x2.shape,
they must be broadcastable to a common shape (which becomes the shape of the output).
out : ndarray, optional
A location into which the result is stored. If provided, it must have a shape that the
inputs broadcast to. If not provided or None, a freshly-allocated array is returned.
Returns
-------
out : ndarray
Result.
Examples
--------
>>> np.bitwise_or(13, 17)
29
>>> np.bitwise_or(31, 5)
31
>>> np.bitwise_or(np.array([31,3], dtype=np.int32), 5)
array([31, 7])
>>> np.bitwise_or(np.array([31,3], dtype='int32'), np.array([5,6], dtype='int32'))
array([31, 7])
>>> np.bitwise_or(np.array([True, True], dtype='bool'), np.array([False, True], dtype='bool'))
array([ True, True])
"""
return _mx_nd_np.bitwise_or(x1, x2, out=out)


@set_module('mxnet.numpy')
@wrap_np_binary_func
def ldexp(x1, x2, out=None, **kwargs):
Expand Down
1 change: 1 addition & 0 deletions python/mxnet/numpy_dispatch_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ def _register_array_function():
'trunc',
'floor',
'bitwise_xor',
'bitwise_or',
'logical_not',
'equal',
'not_equal',
Expand Down
29 changes: 26 additions & 3 deletions python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,11 @@
'linspace', 'logspace', 'expand_dims', 'tile', 'arange', 'split', 'vsplit', 'concatenate', 'append',
'stack', 'vstack', 'column_stack', 'dstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax',
'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip',
'around', 'hypot', 'bitwise_xor', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take',
'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal',
'around', 'hypot', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity',
'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal',
'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'shares_memory', 'may_share_memory', 'diff',
'resize', 'nan_to_num', 'where']


def _num_outputs(sym):
return len(sym.as_nd_ndarray())

Expand Down Expand Up @@ -4103,6 +4102,30 @@ def bitwise_xor(x1, x2, out=None, **kwargs):
return _ufunc_helper(x1, x2, _npi.bitwise_xor, _np.bitwise_xor, _npi.bitwise_xor_scalar, None, out)


@set_module('mxnet.symbol.numpy')
@wrap_np_binary_func
def bitwise_or(x1, x2, out=None, **kwargs):
r"""
Compute the bit-wise OR of two arrays element-wise.
Parameters
----------
x1, x2 : _Symbol or scalar
Only integer and boolean types are handled. If x1.shape != x2.shape,
they must be broadcastable to a common shape (which becomes the shape of the output).
out : _Symbol or None, optional
A location into which the result is stored. If provided, it must have
a shape that the inputs broadcast to. If not provided or `None`,
a freshly-allocated array is returned.
Returns
-------
out : _Symbol or scalar
Result.
"""
return _ufunc_helper(x1, x2, _npi.bitwise_or, _np.bitwise_or, _npi.bitwise_or_scalar, None, out)


@set_module('mxnet.symbol.numpy')
def unique(ar, return_index=False, return_inverse=False, return_counts=False, axis=None):
"""
Expand Down
2 changes: 2 additions & 0 deletions src/operator/mshadow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,8 @@ MXNET_BINARY_MATH_OP(logical_xor, (a || b) && !(a && b) ? DType(1) : DType(0));

MXNET_BINARY_MATH_OP(bitwise_xor, static_cast<int64_t>(a) ^ static_cast<int64_t>(b));

MXNET_BINARY_MATH_OP(bitwise_or, static_cast<int64_t>(a) | static_cast<int64_t>(b));

MXNET_UNARY_MATH_OP(square_root, math::sqrt(a));

MXNET_UNARY_MATH_OP(square_root_grad, 0.5f / math::id(a));
Expand Down
1 change: 1 addition & 0 deletions src/operator/numpy/np_elemwise_broadcast_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_mod)
MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_power)
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, mshadow_op::power>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_broadcast_power"});

MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_add_scalar)
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, op::mshadow_op::plus>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_copy"});
Expand Down
35 changes: 35 additions & 0 deletions src/operator/numpy/np_elemwise_broadcast_op_extended.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,24 @@ NNVM_REGISTER_OP(_npi_bitwise_xor)
.add_argument("lhs", "NDArray-or-Symbol", "First input to the function")
.add_argument("rhs", "NDArray-or-Symbol", "Second input to the function");

NNVM_REGISTER_OP(_npi_bitwise_or)
.set_num_inputs(2)
.set_num_outputs(1)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"lhs", "rhs"};
})
.set_attr<mxnet::FInferShape>("FInferShape", BinaryBroadcastShape)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseIntType<2, 1>)
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
[](const NodeAttrs& attrs){
return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}};
})
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastIntCompute<cpu, mshadow_op::bitwise_or>)
.add_argument("lhs", "NDArray-or-Symbol", "First input to the function")
.add_argument("rhs", "NDArray-or-Symbol", "Second input to the function");

NNVM_REGISTER_OP(_npi_bitwise_xor_scalar)
.set_num_inputs(1)
.set_num_outputs(1)
Expand All @@ -135,6 +153,23 @@ NNVM_REGISTER_OP(_npi_bitwise_xor_scalar)
.add_argument("scalar", "int", "scalar input")
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::ComputeInt<cpu, mshadow_op::bitwise_xor>);

NNVM_REGISTER_OP(_npi_bitwise_or_scalar)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr_parser([](NodeAttrs* attrs) {
attrs->parsed = std::stod(attrs->dict["scalar"]);
})
.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseIntType<1, 1>)
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
[](const NodeAttrs& attrs){
return std::vector<std::pair<int, int> >{{0, 0}};
})
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.add_argument("data", "NDArray-or-Symbol", "source input")
.add_argument("scalar", "int", "scalar input")
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::ComputeInt<cpu, mshadow_op::bitwise_or>);

MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_copysign_scalar)
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::copysign>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_npi_copysign_scalar"});
Expand Down
6 changes: 6 additions & 0 deletions src/operator/numpy/np_elemwise_broadcast_op_extended.cu
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ NNVM_REGISTER_OP(_npi_lcm)
NNVM_REGISTER_OP(_npi_bitwise_xor)
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastIntCompute<gpu, mshadow_op::bitwise_xor>);

NNVM_REGISTER_OP(_npi_bitwise_or)
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastIntCompute<gpu, mshadow_op::bitwise_or>);

NNVM_REGISTER_OP(_backward_npi_copysign)
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastBackwardUseIn<gpu, mshadow_op::copysign_grad,
mshadow_op::copysign_rgrad>);
Expand Down Expand Up @@ -85,6 +88,9 @@ NNVM_REGISTER_OP(_npi_lcm_scalar)
NNVM_REGISTER_OP(_npi_bitwise_xor_scalar)
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::ComputeInt<gpu, mshadow_op::bitwise_xor>);

NNVM_REGISTER_OP(_npi_bitwise_or_scalar)
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::ComputeInt<gpu, mshadow_op::bitwise_or>);

NNVM_REGISTER_OP(_npi_ldexp)
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, mshadow_op::ldexp>);

Expand Down
1 change: 1 addition & 0 deletions src/operator/operator_tune.cc
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,7 @@ IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::logical_or); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::logical_xor); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::logical_xor); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::bitwise_xor); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::bitwise_or); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::smooth_l1_loss); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::smooth_l1_gradient); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::lcm); // NOLINT()
Expand Down
11 changes: 11 additions & 0 deletions tests/python/unittest/test_numpy_interoperability.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,6 +845,16 @@ def _add_workload_lcm():
OpArgMngr.add_workload('lcm', np.array([12, 120], dtype=np.uint8), np.array([20, 200], dtype=np.uint8))
OpArgMngr.add_workload('lcm', np.array(195225786*2, dtype=np.int32), np.array(195225786*5, dtype=np.int32))

def _add_workload_bitwise_or():
OpArgMngr.add_workload('bitwise_or', np.array([False, False, True, True], dtype=np.bool),
np.array([False, True, False, True], dtype=np.bool))
for dtype in [np.int8, np.int32, np.int64]:
zeros = np.array([0], dtype=dtype)
ones = np.array([-1], dtype=dtype)
OpArgMngr.add_workload('bitwise_or', zeros, zeros)
OpArgMngr.add_workload('bitwise_or', ones, zeros)
OpArgMngr.add_workload('bitwise_or', zeros, ones)
OpArgMngr.add_workload('bitwise_or', ones, ones)

def _add_workload_bitwise_xor():
OpArgMngr.add_workload('bitwise_xor', np.array([False, False, True, True], dtype=np.bool),
Expand Down Expand Up @@ -1337,6 +1347,7 @@ def _prepare_workloads():
_add_workload_hypot()
_add_workload_lcm()
_add_workload_bitwise_xor()
_add_workload_bitwise_or()
_add_workload_ldexp()
_add_workload_subtract(array_pool)
_add_workload_multiply(array_pool)
Expand Down
1 change: 1 addition & 0 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1671,6 +1671,7 @@ def hybrid_forward(self, F, a, b, *args, **kwargs):
[lambda y, x1, x2: _np.power(x1, x2) * _np.log(x1)]),
'lcm': (-100, 100, [None], None, [[_np.int32]]),
'bitwise_xor': (-100, 100, [None], None, [[_np.int32]]),
'bitwise_or': (-100, 100, [None], None, [[_np.int32]]),
'maximum': (-1, 1, [lambda y, x1, x2: _np.ones(y.shape) * (x1 >= x2)],
[lambda y, x1, x2: _np.ones(y.shape) * (x1 < x2)]),
'minimum': (-1, 1, [lambda y, x1, x2: _np.ones(y.shape) * (x1 <= x2)],
Expand Down

0 comments on commit 4da14a2

Please sign in to comment.