Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
support bitwise_and
Browse files Browse the repository at this point in the history
  • Loading branch information
haojin2 committed Nov 27, 2019
1 parent 8f10d55 commit 135bd8f
Show file tree
Hide file tree
Showing 10 changed files with 177 additions and 13 deletions.
48 changes: 44 additions & 4 deletions python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,10 @@
'linspace', 'logspace', 'expand_dims', 'tile', 'arange', 'split', 'vsplit', 'concatenate', 'append',
'stack', 'vstack', 'column_stack', 'dstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax',
'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'unravel_index', 'hanning', 'hamming', 'blackman',
'flip', 'around', 'hypot', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril',
'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less',
'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero', 'shares_memory',
'may_share_memory', 'diff', 'resize', 'nan_to_num', 'where']
'flip', 'around', 'hypot', 'bitwise_and', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique',
'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater',
'less', 'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero',
'shares_memory', 'may_share_memory', 'diff', 'resize', 'nan_to_num', 'where']

@set_module('mxnet.ndarray.numpy')
def shape(a):
Expand Down Expand Up @@ -4364,6 +4364,46 @@ def hypot(x1, x2, out=None, **kwargs):
return _ufunc_helper(x1, x2, _npi.hypot, _np.hypot, _npi.hypot_scalar, None, out)


@set_module('mxnet.ndarray.numpy')
@wrap_np_binary_func
def bitwise_and(x1, x2, out=None, **kwargs):
r"""
Compute the bit-wise XOR of two arrays element-wise.
Parameters
----------
x1, x2 : ndarray or scalar
Only integer and boolean types are handled. If x1.shape != x2.shape,
they must be broadcastable to a common shape (which becomes the shape of the output).
out : ndarray, optional
A location into which the result is stored. If provided, it must have a shape that the
inputs broadcast to. If not provided or None, a freshly-allocated array is returned.
Returns
-------
out : ndarray
Result.
Examples
--------
>>> np.bitwise_and(13, 17)
1
>>> np.bitwise_and(14, 13)
12
>>> np.bitwise_and(np.array([14,3], dtype='int32'), 13)
array([12, 1], dtype=int32)
>>> np.bitwise_and(np.array([11,7], dtype='int32'), np.array([4,25], dtype='int32'))
array([0, 1], dtype=int32)
>>> np.bitwise_and(np.array([2,5,255], dtype='int32'), np.array([3,14,16], dtype='int32'))
array([ 2, 4, 16], dtype=int32)
>>> np.bitwise_and(np.array([True, True], dtype='bool'), np.array([False, True], dtype='bool'))
array([False, True])
"""
return _ufunc_helper(x1, x2, _npi.bitwise_and, _np.bitwise_and, _npi.bitwise_and_scalar, None, out)


@set_module('mxnet.ndarray.numpy')
@wrap_np_binary_func
def bitwise_xor(x1, x2, out=None, **kwargs):
Expand Down
53 changes: 47 additions & 6 deletions python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,11 @@
'tensordot', 'histogram', 'eye', 'linspace', 'logspace', 'expand_dims', 'tile', 'arange',
'split', 'vsplit', 'concatenate', 'stack', 'vstack', 'column_stack', 'dstack', 'mean', 'maximum', 'minimum',
'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'unravel_index',
'hanning', 'hamming', 'blackman', 'flip', 'around', 'arctan2', 'hypot', 'bitwise_xor', 'bitwise_or',
'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer',
'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum',
'true_divide', 'nonzero', 'shares_memory', 'may_share_memory', 'diff', 'resize', 'nan_to_num', 'where']
'hanning', 'hamming', 'blackman', 'flip', 'around', 'arctan2', 'hypot', 'bitwise_and', 'bitwise_xor',
'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot',
'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'hsplit', 'rot90',
'einsum', 'true_divide', 'nonzero', 'shares_memory', 'may_share_memory', 'diff', 'resize', 'nan_to_num',
'where']

# Return code for dispatching indexing function call
_NDARRAY_UNSUPPORTED_INDEXING = -1
Expand Down Expand Up @@ -6269,6 +6270,46 @@ def hypot(x1, x2, out=None, **kwargs):
return _mx_nd_np.hypot(x1, x2, out=out)


@set_module('mxnet.numpy')
@wrap_np_binary_func
def bitwise_and(x1, x2, out=None, **kwargs):
r"""
Compute the bit-wise XOR of two arrays element-wise.
Parameters
----------
x1, x2 : ndarray or scalar
Only integer and boolean types are handled. If x1.shape != x2.shape,
they must be broadcastable to a common shape (which becomes the shape of the output).
out : ndarray, optional
A location into which the result is stored. If provided, it must have a shape that the
inputs broadcast to. If not provided or None, a freshly-allocated array is returned.
Returns
-------
out : ndarray
Result.
Examples
--------
>>> np.bitwise_and(13, 17)
1
>>> np.bitwise_and(14, 13)
12
>>> np.bitwise_and(np.array([14,3], dtype='int32'), 13)
array([26, 5], dtype=int32)
>>> np.bitwise_and(np.array([11,7], dtype='int32'), np.array([4,25], dtype='int32'))
array([0, 1], dtype=int32)
>>> np.bitwise_and(np.array([2,5,255], dtype='int32'), np.array([3,14,16], dtype='int32'))
array([ 2, 4, 16], dtype=int32)
>>> np.bitwise_and(np.array([True, True], dtype='bool'), np.array([False, True], dtype='bool'))
array([False, True])
"""
return _mx_nd_np.bitwise_and(x1, x2, out=out)


@set_module('mxnet.numpy')
@wrap_np_binary_func
def bitwise_xor(x1, x2, out=None, **kwargs):
Expand Down Expand Up @@ -6297,10 +6338,10 @@ def bitwise_xor(x1, x2, out=None, **kwargs):
>>> np.bitwise_xor(31, 5)
26
>>> np.bitwise_xor(np.array([31,3], dtype=np.int32), 5)
array([26, 6])
array([26, 6], dtype=int32)
>>> np.bitwise_xor(np.array([31,3], dtype='int32'), np.array([5,6], dtype='int32'))
array([26, 5])
array([26, 5], dtype=int32)
>>> np.bitwise_xor(np.array([True, True], dtype='bool'), np.array([False, True], dtype='bool'))
array([ True, False])
"""
Expand Down
1 change: 1 addition & 0 deletions python/mxnet/numpy_dispatch_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ def _register_array_function():
'ceil',
'trunc',
'floor',
'bitwise_and',
'bitwise_xor',
'bitwise_or',
'logical_not',
Expand Down
30 changes: 27 additions & 3 deletions python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@
'linspace', 'logspace', 'expand_dims', 'tile', 'arange', 'split', 'vsplit', 'concatenate', 'append',
'stack', 'vstack', 'column_stack', 'dstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax',
'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'unravel_index', 'hanning', 'hamming', 'blackman',
'flip', 'around', 'hypot', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril',
'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less',
'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'shares_memory',
'flip', 'around', 'hypot', 'bitwise_and', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique',
'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater',
'less', 'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'shares_memory',
'may_share_memory', 'diff', 'resize', 'nan_to_num', 'where']

def _num_outputs(sym):
Expand Down Expand Up @@ -4110,6 +4110,30 @@ def hypot(x1, x2, out=None, **kwargs):
return _ufunc_helper(x1, x2, _npi.hypot, _np.hypot, _npi.hypot_scalar, None, out)


@set_module('mxnet.symbol.numpy')
@wrap_np_binary_func
def bitwise_and(x1, x2, out=None, **kwargs):
r"""
Compute the bit-wise XOR of two arrays element-wise.
Parameters
----------
x1, x2 : _Symbol or scalar
Only integer and boolean types are handled. If x1.shape != x2.shape,
they must be broadcastable to a common shape (which becomes the shape of the output).
out : _Symbol or None, optional
A location into which the result is stored. If provided, it must have
a shape that the inputs broadcast to. If not provided or `None`,
a freshly-allocated array is returned.
Returns
-------
out : _Symbol or scalar
Result.
"""
return _ufunc_helper(x1, x2, _npi.bitwise_and, _np.bitwise_and, _npi.bitwise_and_scalar, None, out)


@set_module('mxnet.symbol.numpy')
@wrap_np_binary_func
def bitwise_xor(x1, x2, out=None, **kwargs):
Expand Down
2 changes: 2 additions & 0 deletions src/operator/mshadow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,8 @@ MXNET_BINARY_MATH_OP(logical_or, a || b ? DType(1) : DType(0));

MXNET_BINARY_MATH_OP(logical_xor, (a || b) && !(a && b) ? DType(1) : DType(0));

MXNET_BINARY_MATH_OP(bitwise_and, static_cast<int64_t>(a) & static_cast<int64_t>(b));

MXNET_BINARY_MATH_OP(bitwise_xor, static_cast<int64_t>(a) ^ static_cast<int64_t>(b));

MXNET_BINARY_MATH_OP(bitwise_or, static_cast<int64_t>(a) | static_cast<int64_t>(b));
Expand Down
35 changes: 35 additions & 0 deletions src/operator/numpy/np_elemwise_broadcast_op_extended.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,41 @@ NNVM_REGISTER_OP(_npi_lcm_scalar)
.add_argument("scalar", "int", "scalar input")
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::lcm>);

NNVM_REGISTER_OP(_npi_bitwise_and)
.set_num_inputs(2)
.set_num_outputs(1)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"lhs", "rhs"};
})
.set_attr<mxnet::FInferShape>("FInferShape", BinaryBroadcastShape)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseIntType<2, 1>)
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
[](const NodeAttrs& attrs){
return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}};
})
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastIntCompute<cpu, mshadow_op::bitwise_and>)
.add_argument("lhs", "NDArray-or-Symbol", "First input to the function")
.add_argument("rhs", "NDArray-or-Symbol", "Second input to the function");

NNVM_REGISTER_OP(_npi_bitwise_and_scalar)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr_parser([](NodeAttrs* attrs) {
attrs->parsed = std::stod(attrs->dict["scalar"]);
})
.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseIntType<1, 1>)
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
[](const NodeAttrs& attrs){
return std::vector<std::pair<int, int> >{{0, 0}};
})
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.add_argument("data", "NDArray-or-Symbol", "source input")
.add_argument("scalar", "int", "scalar input")
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::ComputeInt<cpu, mshadow_op::bitwise_and>);

NNVM_REGISTER_OP(_npi_bitwise_xor)
.set_num_inputs(2)
.set_num_outputs(1)
Expand Down
6 changes: 6 additions & 0 deletions src/operator/numpy/np_elemwise_broadcast_op_extended.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ NNVM_REGISTER_OP(_npi_copysign)
NNVM_REGISTER_OP(_npi_lcm)
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastIntCompute<gpu, mshadow_op::lcm>);

NNVM_REGISTER_OP(_npi_bitwise_and)
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastIntCompute<gpu, mshadow_op::bitwise_and>);

NNVM_REGISTER_OP(_npi_bitwise_xor)
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastIntCompute<gpu, mshadow_op::bitwise_xor>);

Expand Down Expand Up @@ -85,6 +88,9 @@ NNVM_REGISTER_OP(_backward_npi_rarctan2_scalar)
NNVM_REGISTER_OP(_npi_lcm_scalar)
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::ComputeInt<gpu, mshadow_op::lcm>);

NNVM_REGISTER_OP(_npi_bitwise_and_scalar)
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::ComputeInt<gpu, mshadow_op::bitwise_and>);

NNVM_REGISTER_OP(_npi_bitwise_xor_scalar)
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::ComputeInt<gpu, mshadow_op::bitwise_xor>);

Expand Down
1 change: 1 addition & 0 deletions src/operator/operator_tune.cc
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,7 @@ IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::logical_or); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::logical_or); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::logical_xor); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::logical_xor); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::bitwise_and); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::bitwise_xor); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::bitwise_or); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::smooth_l1_loss); // NOLINT()
Expand Down
13 changes: 13 additions & 0 deletions tests/python/unittest/test_numpy_interoperability.py
Original file line number Diff line number Diff line change
Expand Up @@ -875,6 +875,18 @@ def _add_workload_bitwise_or():
OpArgMngr.add_workload('bitwise_or', zeros, ones)
OpArgMngr.add_workload('bitwise_or', ones, ones)

def _add_workload_bitwise_and():
OpArgMngr.add_workload('bitwise_and', np.array([False, False, True, True], dtype=np.bool),
np.array([False, True, False, True], dtype=np.bool))
for dtype in [np.int8, np.int32, np.int64]:
zeros = np.array([0], dtype=dtype)
ones = np.array([-1], dtype=dtype)
OpArgMngr.add_workload('bitwise_and', zeros, zeros)
OpArgMngr.add_workload('bitwise_and', ones, zeros)
OpArgMngr.add_workload('bitwise_and', zeros, ones)
OpArgMngr.add_workload('bitwise_and', ones, ones)


def _add_workload_bitwise_xor():
OpArgMngr.add_workload('bitwise_xor', np.array([False, False, True, True], dtype=np.bool),
np.array([False, True, False, True], dtype=np.bool))
Expand Down Expand Up @@ -1368,6 +1380,7 @@ def _prepare_workloads():
_add_workload_inner()
_add_workload_hypot()
_add_workload_lcm()
_add_workload_bitwise_and()
_add_workload_bitwise_xor()
_add_workload_bitwise_or()
_add_workload_ldexp()
Expand Down
1 change: 1 addition & 0 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1670,6 +1670,7 @@ def hybrid_forward(self, F, a, b, *args, **kwargs):
'power': (1.0, 2.0, [lambda y, x1, x2: _np.power(x1, x2 - 1.0) * x2],
[lambda y, x1, x2: _np.power(x1, x2) * _np.log(x1)]),
'lcm': (-100, 100, [None], None, [[_np.int32]]),
'bitwise_and': (-100, 100, [None], None, [[_np.int32]]),
'bitwise_xor': (-100, 100, [None], None, [[_np.int32]]),
'bitwise_or': (-100, 100, [None], None, [[_np.int32]]),
'maximum': (-1, 1, [lambda y, x1, x2: _np.ones(y.shape) * (x1 >= x2)],
Expand Down

0 comments on commit 135bd8f

Please sign in to comment.