Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[Numpy] Numpy compatible bitwise_and operator #16009

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 69 additions & 1 deletion python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@
'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p',
'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor',
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot',
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack']
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack',
'bitwise_and']


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -1905,3 +1906,70 @@ def get_list(arrays):

arrays = get_list(arrays)
return _npi.stack(*arrays, axis=axis, out=out)


@set_module('mxnet.ndarray.numpy')
def bitwise_and(x1, x2, out=None):
"""
Compute the bit-wise AND of two arrays element-wise.

Computes the bit-wise AND of the underlying binary representation of
the integers in the input arrays. This ufunc implements the C/Python
operator ``&``.

Parameters
----------
x1, x2 : ndarray or boolean
Only integer and boolean types are handled.
out : ndarray or None, optional
A location into which the result is stored.
If provided, it must have the same shape and dtype as input ndarray.
If not provided or `None`, a freshly-allocated array is returned.

Returns
-------
y : ndarray of integer dtypes, or scalar if both inputs are scalars.

See Also
--------
bitwise_or
bitwise_xor

Notes
-------
This function differs from the original `numpy.bitwise_and
<https://docs.scipy.org/doc/numpy/reference/generated/numpy.bitwise_and.html>`_ in
the following aspects:
- Input type currently does not support Python native iterables(list, tuple, ...).
- Input type currently does not support boolean arrays.

Examples
--------
The number 13 is represented by ``00001101``. Likewise, 17 is
represented by ``00010001``. The bit-wise AND of 13 and 17 is
therefore ``000000001``, or 1:
>>> np.bitwise_and(13, 17)
1

>>> np.bitwise_and(14, 13)
12

>>> x = np.array([14, 13], dtype=np.int32)
>>> np.bitwise_and(x, 13)
array([12, 13], dtype=int32)

>>> x = np.array([11,7], dtype=np.int32)
>>> y = np.array([4, 25], dtype=np.int32)
>>> np.bitwise_and(x, y)
array([0, 1], dtype=int32)

>>> x1 = np.array(13, dtype=np.int32)
>>> x2 = np.array(17, dtype=np.int32)
>>> out = np.array(0, dtype=np.int32)
>>> np.bitwise_and(x1, x2, out)
array(1, dtype=int32)
>>> out
array(1, dtype=int32)

"""
return _ufunc_helper(x1, x2, _npi.bitwise_and, _np.bitwise_and, _npi.bitwise_and_scalar, None, out)
66 changes: 65 additions & 1 deletion python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative',
'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh',
'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate',
'stack']
'stack', 'bitwise_and']


# This function is copied from ndarray.py since pylint
Expand Down Expand Up @@ -3086,3 +3086,67 @@ def stack(arrays, axis=0, out=None):
stacked : ndarray
The stacked array has one more dimension than the input arrays."""
return _mx_nd_np.stack(arrays, axis=axis, out=out)


@set_module('mxnet.numpy')
def bitwise_and(x1, x2, out=None):
"""
Computes the bit-wise AND of the underlying binary representation of
the integers in the input arrays. This method implements the C/Python
operator ``&``.

Parameters
----------
x1, x2 : ndarray of integer dtypes, or ints.
out : ndarray or None, optional
A location into which the result is stored.
If provided, it must have the same shape and dtype as input ndarray.
If not provided or `None`, a freshly-allocated array is returned.

Returns
-------
y : ndarray of integer dtypes, or scalar if both inputs are scalars.

See Also
--------
bitwise_or
bitwise_xor

Notes
-------
This function differs from the original `numpy.bitwise_and
<https://docs.scipy.org/doc/numpy/reference/generated/numpy.bitwise_and.html>`_ in
the following aspects:
- Input type currently does not support Python native iterables(list, tuple, ...).
- Input type currently does not support boolean arrays.

Examples
--------
The number 13 is represented by ``00001101``. Likewise, 17 is
represented by ``00010001``. The bit-wise AND of 13 and 17 is
therefore ``000000001``, or 1:
>>> np.bitwise_and(13, 17)
1

>>> np.bitwise_and(14, 13)
12

>>> x = np.array([14, 13], dtype=np.int32)
>>> np.bitwise_and(x, 13)
array([12, 13], dtype=int32)

>>> x = np.array([11,7], dtype=np.int32)
>>> y = np.array([4, 25], dtype=np.int32)
>>> np.bitwise_and(x, y)
array([0, 1], dtype=int32)

>>> x1 = np.array(13, dtype=np.int32)
>>> x2 = np.array(17, dtype=np.int32)
>>> out = np.array(0, dtype=np.int32)
>>> np.bitwise_and(x1, x2, out)
array(1, dtype=int32)
>>> out
array(1, dtype=int32)

"""
return _mx_nd_np.bitwise_and(x1, x2, out)
37 changes: 36 additions & 1 deletion python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@
'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p',
'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor',
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot',
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack']
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack',
'bitwise_and']


def _num_outputs(sym):
Expand Down Expand Up @@ -2328,4 +2329,38 @@ def get_list(arrays):
return _npi.stack(*arrays, axis=axis, out=out)


@set_module('mxnet.symbol.numpy')
def bitwise_and(x1, x2, out=None):
"""
Computes the bit-wise AND of the underlying binary representation of
the integers in the input arrays. This method implements the C/Python
operator ``&``.

Parameters
----------
x1, x2 : _Symbol
out : _Symbol or None, optional
Dummy parameter to keep the consistency with the ndarray counterpart.

Returns
-------
y : _Symbol of integer dtypes

See Also
--------
bitwise_or
bitwise_xor

Notes
-------
This function differs from the original `numpy.bitwise_and
<https://docs.scipy.org/doc/numpy/reference/generated/numpy.bitwise_and.html>`_ in
the following aspects:
- Input type currently does not support Python native iterables(list, tuple, ...).
- Input type currently does not support boolean arrays.

"""
sxjscience marked this conversation as resolved.
Show resolved Hide resolved
return _ufunc_helper(x1, x2, _npi.bitwise_and, _np.bitwise_and, _npi.bitwise_and_scalar, None, out)


_set_np_symbol_class(_Symbol)
19 changes: 19 additions & 0 deletions src/operator/elemwise_op_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,25 @@ inline bool ElemwiseType(const nnvm::NodeAttrs& attrs,
attrs, in_attrs, out_attrs, -1);
}

// Special case of ElemwiseType. Constrains dtype to integer types
template<index_t n_in, index_t n_out>
inline bool ElemwiseIntType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_attrs,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

alignment

std::vector<int> *out_attrs) {
CHECK(in_attrs->at(0) == mshadow::kInt64 ||
in_attrs->at(0) == mshadow::kInt32 ||
in_attrs->at(0) == mshadow::kInt8 ||
in_attrs->at(0) == mshadow::kUint8) << "Only supports integer types.";
if (n_in != -1) {
CHECK_EQ(in_attrs->size(), static_cast<size_t>(n_in)) << " in operator " << attrs.name;
}
if (n_out != -1) {
CHECK_EQ(out_attrs->size(), static_cast<size_t>(n_out)) << " in operator " << attrs.name;
}
return ElemwiseAttr<int, type_is_none, type_assign, true, type_string>(
attrs, in_attrs, out_attrs, -1);
}

// Transfer gradient and input to FGradient function
struct ElemwiseGradUseIn {
const char *op_name;
Expand Down
3 changes: 3 additions & 0 deletions src/operator/mshadow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,9 @@ MXNET_BINARY_MATH_OP_NC(elu, a > DType(0) ? a :

MXNET_BINARY_MATH_OP_NC(elu_grad, a > DType(0) ? DType(1) : DType(b + a));

MXNET_BINARY_MATH_OP(bitwise_and,
static_cast<int64_t>(static_cast<int64_t>(a) & static_cast<int64_t>(b)));

MXNET_SIMPLE_UNARY_MATH_OP(tanh);

MXNET_UNARY_MATH_OP(tanh_grad, 1.0f - math::sqr(a));
Expand Down
31 changes: 31 additions & 0 deletions src/operator/numpy/np_elemwise_broadcast_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -182,5 +182,36 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_rpower_scalar)
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::rpower>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseOut{"_backward_rpower_scalar"});

NNVM_REGISTER_OP(_npi_bitwise_and)
.set_num_inputs(2)
.set_num_outputs(1)
.set_attr<mxnet::FInferShape>("FInferShape", BinaryBroadcastShape)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseIntType<2, 1>) // TODO(reminisce) boolean support
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"x1", "x2"};
})
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, mshadow_op::bitwise_and>)
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.add_argument("x1", "NDArray-or-Symbol", "Input ndarray")
.add_argument("x2", "NDArray-or-Symbol", "Input ndarray");

NNVM_REGISTER_OP(_npi_bitwise_and_scalar)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr_parser([](NodeAttrs* attrs) {
attrs->parsed = std::stod(attrs->dict["scalar"]);
})
.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseIntType<1, 1>) // TODO(reminisce) boolean support
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"x1", "x2"};
})
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::bitwise_and>)
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.add_argument("data", "NDArray-or-Symbol", "source input")
.add_argument("scalar", "int", "scalar input");

} // namespace op
} // namespace mxnet
6 changes: 6 additions & 0 deletions src/operator/numpy/np_elemwise_broadcast_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ NNVM_REGISTER_OP(_npi_maximum)
NNVM_REGISTER_OP(_npi_minimum)
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, mshadow_op::minimum>);

NNVM_REGISTER_OP(_npi_bitwise_and)
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, mshadow_op::bitwise_and>);

NNVM_REGISTER_OP(_npi_add_scalar)
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, op::mshadow_op::plus>);

Expand Down Expand Up @@ -78,5 +81,8 @@ NNVM_REGISTER_OP(_npi_maximum_scalar)
NNVM_REGISTER_OP(_npi_minimum_scalar)
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::minimum>);

NNVM_REGISTER_OP(_npi_bitwise_and_scalar)
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::bitwise_and>);

} // namespace op
} // namespace mxnet
1 change: 1 addition & 0 deletions src/operator/operator_tune.cc
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,7 @@ IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::logical_xor); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::logical_xor); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::smooth_l1_loss); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::smooth_l1_gradient); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::bitwise_and); // NOLINT()
IMPLEMENT_BLANK_WORKLOAD_FWD(mxnet::op::mxnet_op::set_to_int<0>); // NOLINT()
IMPLEMENT_BLANK_WORKLOAD_FWD(mxnet::op::mxnet_op::set_to_int<1>); // NOLINT()
IMPLEMENT_BLANK_WORKLOAD_FWD(mxnet::op::PopulateFullIdxRspKernel); // NOLINT()
Expand Down
47 changes: 47 additions & 0 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1080,6 +1080,53 @@ def hybrid_forward(self, F, a, *args):
assert same(mx_out.asnumpy(), np_out)


@with_seed()
@use_np
def test_np_bitwise_and():
class TestBitwiseAnd(HybridBlock):
def __init__(self):
super(TestBitwiseAnd, self).__init__()

def hybrid_forward(self, F, x1, x2):
return F.np.bitwise_and(x1, x2)

shapes = [
((3, 1), (3, 1)),
((3, 1, 2), (3, 1, 2)),
((1, ),(1, )),
((3, 0), (3, 0)), # zero-size shape
((0, 1), (0, 1)), # zero-size shape
((2, 0, 2), (2, 0, 2)), # zero-size shape
((1, ), (3, )), # broadcast
((2, 3), (2, 1)), # broadcast
((1, 3), (2, 3)), # broadcast
((1, 3), (2, 0, 3)), # broadcast to zero-size shape
((1, 0, 1), (3, 0, 1)), # broadcast of zero-size shape
((), ()), # zero-dim shape
]

for hybridize in [True, False]:
for shape in shapes:
x1_shape, x2_shape = shape

test_bitwise_and = TestBitwiseAnd()
if hybridize:
test_bitwise_and.hybridize()

x1 = rand_ndarray(x1_shape, dtype=_np.dtype(int)).as_np_ndarray()
x2 = rand_ndarray(x2_shape, dtype=_np.dtype(int)).as_np_ndarray()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we test for all the supported dtypes, I think we support uint8, int8, int32, int64.


np_out = _np.bitwise_and(x1.asnumpy(), x2.asnumpy())
with mx.autograd.record():
mx_out = test_bitwise_and(x1, x2)
assert mx_out.shape == np_out.shape
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)

mx_out = np.bitwise_and(x1, x2)
np_out = _np.bitwise_and(x1.asnumpy(), x2.asnumpy())
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)


if __name__ == '__main__':
import nose
nose.runmodule()