From 7a2e6cd444d6176523668945823a831256b166ab Mon Sep 17 00:00:00 2001 From: zoeygxy Date: Wed, 19 Jun 2019 15:02:13 +0800 Subject: [PATCH 1/5] Numpy compatible bitiwse_and moved from numpy branch --- python/mxnet/_numpy_op_doc.py | 70 +++++++++++++++++++ src/operator/mshadow_op.h | 3 + .../numpy/np_elemwise_broadcast_op.cc | 14 ++++ .../numpy/np_elemwise_broadcast_op.cu | 3 + src/operator/operator_tune.cc | 1 + tests/python/unittest/test_numpy_op.py | 47 +++++++++++++ 6 files changed, 138 insertions(+) diff --git a/python/mxnet/_numpy_op_doc.py b/python/mxnet/_numpy_op_doc.py index 5543ebc8e8c9..26bdad40f52a 100644 --- a/python/mxnet/_numpy_op_doc.py +++ b/python/mxnet/_numpy_op_doc.py @@ -52,3 +52,73 @@ def _np_zeros_like(a): Array of zeros with the same shape and type as `a`. """ pass + + +def _np_bitwise_and(): + """ + Compute the bit-wise AND of two arrays element-wise. + + Computes the bit-wise AND of the underlying binary representation of + the integers in the input arrays. This ufunc implements the C/Python + operator ``&``. + + Parameters + ---------- + x1, x2 : ndarray or boolean + Only integer and boolean types are handled. + out : ndarray, None, or tuple of ndarray and None, optional + A location into which the result is stored. If provided, it must have + a shape that the inputs broadcast to. If not provided or `None`, + a freshly-allocated array is returned. A tuple (possible only as a + keyword argument) must have length equal to the number of outputs. + where : array_like, optional + Values of True indicate to calculate the ufunc at that position, values + of False indicate to leave the value in the output alone. + **kwargs + For other keyword-only arguments, see the + :ref:`ufunc docs `. + + Returns + ------- + out : ndarray or scalar + Result. + + See Also + -------- + bitwise_or + bitwise_xor + + Notes + ------- + This function differs from the original `numpy.bitwise_and + `_ in + the following aspects: + - Input type does not support Python native iterables(list, tuple, ...). + + + Examples + -------- + The number 13 is represented by ``00001101``. Likewise, 17 is + represented by ``00010001``. The bit-wise AND of 13 and 17 is + therefore ``000000001``, or 1: + + >>> x1 = np.array(13) + >>> x2 = np.array(17) + >>> np.bitwise_and(x1, x2) + array(1.) + + >>> x1 = np.array([True, False, False]) + >>> x2 = np.array([True, True, False]) + >>> np.bitwise_and(x1, x2) + array([1., 0., 0.]) + + Only support ints as input dtype. + + >>> x1 = np.array(13.0) + >>> x2 = np.array(17.0) + >>> np.bitwise_and(x1, x2) + Traceback (most recent call last): + File "", line 1, in + TypeError: ufunc 'bitwise_and' not supported for the input types, and the inputs could not be safely coerced to any supported types according to the casting rule ''safe'' + """ + pass diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h index ab53e7733066..b1f8bd56cbba 100644 --- a/src/operator/mshadow_op.h +++ b/src/operator/mshadow_op.h @@ -147,6 +147,9 @@ MXNET_BINARY_MATH_OP_NC(elu, a > DType(0) ? a : MXNET_BINARY_MATH_OP_NC(elu_grad, a > DType(0) ? DType(1) : DType(b + a)); +MXNET_BINARY_MATH_OP(bitwise_and, + static_cast(static_cast(a) & static_cast(b))); + MXNET_SIMPLE_UNARY_MATH_OP(tanh); MXNET_UNARY_MATH_OP(tanh_grad, 1.0f - math::sqr(a)); diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cc b/src/operator/numpy/np_elemwise_broadcast_op.cc index c36423dff9fd..b8a0b9352f89 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op.cc +++ b/src/operator/numpy/np_elemwise_broadcast_op.cc @@ -182,5 +182,19 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_rpower_scalar) .set_attr("FCompute", BinaryScalarOp::Compute) .set_attr("FGradient", ElemwiseGradUseOut{"_backward_rpower_scalar"}); +NNVM_REGISTER_OP(_np_bitwise_and) +.set_num_inputs(2) +.set_num_outputs(1) +.set_attr("FInferShape", BinaryBroadcastShape) +.set_attr("FInferType", ElemwiseType<2, 1>) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"x1", "x2"}; +}) +.set_attr("FCompute", BinaryBroadcastCompute) +.set_attr("FGradient", MakeZeroGradNodes) +.add_argument("x1", "NDArray-or-Symbol", "Input ndarray") +.add_argument("x2", "NDArray-or-Symbol", "Input ndarray"); + } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cu b/src/operator/numpy/np_elemwise_broadcast_op.cu index c858b3a4987a..ea71f2a70ecf 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op.cu +++ b/src/operator/numpy/np_elemwise_broadcast_op.cu @@ -48,6 +48,9 @@ NNVM_REGISTER_OP(_npi_maximum) NNVM_REGISTER_OP(_npi_minimum) .set_attr("FCompute", BinaryBroadcastCompute); +NNVM_REGISTER_OP(_np_bitwise_and) +.set_attr("FCompute", BinaryBroadcastCompute); + NNVM_REGISTER_OP(_npi_add_scalar) .set_attr("FCompute", BinaryScalarOp::Compute); diff --git a/src/operator/operator_tune.cc b/src/operator/operator_tune.cc index 98ce14e7bf05..63ed4b11290d 100644 --- a/src/operator/operator_tune.cc +++ b/src/operator/operator_tune.cc @@ -359,6 +359,7 @@ IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::logical_xor); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::logical_xor); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::smooth_l1_loss); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::smooth_l1_gradient); // NOLINT() +IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::bitwise_and); // NOLINT() IMPLEMENT_BLANK_WORKLOAD_FWD(mxnet::op::mxnet_op::set_to_int<0>); // NOLINT() IMPLEMENT_BLANK_WORKLOAD_FWD(mxnet::op::mxnet_op::set_to_int<1>); // NOLINT() IMPLEMENT_BLANK_WORKLOAD_FWD(mxnet::op::PopulateFullIdxRspKernel); // NOLINT() diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 399cdead6177..4af563b21b6f 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -1080,6 +1080,53 @@ def hybrid_forward(self, F, a, *args): assert same(mx_out.asnumpy(), np_out) +@with_seed() +@use_np +def test_np_bitwise_and(): + class TestBitwiseAnd(HybridBlock): + def __init__(self): + super(TestBitwiseAnd, self).__init__() + + def hybrid_forward(self, F, x1, x2): + return F.np.bitwise_and(x1, x2) + + shapes = [ + ((3, 1), (3, 1)), + ((3, 1, 2), (3, 1, 2)), + ((1, ),(1, )), + ((3, 0), (3, 0)), # zero-size shape + ((0, 1), (0, 1)), # zero-size shape + ((2, 0, 2), (2, 0, 2)), # zero-size shape + ((1, ), (3, )), # broadcast + ((2, 3), (2, 1)), # broadcast + ((1, 3), (2, 3)), # broadcast + ((1, 3), (2, 0, 3)), # broadcast to zero-size shape + ((1, 0, 1), (3, 0, 1)), # broadcast of zero-size shape + ((), ()), # zero-dim shape + ] + + for hybridize in [True, False]: + for shape in shapes: + x1_shape, x2_shape = shape + + test_bitwise_and = TestBitwiseAnd() + if hybridize: + test_bitwise_and.hybridize() + + x1 = rand_ndarray(x1_shape, dtype=_np.dtype(int)).as_np_ndarray() + x2 = rand_ndarray(x2_shape, dtype=_np.dtype(int)).as_np_ndarray() + + np_out = _np.bitwise_and(x1.asnumpy(), x2.asnumpy()) + with mx.autograd.record(): + mx_out = test_bitwise_and(x1, x2) + assert mx_out.shape == np_out.shape + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5) + + mx_out = np.bitwise_and(x1, x2) + np_out = _np.bitwise_and(x1.asnumpy(), x2.asnumpy()) + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5) + + if __name__ == '__main__': import nose nose.runmodule() From d852fead818ba0f3f332fa9ade07761aa13318ef Mon Sep 17 00:00:00 2001 From: zoeygxy Date: Mon, 26 Aug 2019 16:46:55 +0800 Subject: [PATCH 2/5] Refactors to add scalar support --- python/mxnet/_numpy_op_doc.py | 70 ------------------- python/mxnet/ndarray/numpy/_op.py | 70 ++++++++++++++++++- python/mxnet/numpy/multiarray.py | 66 ++++++++++++++++- python/mxnet/symbol/numpy/_symbol.py | 34 +++++++++ src/operator/elemwise_op_common.h | 19 +++++ .../numpy/np_elemwise_broadcast_op.cc | 21 +++++- 6 files changed, 206 insertions(+), 74 deletions(-) diff --git a/python/mxnet/_numpy_op_doc.py b/python/mxnet/_numpy_op_doc.py index 26bdad40f52a..5543ebc8e8c9 100644 --- a/python/mxnet/_numpy_op_doc.py +++ b/python/mxnet/_numpy_op_doc.py @@ -52,73 +52,3 @@ def _np_zeros_like(a): Array of zeros with the same shape and type as `a`. """ pass - - -def _np_bitwise_and(): - """ - Compute the bit-wise AND of two arrays element-wise. - - Computes the bit-wise AND of the underlying binary representation of - the integers in the input arrays. This ufunc implements the C/Python - operator ``&``. - - Parameters - ---------- - x1, x2 : ndarray or boolean - Only integer and boolean types are handled. - out : ndarray, None, or tuple of ndarray and None, optional - A location into which the result is stored. If provided, it must have - a shape that the inputs broadcast to. If not provided or `None`, - a freshly-allocated array is returned. A tuple (possible only as a - keyword argument) must have length equal to the number of outputs. - where : array_like, optional - Values of True indicate to calculate the ufunc at that position, values - of False indicate to leave the value in the output alone. - **kwargs - For other keyword-only arguments, see the - :ref:`ufunc docs `. - - Returns - ------- - out : ndarray or scalar - Result. - - See Also - -------- - bitwise_or - bitwise_xor - - Notes - ------- - This function differs from the original `numpy.bitwise_and - `_ in - the following aspects: - - Input type does not support Python native iterables(list, tuple, ...). - - - Examples - -------- - The number 13 is represented by ``00001101``. Likewise, 17 is - represented by ``00010001``. The bit-wise AND of 13 and 17 is - therefore ``000000001``, or 1: - - >>> x1 = np.array(13) - >>> x2 = np.array(17) - >>> np.bitwise_and(x1, x2) - array(1.) - - >>> x1 = np.array([True, False, False]) - >>> x2 = np.array([True, True, False]) - >>> np.bitwise_and(x1, x2) - array([1., 0., 0.]) - - Only support ints as input dtype. - - >>> x1 = np.array(13.0) - >>> x2 = np.array(17.0) - >>> np.bitwise_and(x1, x2) - Traceback (most recent call last): - File "", line 1, in - TypeError: ufunc 'bitwise_and' not supported for the input types, and the inputs could not be safely coerced to any supported types according to the casting rule ''safe'' - """ - pass diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index 04b3b19bcf2e..62d102188c58 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -32,7 +32,8 @@ 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', - 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack'] + 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', + 'bitwise_and'] @set_module('mxnet.ndarray.numpy') @@ -1905,3 +1906,70 @@ def get_list(arrays): arrays = get_list(arrays) return _npi.stack(*arrays, axis=axis, out=out) + + +@set_module('mxnet.ndarray.numpy') +def bitwise_and(x1, x2, out=None): + """ + Compute the bit-wise AND of two arrays element-wise. + + Computes the bit-wise AND of the underlying binary representation of + the integers in the input arrays. This ufunc implements the C/Python + operator ``&``. + + Parameters + ---------- + x1, x2 : ndarray or boolean + Only integer and boolean types are handled. + out : ndarray or None, optional + A location into which the result is stored. + If provided, it must have the same shape and dtype as input ndarray. + If not provided or `None`, a freshly-allocated array is returned. + + Returns + ------- + y : ndarray of integer dtypes, or scalar if both inputs are scalars. + + See Also + -------- + bitwise_or + bitwise_xor + + Notes + ------- + This function differs from the original `numpy.bitwise_and + `_ in + the following aspects: + - Input type currently does not support Python native iterables(list, tuple, ...). + - Input type currently does not support boolean arrays. + + Examples + -------- + The number 13 is represented by ``00001101``. Likewise, 17 is + represented by ``00010001``. The bit-wise AND of 13 and 17 is + therefore ``000000001``, or 1: + >>> np.bitwise_and(13, 17) + 1 + + >>> np.bitwise_and(14, 13) + 12 + + >>> x = np.array([14, 13], dtype=np.int32) + >>> np.bitwise_and(x, 13) + array([12, 13], dtype=int32) + + >>> x = np.array([11,7], dtype=np.int32) + >>> y = np.array([4, 25], dtype=np.int32) + >>> np.bitwise_and(x, y) + array([0, 1], dtype=int32) + + >>> x1 = np.array(13, dtype=np.int32) + >>> x2 = np.array(17, dtype=np.int32) + >>> out = np.array(0, dtype=np.int32) + >>> np.bitwise_and(x1, x2, out) + array(1, dtype=int32) + >>> out + array(1, dtype=int32) + + """ + return _ufunc_helper(x1, x2, _npi.bitwise_and, _np.bitwise_and, _npi.bitwise_and_scalar, None, out) diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index a47a9c01b7c4..7662b2fb8d77 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -49,7 +49,7 @@ 'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', - 'stack'] + 'stack', 'bitwise_and'] # This function is copied from ndarray.py since pylint @@ -3086,3 +3086,67 @@ def stack(arrays, axis=0, out=None): stacked : ndarray The stacked array has one more dimension than the input arrays.""" return _mx_nd_np.stack(arrays, axis=axis, out=out) + + +@set_module('mxnet.numpy') +def bitwise_and(x1, x2, out=None): + """ + Computes the bit-wise AND of the underlying binary representation of + the integers in the input arrays. This method implements the C/Python + operator ``&``. + + Parameters + ---------- + x1, x2 : ndarray of integer dtypes, or ints. + out : ndarray or None, optional + A location into which the result is stored. + If provided, it must have the same shape and dtype as input ndarray. + If not provided or `None`, a freshly-allocated array is returned. + + Returns + ------- + y : ndarray of integer dtypes, or scalar if both inputs are scalars. + + See Also + -------- + bitwise_or + bitwise_xor + + Notes + ------- + This function differs from the original `numpy.bitwise_and + `_ in + the following aspects: + - Input type currently does not support Python native iterables(list, tuple, ...). + - Input type currently does not support boolean arrays. + + Examples + -------- + The number 13 is represented by ``00001101``. Likewise, 17 is + represented by ``00010001``. The bit-wise AND of 13 and 17 is + therefore ``000000001``, or 1: + >>> np.bitwise_and(13, 17) + 1 + + >>> np.bitwise_and(14, 13) + 12 + + >>> x = np.array([14, 13], dtype=np.int32) + >>> np.bitwise_and(x, 13) + array([12, 13], dtype=int32) + + >>> x = np.array([11,7], dtype=np.int32) + >>> y = np.array([4, 25], dtype=np.int32) + >>> np.bitwise_and(x, y) + array([0, 1], dtype=int32) + + >>> x1 = np.array(13, dtype=np.int32) + >>> x2 = np.array(17, dtype=np.int32) + >>> out = np.array(0, dtype=np.int32) + >>> np.bitwise_and(x1, x2, out) + array(1, dtype=int32) + >>> out + array(1, dtype=int32) + + """ + return _mx_nd_np.bitwise_and(x1, x2, out) diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index 251a8a1b8e56..f60037595245 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -2328,4 +2328,38 @@ def get_list(arrays): return _npi.stack(*arrays, axis=axis, out=out) +@set_module('mxnet.symbol.numpy') +def bitwise_and(x1, x2, out=None): + """ + Computes the bit-wise AND of the underlying binary representation of + the integers in the input arrays. This method implements the C/Python + operator ``&``. + + Parameters + ---------- + x1, x2 : _Symbol + out : _Symbol or None, optional + Dummy parameter to keep the consistency with the ndarray counterpart. + + Returns + ------- + y : _Symbol of integer dtypes + + See Also + -------- + bitwise_or + bitwise_xor + + Notes + ------- + This function differs from the original `numpy.bitwise_and + `_ in + the following aspects: + - Input type currently does not support Python native iterables(list, tuple, ...). + - Input type currently does not support boolean arrays. + + """ + return _ufunc_helper(x1, x2, _npi.bitwise_and, _np.bitwise_and, _npi.bitwise_and_scalar, None, out) + + _set_np_symbol_class(_Symbol) diff --git a/src/operator/elemwise_op_common.h b/src/operator/elemwise_op_common.h index 6dae2dfa20c4..c8e94fbe57b2 100644 --- a/src/operator/elemwise_op_common.h +++ b/src/operator/elemwise_op_common.h @@ -186,6 +186,25 @@ inline bool ElemwiseType(const nnvm::NodeAttrs& attrs, attrs, in_attrs, out_attrs, -1); } +// Special case of ElemwiseType. Constrains dtype to integer types +template +inline bool ElemwiseIntType(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK(in_attrs->at(0) == mshadow::kInt64 || + in_attrs->at(0) == mshadow::kInt32 || + in_attrs->at(0) == mshadow::kInt8 || + in_attrs->at(0) == mshadow::kUint8) << "Only supports integer types."; + if (n_in != -1) { + CHECK_EQ(in_attrs->size(), static_cast(n_in)) << " in operator " << attrs.name; + } + if (n_out != -1) { + CHECK_EQ(out_attrs->size(), static_cast(n_out)) << " in operator " << attrs.name; + } + return ElemwiseAttr( + attrs, in_attrs, out_attrs, -1); +} + // Transfer gradient and input to FGradient function struct ElemwiseGradUseIn { const char *op_name; diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cc b/src/operator/numpy/np_elemwise_broadcast_op.cc index b8a0b9352f89..0d22c0803557 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op.cc +++ b/src/operator/numpy/np_elemwise_broadcast_op.cc @@ -182,11 +182,11 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_rpower_scalar) .set_attr("FCompute", BinaryScalarOp::Compute) .set_attr("FGradient", ElemwiseGradUseOut{"_backward_rpower_scalar"}); -NNVM_REGISTER_OP(_np_bitwise_and) +NNVM_REGISTER_OP(_npi_bitwise_and) .set_num_inputs(2) .set_num_outputs(1) .set_attr("FInferShape", BinaryBroadcastShape) -.set_attr("FInferType", ElemwiseType<2, 1>) +.set_attr("FInferType", ElemwiseIntType<2, 1>) //TODO(reminisce):boolean support .set_attr("FListInputNames", [](const NodeAttrs& attrs) { return std::vector{"x1", "x2"}; @@ -196,5 +196,22 @@ NNVM_REGISTER_OP(_np_bitwise_and) .add_argument("x1", "NDArray-or-Symbol", "Input ndarray") .add_argument("x2", "NDArray-or-Symbol", "Input ndarray"); +NNVM_REGISTER_OP(_npi_bitwise_and_scalar) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr_parser([](NodeAttrs* attrs) { + attrs->parsed = std::stod(attrs->dict["scalar"]); + }) +.set_attr("FInferShape", ElemwiseShape<1, 1>) +.set_attr("FInferType", ElemwiseIntType<1, 1>) //TODO(reminisce):boolean support +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"x1", "x2"}; +}) +.set_attr("FCompute", BinaryScalarOp::Compute) +.set_attr("FGradient", MakeZeroGradNodes) +.add_argument("data", "NDArray-or-Symbol", "source input") +.add_argument("scalar", "int", "scalar input"); + } // namespace op } // namespace mxnet From 33aab4f7b0ef7eb62f364d921a0f1acf7ad2c7d9 Mon Sep 17 00:00:00 2001 From: zoeygxy Date: Mon, 26 Aug 2019 16:53:00 +0800 Subject: [PATCH 3/5] style fixed --- src/operator/numpy/np_elemwise_broadcast_op.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cc b/src/operator/numpy/np_elemwise_broadcast_op.cc index 0d22c0803557..0e3f5f6e228b 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op.cc +++ b/src/operator/numpy/np_elemwise_broadcast_op.cc @@ -186,7 +186,7 @@ NNVM_REGISTER_OP(_npi_bitwise_and) .set_num_inputs(2) .set_num_outputs(1) .set_attr("FInferShape", BinaryBroadcastShape) -.set_attr("FInferType", ElemwiseIntType<2, 1>) //TODO(reminisce):boolean support +.set_attr("FInferType", ElemwiseIntType<2, 1>) // TODO(reminisce) boolean support .set_attr("FListInputNames", [](const NodeAttrs& attrs) { return std::vector{"x1", "x2"}; @@ -203,7 +203,7 @@ NNVM_REGISTER_OP(_npi_bitwise_and_scalar) attrs->parsed = std::stod(attrs->dict["scalar"]); }) .set_attr("FInferShape", ElemwiseShape<1, 1>) -.set_attr("FInferType", ElemwiseIntType<1, 1>) //TODO(reminisce):boolean support +.set_attr("FInferType", ElemwiseIntType<1, 1>) // TODO(reminisce) boolean support .set_attr("FListInputNames", [](const NodeAttrs& attrs) { return std::vector{"x1", "x2"}; From a2782e798a5ca1cb9c6c867e5f31fa2e158048d9 Mon Sep 17 00:00:00 2001 From: zoeygxy Date: Mon, 26 Aug 2019 19:45:33 +0800 Subject: [PATCH 4/5] fix symbol calc --- python/mxnet/symbol/numpy/_symbol.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index f60037595245..df2b5ea25508 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -34,7 +34,8 @@ 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', - 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack'] + 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', + 'bitwise_and'] def _num_outputs(sym): From 00769312246ebb18bb92a657a3a7fdf88d0eafb7 Mon Sep 17 00:00:00 2001 From: Zoey Xinyi Ge Date: Tue, 27 Aug 2019 00:54:25 +0800 Subject: [PATCH 5/5] Add gpu implementation --- src/operator/numpy/np_elemwise_broadcast_op.cu | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cu b/src/operator/numpy/np_elemwise_broadcast_op.cu index ea71f2a70ecf..2b1edb481100 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op.cu +++ b/src/operator/numpy/np_elemwise_broadcast_op.cu @@ -48,7 +48,7 @@ NNVM_REGISTER_OP(_npi_maximum) NNVM_REGISTER_OP(_npi_minimum) .set_attr("FCompute", BinaryBroadcastCompute); -NNVM_REGISTER_OP(_np_bitwise_and) +NNVM_REGISTER_OP(_npi_bitwise_and) .set_attr("FCompute", BinaryBroadcastCompute); NNVM_REGISTER_OP(_npi_add_scalar) @@ -81,5 +81,8 @@ NNVM_REGISTER_OP(_npi_maximum_scalar) NNVM_REGISTER_OP(_npi_minimum_scalar) .set_attr("FCompute", BinaryScalarOp::Compute); +NNVM_REGISTER_OP(_npi_bitwise_and_scalar) +.set_attr("FCompute", BinaryScalarOp::Compute); + } // namespace op } // namespace mxnet