diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index f9164855dfe9..70cf3f6d2333 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -36,7 +36,7 @@ 'linspace', 'logspace', 'expand_dims', 'tile', 'arange', 'split', 'vsplit', 'concatenate', 'append', 'stack', 'vstack', 'column_stack', 'dstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', - 'around', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', + 'around', 'hypot', 'bitwise_xor', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero', 'shares_memory', 'may_share_memory', 'diff', 'resize', 'nan_to_num'] @@ -4291,6 +4291,44 @@ def hypot(x1, x2, out=None, **kwargs): return _ufunc_helper(x1, x2, _npi.hypot, _np.hypot, _npi.hypot_scalar, None, out) +@set_module('mxnet.ndarray.numpy') +@wrap_np_binary_func +def bitwise_xor(x1, x2, out=None, **kwargs): + r""" + Compute the bit-wise XOR of two arrays element-wise. + + Parameters + ---------- + x1, x2 : ndarray or scalar + Only integer and boolean types are handled. If x1.shape != x2.shape, + they must be broadcastable to a common shape (which becomes the shape of the output). + out : ndarray, optional + A location into which the result is stored. If provided, it must have a shape that the + inputs broadcast to. If not provided or None, a freshly-allocated array is returned. + + Returns + ------- + out : ndarray + Result. + + Examples + -------- + >>> np.bitwise_xor(13, 17) + 28 + + >>> np.bitwise_xor(31, 5) + 26 + >>> np.bitwise_xor(np.array([31,3], dtype='int32'), 5) + array([26, 6]) + + >>> np.bitwise_xor(np.array([31,3], dtype='int32'), np.array([5,6], dtype='int32')) + array([26, 5]) + >>> np.bitwise_xor(np.array([True, True], dtype='bool'), np.array([False, True], dtype='bool')) + array([ True, False]) + """ + return _ufunc_helper(x1, x2, _npi.bitwise_xor, _np.bitwise_xor, _npi.bitwise_xor_scalar, None, out) + + @set_module('mxnet.ndarray.numpy') @wrap_np_binary_func def ldexp(x1, x2, out=None, **kwargs): diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index 1df1a0360913..09ff22e66702 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -54,8 +54,8 @@ 'tensordot', 'histogram', 'eye', 'linspace', 'logspace', 'expand_dims', 'tile', 'arange', 'split', 'vsplit', 'concatenate', 'stack', 'vstack', 'column_stack', 'dstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', - 'blackman', 'flip', 'around', 'arctan2', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', - 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', + 'blackman', 'flip', 'around', 'arctan2', 'hypot', 'bitwise_xor', 'rad2deg', 'deg2rad', 'unique', 'lcm', + 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero', 'shares_memory', 'may_share_memory', 'diff', 'resize', 'nan_to_num'] @@ -6198,6 +6198,44 @@ def hypot(x1, x2, out=None, **kwargs): return _mx_nd_np.hypot(x1, x2, out=out) +@set_module('mxnet.numpy') +@wrap_np_binary_func +def bitwise_xor(x1, x2, out=None, **kwargs): + r""" + Compute the bit-wise XOR of two arrays element-wise. + + Parameters + ---------- + x1, x2 : ndarray or scalar + Only integer and boolean types are handled. If x1.shape != x2.shape, + they must be broadcastable to a common shape (which becomes the shape of the output). + out : ndarray, optional + A location into which the result is stored. If provided, it must have a shape that the + inputs broadcast to. If not provided or None, a freshly-allocated array is returned. + + Returns + ------- + out : ndarray + Result. + + Examples + -------- + >>> np.bitwise_xor(13, 17) + 28 + + >>> np.bitwise_xor(31, 5) + 26 + >>> np.bitwise_xor(np.array([31,3], dtype=np.int32), 5) + array([26, 6]) + + >>> np.bitwise_xor(np.array([31,3], dtype='int32'), np.array([5,6], dtype='int32')) + array([26, 5]) + >>> np.bitwise_xor(np.array([True, True], dtype='bool'), np.array([False, True], dtype='bool')) + array([ True, False]) + """ + return _mx_nd_np.bitwise_xor(x1, x2, out=out) + + @set_module('mxnet.numpy') @wrap_np_binary_func def ldexp(x1, x2, out=None, **kwargs): diff --git a/python/mxnet/numpy_dispatch_protocol.py b/python/mxnet/numpy_dispatch_protocol.py index 025982cfc7a5..694d1d40143b 100644 --- a/python/mxnet/numpy_dispatch_protocol.py +++ b/python/mxnet/numpy_dispatch_protocol.py @@ -225,6 +225,7 @@ def _register_array_function(): 'ceil', 'trunc', 'floor', + 'bitwise_xor', 'logical_not', 'equal', 'not_equal', diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index 0d7303865b92..0a42a620c14e 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -38,7 +38,7 @@ 'linspace', 'logspace', 'expand_dims', 'tile', 'arange', 'split', 'vsplit', 'concatenate', 'append', 'stack', 'vstack', 'column_stack', 'dstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', - 'around', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', + 'around', 'hypot', 'bitwise_xor', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'shares_memory', 'may_share_memory', 'diff', 'resize', 'nan_to_num'] @@ -4058,17 +4058,16 @@ def hypot(x1, x2, out=None, **kwargs): Parameters ---------- - x1, x2 : array_like + x1, x2 : _Symbol or scalar Leg of the triangle(s). - out : ndarray, None, or tuple of ndarray and None, optional + out : _Symbol or None, optional A location into which the result is stored. If provided, it must have a shape that the inputs broadcast to. If not provided or `None`, - a freshly-allocated array is returned. A tuple (possible only as a - keyword argument) must have length equal to the number of outputs. + a freshly-allocated array is returned. Returns ------- - z : ndarray + z : _Symbol or scalar The hypotenuse of the triangle(s). This is a scalar if both `x1` and `x2` are scalars. @@ -4080,6 +4079,30 @@ def hypot(x1, x2, out=None, **kwargs): return _ufunc_helper(x1, x2, _npi.hypot, _np.hypot, _npi.hypot_scalar, None, out) +@set_module('mxnet.symbol.numpy') +@wrap_np_binary_func +def bitwise_xor(x1, x2, out=None, **kwargs): + r""" + Compute the bit-wise XOR of two arrays element-wise. + + Parameters + ---------- + x1, x2 : _Symbol or scalar + Only integer and boolean types are handled. If x1.shape != x2.shape, + they must be broadcastable to a common shape (which becomes the shape of the output). + out : _Symbol or None, optional + A location into which the result is stored. If provided, it must have + a shape that the inputs broadcast to. If not provided or `None`, + a freshly-allocated array is returned. + + Returns + ------- + out : _Symbol or scalar + Result. + """ + return _ufunc_helper(x1, x2, _npi.bitwise_xor, _np.bitwise_xor, _npi.bitwise_xor_scalar, None, out) + + @set_module('mxnet.symbol.numpy') def unique(ar, return_index=False, return_inverse=False, return_counts=False, axis=None): """ diff --git a/src/operator/elemwise_op_common.h b/src/operator/elemwise_op_common.h index 6711297718b2..2cdd73a95801 100644 --- a/src/operator/elemwise_op_common.h +++ b/src/operator/elemwise_op_common.h @@ -209,7 +209,8 @@ inline bool ElemwiseIntType(const nnvm::NodeAttrs& attrs, CHECK(in_attrs->at(0) == mshadow::kInt64 || in_attrs->at(0) == mshadow::kInt32 || in_attrs->at(0) == mshadow::kInt8 || - in_attrs->at(0) == mshadow::kUint8) << "Only supports integer types."; + in_attrs->at(0) == mshadow::kUint8 || + in_attrs->at(0) == mshadow::kBool) << "Only supports integer types."; if (n_in != -1) { CHECK_EQ(in_attrs->size(), static_cast(n_in)) << " in operator " << attrs.name; } diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cc b/src/operator/numpy/np_elemwise_broadcast_op.cc index a76e59d30dc6..4e4734c143a4 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op.cc +++ b/src/operator/numpy/np_elemwise_broadcast_op.cc @@ -28,16 +28,6 @@ namespace mxnet { namespace op { -bool NumpyBinaryScalarType(const nnvm::NodeAttrs& attrs, - std::vector* in_attrs, - std::vector* out_attrs) { - CHECK_EQ(in_attrs->size(), 1U); - CHECK_EQ(out_attrs->size(), 1U); - TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); - TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0)); - return in_attrs->at(0) != -1; -} - #define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(name) \ NNVM_REGISTER_OP(name) \ .set_num_inputs(1) \ @@ -156,62 +146,6 @@ MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_mod) MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_power) .set_attr("FCompute", BinaryBroadcastCompute) .set_attr("FGradient", ElemwiseGradUseIn{"_backward_broadcast_power"}); - -MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_copysign) -.describe(R"code()code" ADD_FILELINE) -.set_attr("FCompute", BinaryBroadcastCompute) -.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_copysign"}); - -NNVM_REGISTER_OP(_backward_npi_copysign) -.set_num_inputs(3) -.set_num_outputs(2) -.set_attr("TIsBackward", true) -.set_attr("FInplaceOption", - [](const NodeAttrs& attrs){ - return std::vector >{{0, 1}}; - }) -.set_attr("FResourceRequest", - [](const NodeAttrs& attrs) { - return std::vector{ResourceRequest::kTempSpace}; - }) -.set_attr("FCompute", BinaryBroadcastBackwardUseIn); - -NNVM_REGISTER_OP(_npi_lcm) -.set_num_inputs(2) -.set_num_outputs(1) -.set_attr("FListInputNames", -[](const NodeAttrs& attrs) { - return std::vector{"lhs", "rhs"}; -}) -.set_attr("FInferShape", BinaryBroadcastShape) -.set_attr("FInferType", ElemwiseIntType<2, 1>) -.set_attr("FInplaceOption", -[](const NodeAttrs& attrs){ - return std::vector >{{0, 0}, {1, 0}}; -}) -.set_attr("FGradient", MakeZeroGradNodes) -.set_attr("FCompute", BinaryBroadcastCompute) -.add_argument("lhs", "NDArray-or-Symbol", "First input to the function") -.add_argument("rhs", "NDArray-or-Symbol", "Second input to the function"); - -NNVM_REGISTER_OP(_npi_lcm_scalar) -.set_num_inputs(1) -.set_num_outputs(1) -.set_attr_parser([](NodeAttrs* attrs) { - attrs->parsed = std::stod(attrs->dict["scalar"]); - }) -.set_attr("FInferShape", ElemwiseShape<1, 1>) -.set_attr("FInferType", ElemwiseIntType<1, 1>) -.set_attr("FInplaceOption", - [](const NodeAttrs& attrs){ - return std::vector >{{0, 0}}; - }) -.set_attr("FGradient", MakeZeroGradNodes) -.add_argument("data", "NDArray-or-Symbol", "source input") -.add_argument("scalar", "int", "scalar input") -.set_attr("FCompute", BinaryScalarOp::Compute); - MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_add_scalar) .set_attr("FCompute", BinaryScalarOp::Compute) .set_attr("FGradient", ElemwiseGradUseNone{"_copy"}); @@ -244,177 +178,5 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_rpower_scalar) .set_attr("FCompute", BinaryScalarOp::Compute) .set_attr("FGradient", ElemwiseGradUseOut{"_backward_rpower_scalar"}); -MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_copysign_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute) -.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_copysign_scalar"}); - -MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_rcopysign_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute) -.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_rcopysign_scalar"}); - -MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_backward_npi_copysign_scalar) -.set_attr("FCompute", - BinaryScalarOp::Backward); - -MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_backward_npi_rcopysign_scalar) -.set_attr("FCompute", - BinaryScalarOp::Backward); - -inline bool IsFloatType(const int dtype) { - return (dtype == mshadow::kFloat16 || - dtype == mshadow::kFloat32 || - dtype == mshadow::kFloat64); -} - -inline bool Arctan2OpType(const nnvm::NodeAttrs& attrs, - std::vector* in_attrs, - std::vector* out_attrs) { - CHECK_EQ(in_attrs->size(), 2U); - CHECK_EQ(out_attrs->size(), 1U); - - TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); - TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(1)); - TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0)); - TYPE_ASSIGN_CHECK(*in_attrs, 1, out_attrs->at(0)); - // check if it is float16, float32 or float64. If not, raise error. - CHECK(IsFloatType(in_attrs->at(0))) << "Do not support `int` as input.\n"; - return out_attrs->at(0) != -1; -} - -NNVM_REGISTER_OP(_npi_arctan2) -.set_num_inputs(2) -.set_num_outputs(1) -.set_attr("FListInputNames", - [](const NodeAttrs& attrs) { - return std::vector{"x1", "x2"}; - }) -.set_attr("FInferShape", BinaryBroadcastShape) -.set_attr("FInferType", Arctan2OpType) -.set_attr("FCompute", BinaryBroadcastCompute) -.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_arctan2"}) -.set_attr("FInplaceOption", - [](const NodeAttrs& attrs) { - return std::vector >{{0, 0}}; - }) -.add_argument("x1", "NDArray-or-Symbol", "The input array") -.add_argument("x2", "NDArray-or-Symbol", "The input array"); - -NNVM_REGISTER_OP(_backward_npi_arctan2) -.set_num_inputs(3) -.set_num_outputs(2) -.set_attr("TIsBackward", true) -.set_attr("FResourceRequest", - [](const NodeAttrs& attrs) { - return std::vector{ResourceRequest::kTempSpace}; - }) -.set_attr("FCompute", BinaryBroadcastBackwardUseIn); - -MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_arctan2_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute) -.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_arctan2_scalar"}); - -MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_rarctan2_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute) -.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_rarctan2_scalar"}); - -MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_arctan2_scalar) -.add_argument("scalar", "float", "scalar value") -.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); }) -.set_attr("FCompute", - BinaryScalarOp::Backward); - -MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_rarctan2_scalar) -.add_argument("scalar", "float", "scalar value") -.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); }) -.set_attr("FCompute", - BinaryScalarOp::Backward); - -bool HypotOpType(const nnvm::NodeAttrs& attrs, - std::vector* in_attrs, - std::vector* out_attrs) { - CHECK_EQ(in_attrs->size(), 2U); - CHECK_EQ(out_attrs->size(), 1U); - - TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); - TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(1)); - TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0)); - TYPE_ASSIGN_CHECK(*in_attrs, 1, out_attrs->at(0)); - - CHECK(IsFloatType(in_attrs->at(0))) << "Do not support `int` as input.\n"; - return out_attrs->at(0) != -1; -} - -// rigister hypot that do not support int here -NNVM_REGISTER_OP(_npi_hypot) -.set_num_inputs(2) -.set_num_outputs(1) -.set_attr("FListInputNames", - [](const NodeAttrs& attrs) { - return std::vector{"x1", "x2"}; - }) -.set_attr("FInferShape", BinaryBroadcastShape) -.set_attr("FInferType", HypotOpType) -.set_attr("FCompute", BinaryBroadcastCompute) -.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_hypot"}) -.set_attr("FInplaceOption", - [](const NodeAttrs& attrs) { - return std::vector >{{0, 0}, {1, 0}}; - }) -.add_argument("x1", "NDArray-or-Symbol", "The input array") -.add_argument("x2", "NDArray-or-Symbol", "The input array"); - -NNVM_REGISTER_OP(_backward_npi_hypot) -.set_num_inputs(3) -.set_num_outputs(2) -.set_attr("TIsBackward", true) -.set_attr("FInplaceOption", - [](const NodeAttrs& attrs) { - return std::vector > {{0, 1}}; - }) -.set_attr("FResourceRequest", - [](const NodeAttrs& attrs) { - return std::vector{ResourceRequest::kTempSpace}; - }) -.set_attr("FCompute", BinaryBroadcastBackwardUseIn); - -MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_ldexp) -.set_attr("FCompute", BinaryBroadcastCompute) -.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_ldexp"}); - -MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_ldexp_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute) -.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_ldexp_scalar"}); - -MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_rldexp_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute) -.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_rldexp_scalar"}); - -NNVM_REGISTER_OP(_backward_npi_ldexp) -.set_num_inputs(3) -.set_num_outputs(2) -.set_attr("TIsBackward", true) -.set_attr("FInplaceOption", - [](const NodeAttrs& attrs){ - return std::vector >{{0, 1}}; - }) -.set_attr("FResourceRequest", - [](const NodeAttrs& attrs) { - return std::vector{ResourceRequest::kTempSpace}; - }) -.set_attr("FCompute", BinaryBroadcastBackwardUseIn); - -MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_ldexp_scalar) -.add_argument("scalar", "float", "scalar value") -.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); }) -.set_attr("FCompute", BinaryScalarOp::Backward); - -MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_rldexp_scalar) -.add_argument("scalar", "float", "scalar value") -.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); }) -.set_attr("FCompute", BinaryScalarOp::Backward); - } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cu b/src/operator/numpy/np_elemwise_broadcast_op.cu index a0a277df211f..5c9dc97377cf 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op.cu +++ b/src/operator/numpy/np_elemwise_broadcast_op.cu @@ -70,29 +70,6 @@ NNVM_REGISTER_OP(_npi_mod) NNVM_REGISTER_OP(_npi_power) .set_attr("FCompute", BinaryBroadcastCompute); -NNVM_REGISTER_OP(_npi_copysign) -.set_attr("FCompute", BinaryBroadcastCompute); - -NNVM_REGISTER_OP(_npi_lcm) -.set_attr("FCompute", BinaryBroadcastCompute); - -NNVM_REGISTER_OP(_backward_npi_copysign) -.set_attr("FCompute", BinaryBroadcastBackwardUseIn); - -NNVM_REGISTER_OP(_npi_arctan2) -.set_attr("FCompute", BinaryBroadcastCompute); - -NNVM_REGISTER_OP(_backward_npi_arctan2) -.set_attr("FCompute", BinaryBroadcastBackwardUseIn); -NNVM_REGISTER_OP(_npi_hypot) -.set_attr("FCompute", BinaryBroadcastCompute); - -NNVM_REGISTER_OP(_backward_npi_hypot) -.set_attr("FCompute", BinaryBroadcastBackwardUseIn); - NNVM_REGISTER_OP(_npi_add_scalar) .set_attr("FCompute", BinaryScalarOp::Compute); @@ -117,53 +94,5 @@ NNVM_REGISTER_OP(_npi_power_scalar) NNVM_REGISTER_OP(_npi_rpower_scalar) .set_attr("FCompute", BinaryScalarOp::Compute); -NNVM_REGISTER_OP(_npi_copysign_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute); - -NNVM_REGISTER_OP(_npi_rcopysign_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute); - -NNVM_REGISTER_OP(_backward_npi_copysign_scalar) -.set_attr("FCompute", - BinaryScalarOp::Backward); - -NNVM_REGISTER_OP(_backward_npi_rcopysign_scalar) -.set_attr("FCompute", - BinaryScalarOp::Backward); - -NNVM_REGISTER_OP(_npi_arctan2_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute); - -NNVM_REGISTER_OP(_backward_npi_arctan2_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute); - -NNVM_REGISTER_OP(_npi_rarctan2_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute); - -NNVM_REGISTER_OP(_backward_npi_rarctan2_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute); - -NNVM_REGISTER_OP(_npi_lcm_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute); - -NNVM_REGISTER_OP(_npi_ldexp) -.set_attr("FCompute", BinaryBroadcastCompute); - -NNVM_REGISTER_OP(_npi_ldexp_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute); - -NNVM_REGISTER_OP(_npi_rldexp_scalar) -.set_attr("FCompute", BinaryScalarOp::Compute); - -NNVM_REGISTER_OP(_backward_npi_ldexp) -.set_attr("FCompute", BinaryBroadcastBackwardUseIn); - -NNVM_REGISTER_OP(_backward_npi_ldexp_scalar) -.set_attr("FCompute", BinaryScalarOp::Backward); - -NNVM_REGISTER_OP(_backward_npi_rldexp_scalar) -.set_attr("FCompute", BinaryScalarOp::Backward); - } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/np_elemwise_broadcast_op.h b/src/operator/numpy/np_elemwise_broadcast_op.h index 1a4596fba91c..2ad8b88bdb1c 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op.h +++ b/src/operator/numpy/np_elemwise_broadcast_op.h @@ -34,6 +34,16 @@ namespace mxnet { namespace op { +inline bool NumpyBinaryScalarType(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); + TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0)); + return in_attrs->at(0) != -1; +} + inline void PrintErrorMessage(const std::string& op_name, const int dtype1, const int dtype2) { LOG(FATAL) << "Operator " << op_name << " does not support combination of " << common::dtype_string(dtype1) << " with " << common::dtype_string(dtype2) diff --git a/src/operator/numpy/np_elemwise_broadcast_op_extended.cc b/src/operator/numpy/np_elemwise_broadcast_op_extended.cc new file mode 100644 index 000000000000..84c47e597883 --- /dev/null +++ b/src/operator/numpy/np_elemwise_broadcast_op_extended.cc @@ -0,0 +1,305 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_elemwise_binary_op_extended.cc + * \brief CPU Implementation of extended functions for elementwise numpy binary broadcast operator. + */ + +#include "../../common/utils.h" +#include "./np_elemwise_broadcast_op.h" + +namespace mxnet { +namespace op { + +#define MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(name) \ + NNVM_REGISTER_OP(name) \ + .set_num_inputs(1) \ + .set_num_outputs(1) \ + .set_attr_parser([](NodeAttrs* attrs) { \ + attrs->parsed = std::stod(attrs->dict["scalar"]); \ + }) \ + .set_attr("FInferShape", ElemwiseShape<1, 1>) \ + .set_attr("FInferType", NumpyBinaryScalarType) \ + .set_attr("FInplaceOption", \ + [](const NodeAttrs& attrs){ \ + return std::vector >{{0, 0}}; \ + }) \ + .add_argument("data", "NDArray-or-Symbol", "source input") \ + .add_argument("scalar", "float", "scalar input") + +MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_copysign) +.describe(R"code()code" ADD_FILELINE) +.set_attr("FCompute", BinaryBroadcastCompute) +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_copysign"}); + +NNVM_REGISTER_OP(_backward_npi_copysign) +.set_num_inputs(3) +.set_num_outputs(2) +.set_attr("TIsBackward", true) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs){ + return std::vector >{{0, 1}}; + }) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FCompute", BinaryBroadcastBackwardUseIn); + +NNVM_REGISTER_OP(_npi_lcm) +.set_num_inputs(2) +.set_num_outputs(1) +.set_attr("FListInputNames", +[](const NodeAttrs& attrs) { + return std::vector{"lhs", "rhs"}; +}) +.set_attr("FInferShape", BinaryBroadcastShape) +.set_attr("FInferType", ElemwiseIntType<2, 1>) +.set_attr("FInplaceOption", +[](const NodeAttrs& attrs){ + return std::vector >{{0, 0}, {1, 0}}; +}) +.set_attr("FGradient", MakeZeroGradNodes) +.set_attr("FCompute", BinaryBroadcastIntCompute) +.add_argument("lhs", "NDArray-or-Symbol", "First input to the function") +.add_argument("rhs", "NDArray-or-Symbol", "Second input to the function"); + +NNVM_REGISTER_OP(_npi_lcm_scalar) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr_parser([](NodeAttrs* attrs) { + attrs->parsed = std::stod(attrs->dict["scalar"]); + }) +.set_attr("FInferShape", ElemwiseShape<1, 1>) +.set_attr("FInferType", ElemwiseIntType<1, 1>) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs){ + return std::vector >{{0, 0}}; + }) +.set_attr("FGradient", MakeZeroGradNodes) +.add_argument("data", "NDArray-or-Symbol", "source input") +.add_argument("scalar", "int", "scalar input") +.set_attr("FCompute", BinaryScalarOp::Compute); + +NNVM_REGISTER_OP(_npi_bitwise_xor) +.set_num_inputs(2) +.set_num_outputs(1) +.set_attr("FListInputNames", +[](const NodeAttrs& attrs) { + return std::vector{"lhs", "rhs"}; +}) +.set_attr("FInferShape", BinaryBroadcastShape) +.set_attr("FInferType", ElemwiseIntType<2, 1>) +.set_attr("FInplaceOption", +[](const NodeAttrs& attrs){ + return std::vector >{{0, 0}, {1, 0}}; +}) +.set_attr("FGradient", MakeZeroGradNodes) +.set_attr("FCompute", BinaryBroadcastIntCompute) +.add_argument("lhs", "NDArray-or-Symbol", "First input to the function") +.add_argument("rhs", "NDArray-or-Symbol", "Second input to the function"); + +NNVM_REGISTER_OP(_npi_bitwise_xor_scalar) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr_parser([](NodeAttrs* attrs) { + attrs->parsed = std::stod(attrs->dict["scalar"]); + }) +.set_attr("FInferShape", ElemwiseShape<1, 1>) +.set_attr("FInferType", ElemwiseIntType<1, 1>) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs){ + return std::vector >{{0, 0}}; + }) +.set_attr("FGradient", MakeZeroGradNodes) +.add_argument("data", "NDArray-or-Symbol", "source input") +.add_argument("scalar", "int", "scalar input") +.set_attr("FCompute", BinaryScalarOp::ComputeInt); + +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_copysign_scalar) +.set_attr("FCompute", BinaryScalarOp::Compute) +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_copysign_scalar"}); + +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_rcopysign_scalar) +.set_attr("FCompute", BinaryScalarOp::Compute) +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_rcopysign_scalar"}); + +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_backward_npi_copysign_scalar) +.set_attr("FCompute", + BinaryScalarOp::Backward); + +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_backward_npi_rcopysign_scalar) +.set_attr("FCompute", + BinaryScalarOp::Backward); + +inline bool Arctan2OpType(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 2U); + CHECK_EQ(out_attrs->size(), 1U); + + TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); + TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(1)); + TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0)); + TYPE_ASSIGN_CHECK(*in_attrs, 1, out_attrs->at(0)); + // check if it is float16, float32 or float64. If not, raise error. + CHECK(common::is_float(in_attrs->at(0))) << "Do not support `int` as input.\n"; + return out_attrs->at(0) != -1; +} + +NNVM_REGISTER_OP(_npi_arctan2) +.set_num_inputs(2) +.set_num_outputs(1) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"x1", "x2"}; + }) +.set_attr("FInferShape", BinaryBroadcastShape) +.set_attr("FInferType", Arctan2OpType) +.set_attr("FCompute", BinaryBroadcastCompute) +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_arctan2"}) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs) { + return std::vector >{{0, 0}}; + }) +.add_argument("x1", "NDArray-or-Symbol", "The input array") +.add_argument("x2", "NDArray-or-Symbol", "The input array"); + +NNVM_REGISTER_OP(_backward_npi_arctan2) +.set_num_inputs(3) +.set_num_outputs(2) +.set_attr("TIsBackward", true) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FCompute", BinaryBroadcastBackwardUseIn); + +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_arctan2_scalar) +.set_attr("FCompute", BinaryScalarOp::Compute) +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_arctan2_scalar"}); + +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_rarctan2_scalar) +.set_attr("FCompute", BinaryScalarOp::Compute) +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_rarctan2_scalar"}); + +MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_arctan2_scalar) +.add_argument("scalar", "float", "scalar value") +.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); }) +.set_attr("FCompute", + BinaryScalarOp::Backward); + +MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_rarctan2_scalar) +.add_argument("scalar", "float", "scalar value") +.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); }) +.set_attr("FCompute", + BinaryScalarOp::Backward); + +bool HypotOpType(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 2U); + CHECK_EQ(out_attrs->size(), 1U); + + TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); + TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(1)); + TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0)); + TYPE_ASSIGN_CHECK(*in_attrs, 1, out_attrs->at(0)); + + CHECK(common::is_float(in_attrs->at(0))) << "Do not support `int` as input.\n"; + return out_attrs->at(0) != -1; +} + +// rigister hypot that do not support int here +NNVM_REGISTER_OP(_npi_hypot) +.set_num_inputs(2) +.set_num_outputs(1) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"x1", "x2"}; + }) +.set_attr("FInferShape", BinaryBroadcastShape) +.set_attr("FInferType", HypotOpType) +.set_attr("FCompute", BinaryBroadcastCompute) +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_hypot"}) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs) { + return std::vector >{{0, 0}, {1, 0}}; + }) +.add_argument("x1", "NDArray-or-Symbol", "The input array") +.add_argument("x2", "NDArray-or-Symbol", "The input array"); + +NNVM_REGISTER_OP(_backward_npi_hypot) +.set_num_inputs(3) +.set_num_outputs(2) +.set_attr("TIsBackward", true) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs) { + return std::vector > {{0, 1}}; + }) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FCompute", BinaryBroadcastBackwardUseIn); + +MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_ldexp) +.set_attr("FCompute", BinaryBroadcastCompute) +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_ldexp"}); + +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_ldexp_scalar) +.set_attr("FCompute", BinaryScalarOp::Compute) +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_ldexp_scalar"}); + +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_rldexp_scalar) +.set_attr("FCompute", BinaryScalarOp::Compute) +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_rldexp_scalar"}); + +NNVM_REGISTER_OP(_backward_npi_ldexp) +.set_num_inputs(3) +.set_num_outputs(2) +.set_attr("TIsBackward", true) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs){ + return std::vector >{{0, 1}}; + }) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FCompute", BinaryBroadcastBackwardUseIn); + +MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_ldexp_scalar) +.add_argument("scalar", "float", "scalar value") +.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); }) +.set_attr("FCompute", BinaryScalarOp::Backward); + +MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_rldexp_scalar) +.add_argument("scalar", "float", "scalar value") +.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); }) +.set_attr("FCompute", BinaryScalarOp::Backward); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/np_elemwise_broadcast_op_extended.cu b/src/operator/numpy/np_elemwise_broadcast_op_extended.cu new file mode 100644 index 000000000000..f858fb4a4e79 --- /dev/null +++ b/src/operator/numpy/np_elemwise_broadcast_op_extended.cu @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_elemwise_broadcast_op_extended.cu + * \brief GPU Implementation of extended functions for elementwise binary broadcast operator. + */ + +#include "./np_elemwise_broadcast_op.h" + +namespace mxnet { +namespace op { + +NNVM_REGISTER_OP(_npi_copysign) +.set_attr("FCompute", BinaryBroadcastCompute); + +NNVM_REGISTER_OP(_npi_lcm) +.set_attr("FCompute", BinaryBroadcastIntCompute); + +NNVM_REGISTER_OP(_npi_bitwise_xor) +.set_attr("FCompute", BinaryBroadcastIntCompute); + +NNVM_REGISTER_OP(_backward_npi_copysign) +.set_attr("FCompute", BinaryBroadcastBackwardUseIn); + +NNVM_REGISTER_OP(_npi_arctan2) +.set_attr("FCompute", BinaryBroadcastCompute); + +NNVM_REGISTER_OP(_backward_npi_arctan2) +.set_attr("FCompute", BinaryBroadcastBackwardUseIn); +NNVM_REGISTER_OP(_npi_hypot) +.set_attr("FCompute", BinaryBroadcastCompute); + +NNVM_REGISTER_OP(_backward_npi_hypot) +.set_attr("FCompute", BinaryBroadcastBackwardUseIn); +NNVM_REGISTER_OP(_npi_copysign_scalar) +.set_attr("FCompute", BinaryScalarOp::Compute); + +NNVM_REGISTER_OP(_npi_rcopysign_scalar) +.set_attr("FCompute", BinaryScalarOp::Compute); + +NNVM_REGISTER_OP(_backward_npi_copysign_scalar) +.set_attr("FCompute", + BinaryScalarOp::Backward); + +NNVM_REGISTER_OP(_backward_npi_rcopysign_scalar) +.set_attr("FCompute", + BinaryScalarOp::Backward); + +NNVM_REGISTER_OP(_npi_arctan2_scalar) +.set_attr("FCompute", BinaryScalarOp::Compute); + +NNVM_REGISTER_OP(_backward_npi_arctan2_scalar) +.set_attr("FCompute", BinaryScalarOp::Compute); + +NNVM_REGISTER_OP(_npi_rarctan2_scalar) +.set_attr("FCompute", BinaryScalarOp::Compute); + +NNVM_REGISTER_OP(_backward_npi_rarctan2_scalar) +.set_attr("FCompute", BinaryScalarOp::Compute); + +NNVM_REGISTER_OP(_npi_lcm_scalar) +.set_attr("FCompute", BinaryScalarOp::ComputeInt); + +NNVM_REGISTER_OP(_npi_bitwise_xor_scalar) +.set_attr("FCompute", BinaryScalarOp::ComputeInt); + +NNVM_REGISTER_OP(_npi_ldexp) +.set_attr("FCompute", BinaryBroadcastCompute); + +NNVM_REGISTER_OP(_npi_ldexp_scalar) +.set_attr("FCompute", BinaryScalarOp::Compute); + +NNVM_REGISTER_OP(_npi_rldexp_scalar) +.set_attr("FCompute", BinaryScalarOp::Compute); + +NNVM_REGISTER_OP(_backward_npi_ldexp) +.set_attr("FCompute", BinaryBroadcastBackwardUseIn); + +NNVM_REGISTER_OP(_backward_npi_ldexp_scalar) +.set_attr("FCompute", BinaryScalarOp::Backward); + +NNVM_REGISTER_OP(_backward_npi_rldexp_scalar) +.set_attr("FCompute", BinaryScalarOp::Backward); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/operator_tune.cc b/src/operator/operator_tune.cc index 633f63026bc0..e2a4c8af3099 100644 --- a/src/operator/operator_tune.cc +++ b/src/operator/operator_tune.cc @@ -396,10 +396,10 @@ IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::logical_or); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::logical_or); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::logical_xor); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::logical_xor); // NOLINT() -IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::bitwise_xor); // NOLINT() +IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::bitwise_xor); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::smooth_l1_loss); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::smooth_l1_gradient); // NOLINT() -IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::lcm); // NOLINT() +IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::lcm); // NOLINT() IMPLEMENT_BLANK_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mxnet_op::set_to_int<0>); // NOLINT() IMPLEMENT_BLANK_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mxnet_op::set_to_int<1>); // NOLINT() IMPLEMENT_BLANK_WORKLOAD_FWD(mxnet::op::PopulateFullIdxRspKernel); // NOLINT() diff --git a/src/operator/tensor/elemwise_binary_broadcast_op.h b/src/operator/tensor/elemwise_binary_broadcast_op.h index b48ed389ba98..958432b2c4b2 100644 --- a/src/operator/tensor/elemwise_binary_broadcast_op.h +++ b/src/operator/tensor/elemwise_binary_broadcast_op.h @@ -332,6 +332,37 @@ struct csr_dns_map_kernel { } // namespace mxnet_op +template +void BinaryBroadcastIntCompute(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + if (outputs[0].shape_.Size() == 0U) return; + mxnet::TShape new_lshape, new_rshape, new_oshape; + int ndim = BinaryBroadcastShapeCompact(inputs[0].shape_, inputs[1].shape_, outputs[0].shape_, + &new_lshape, &new_rshape, &new_oshape); + if (!ndim) { + ElemwiseBinaryOp::ComputeInt(attrs, ctx, inputs, req, outputs); + } else { + if (req[0] == kNullOp) return; + mshadow::Stream *s = ctx.get_stream(); + if (outputs[0].type_flag_ == mshadow::kBool) { + LOG(FATAL) << "Operator " << attrs.op->name << " does not support boolean type"; + } + MXNET_INT_TYPE_SWITCH(outputs[0].type_flag_, DType, { + BROADCAST_NDIM_SWITCH(ndim, NDim, { + mshadow::Shape oshape = new_oshape.get(); + mshadow::Shape lstride = mxnet_op::calc_stride(new_lshape.get()); + mshadow::Shape rstride = mxnet_op::calc_stride(new_rshape.get()); + mxnet_op::Kernel, xpu>:: + template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape, + inputs[0].dptr(), inputs[1].dptr(), outputs[0].dptr()); + }); + }); + } +} + template void BinaryBroadcastCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -345,22 +376,21 @@ void BinaryBroadcastCompute(const nnvm::NodeAttrs& attrs, if (!ndim) { ElemwiseBinaryOp::Compute(attrs, ctx, inputs, req, outputs); } else { - if (req[0] != kNullOp) { - mshadow::Stream *s = ctx.get_stream(); - if (outputs[0].type_flag_ == mshadow::kBool) { - LOG(FATAL) << "Operator " << attrs.op->name << " does not support boolean type"; - } - MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { - BROADCAST_NDIM_SWITCH(ndim, NDim, { - mshadow::Shape oshape = new_oshape.get(); - mshadow::Shape lstride = mxnet_op::calc_stride(new_lshape.get()); - mshadow::Shape rstride = mxnet_op::calc_stride(new_rshape.get()); - mxnet_op::Kernel, xpu>:: - template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape, - inputs[0].dptr(), inputs[1].dptr(), outputs[0].dptr()); - }); - }); + if (req[0] == kNullOp) return; + mshadow::Stream *s = ctx.get_stream(); + if (outputs[0].type_flag_ == mshadow::kBool) { + LOG(FATAL) << "Operator " << attrs.op->name << " does not support boolean type"; } + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { + BROADCAST_NDIM_SWITCH(ndim, NDim, { + mshadow::Shape oshape = new_oshape.get(); + mshadow::Shape lstride = mxnet_op::calc_stride(new_lshape.get()); + mshadow::Shape rstride = mxnet_op::calc_stride(new_rshape.get()); + mxnet_op::Kernel, xpu>:: + template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape, + inputs[0].dptr(), inputs[1].dptr(), outputs[0].dptr()); + }); + }); } } @@ -377,19 +407,18 @@ void BinaryBroadcastComputeWithBool(const nnvm::NodeAttrs& attrs, if (!ndim) { ElemwiseBinaryOp::ComputeWithBool(attrs, ctx, inputs, req, outputs); } else { - if (req[0] != kNullOp) { - mshadow::Stream *s = ctx.get_stream(); - MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[0].type_flag_, DType, { - BROADCAST_NDIM_SWITCH(ndim, NDim, { - mshadow::Shape oshape = new_oshape.get(); - mshadow::Shape lstride = mxnet_op::calc_stride(new_lshape.get()); - mshadow::Shape rstride = mxnet_op::calc_stride(new_rshape.get()); - mxnet_op::Kernel, xpu>:: - template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape, - inputs[0].dptr(), inputs[1].dptr(), outputs[0].dptr()); - }); + if (req[0] == kNullOp) return; + mshadow::Stream *s = ctx.get_stream(); + MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[0].type_flag_, DType, { + BROADCAST_NDIM_SWITCH(ndim, NDim, { + mshadow::Shape oshape = new_oshape.get(); + mshadow::Shape lstride = mxnet_op::calc_stride(new_lshape.get()); + mshadow::Shape rstride = mxnet_op::calc_stride(new_rshape.get()); + mxnet_op::Kernel, xpu>:: + template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape, + inputs[0].dptr(), inputs[1].dptr(), outputs[0].dptr()); }); - } + }); } } @@ -406,20 +435,19 @@ void BinaryBroadcastComputeLogic(const nnvm::NodeAttrs& attrs, if (!ndim) { ElemwiseBinaryOp::ComputeLogic(attrs, ctx, inputs, req, outputs); } else { - if (req[0] != kNullOp) { - mshadow::Stream *s = ctx.get_stream(); - MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, DType, { - BROADCAST_NDIM_SWITCH(ndim, NDim, { - mshadow::Shape oshape = new_oshape.get(); - mshadow::Shape lstride = mxnet_op::calc_stride(new_lshape.get()); - mshadow::Shape rstride = mxnet_op::calc_stride(new_rshape.get()); - mxnet_op::Kernel, xpu>:: - template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape, - inputs[0].dptr(), inputs[1].dptr(), - outputs[0].dptr()); - }); - }); - } + if (req[0] == kNullOp) return; + mshadow::Stream *s = ctx.get_stream(); + MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, DType, { + BROADCAST_NDIM_SWITCH(ndim, NDim, { + mshadow::Shape oshape = new_oshape.get(); + mshadow::Shape lstride = mxnet_op::calc_stride(new_lshape.get()); + mshadow::Shape rstride = mxnet_op::calc_stride(new_rshape.get()); + mxnet_op::Kernel, xpu>:: + template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape, + inputs[0].dptr(), inputs[1].dptr(), + outputs[0].dptr()); + }); + }); } } diff --git a/src/operator/tensor/elemwise_binary_op.h b/src/operator/tensor/elemwise_binary_op.h index c046a28f16b2..bc5140a5d75f 100644 --- a/src/operator/tensor/elemwise_binary_op.h +++ b/src/operator/tensor/elemwise_binary_op.h @@ -474,6 +474,30 @@ class ElemwiseBinaryOp : public OpBase { std::vector *in_attrs, std::vector *out_attrs); + template + static void ComputeInt(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mxnet_op; + if (req[0] == kNullOp) return; + Stream *s = ctx.get_stream(); + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { + MXNET_INT_TYPE_SWITCH(outputs[0].type_flag_, DType, { + const size_t size = (minthree(outputs[0].Size(), inputs[0].Size(), inputs[1].Size()) + + DataType::kLanes - 1) / DataType::kLanes; + if (size != 0) { + Kernel, xpu>::Launch(s, size, + outputs[0].dptr(), + inputs[0].dptr(), inputs[1].dptr()); + } + }); + }); + } + template static void Compute(const nnvm::NodeAttrs &attrs, const OpContext &ctx, @@ -481,25 +505,24 @@ class ElemwiseBinaryOp : public OpBase { const std::vector &req, const std::vector &outputs) { using namespace mxnet_op; - if (req[0] != kNullOp) { - Stream *s = ctx.get_stream(); - CHECK_EQ(inputs.size(), 2U); - CHECK_EQ(outputs.size(), 1U); - if (outputs[0].type_flag_ == mshadow::kBool) { - LOG(FATAL) << "Operator " << attrs.op->name << " does not support boolean type"; - } - MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { - MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { - const size_t size = (minthree(outputs[0].Size(), inputs[0].Size(), inputs[1].Size()) - + DataType::kLanes - 1) / DataType::kLanes; - if (size != 0) { - Kernel, xpu>::Launch(s, size, - outputs[0].dptr(), - inputs[0].dptr(), inputs[1].dptr()); - } - }); - }); + if (req[0] == kNullOp) return; + Stream *s = ctx.get_stream(); + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + if (outputs[0].type_flag_ == mshadow::kBool) { + LOG(FATAL) << "Operator " << attrs.op->name << " does not support boolean type"; } + MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { + const size_t size = (minthree(outputs[0].Size(), inputs[0].Size(), inputs[1].Size()) + + DataType::kLanes - 1) / DataType::kLanes; + if (size != 0) { + Kernel, xpu>::Launch(s, size, + outputs[0].dptr(), + inputs[0].dptr(), inputs[1].dptr()); + } + }); + }); } template @@ -509,22 +532,21 @@ class ElemwiseBinaryOp : public OpBase { const std::vector &req, const std::vector &outputs) { using namespace mxnet_op; - if (req[0] != kNullOp) { - Stream *s = ctx.get_stream(); - CHECK_EQ(inputs.size(), 2U); - CHECK_EQ(outputs.size(), 1U); - MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { - MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[0].type_flag_, DType, { - const size_t size = (minthree(outputs[0].Size(), inputs[0].Size(), inputs[1].Size()) - + DataType::kLanes - 1) / DataType::kLanes; - if (size != 0) { - Kernel, xpu>::Launch(s, size, - outputs[0].dptr(), - inputs[0].dptr(), inputs[1].dptr()); - } - }); + if (req[0] == kNullOp) return; + Stream *s = ctx.get_stream(); + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[0].type_flag_, DType, { + const size_t size = (minthree(outputs[0].Size(), inputs[0].Size(), inputs[1].Size()) + + DataType::kLanes - 1) / DataType::kLanes; + if (size != 0) { + Kernel, xpu>::Launch(s, size, + outputs[0].dptr(), + inputs[0].dptr(), inputs[1].dptr()); + } }); - } + }); } template @@ -534,23 +556,22 @@ class ElemwiseBinaryOp : public OpBase { const std::vector &req, const std::vector &outputs) { using namespace mxnet_op; - if (req[0] != kNullOp) { - Stream *s = ctx.get_stream(); - CHECK_EQ(inputs.size(), 2U); - CHECK_EQ(outputs.size(), 1U); - MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { - MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, DType, { - const size_t size = (minthree(outputs[0].Size(), inputs[0].Size(), inputs[1].Size()) - + DataType::kLanes - 1) / DataType::kLanes; - if (size != 0) { - Kernel, xpu>::Launch(s, size, - outputs[0].dptr(), - inputs[0].dptr(), - inputs[1].dptr()); - } - }); + if (req[0] == kNullOp) return; + Stream *s = ctx.get_stream(); + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, DType, { + const size_t size = (minthree(outputs[0].Size(), inputs[0].Size(), inputs[1].Size()) + + DataType::kLanes - 1) / DataType::kLanes; + if (size != 0) { + Kernel, xpu>::Launch(s, size, + outputs[0].dptr(), + inputs[0].dptr(), + inputs[1].dptr()); + } }); - } + }); } template @@ -560,22 +581,21 @@ class ElemwiseBinaryOp : public OpBase { const std::vector &req, const std::vector &outputs) { using namespace mxnet_op; - if (req[0] != kNullOp) { - Stream *s = ctx.get_stream(); - CHECK_EQ(inputs.size(), 2U); - CHECK_EQ(outputs.size(), 1U); - MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { - MSHADOW_TYPE_SWITCH_WITH_HALF2(outputs[0].type_flag_, DType, { - const size_t size = (minthree(outputs[0].Size(), inputs[0].Size(), inputs[1].Size()) - + DataType::kLanes - 1) / DataType::kLanes; - if (size != 0) { - Kernel, xpu>::Launch(s, size, - outputs[0].dptr(), - inputs[0].dptr(), inputs[1].dptr()); - } - }); + if (req[0] == kNullOp) return; + Stream *s = ctx.get_stream(); + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { + MSHADOW_TYPE_SWITCH_WITH_HALF2(outputs[0].type_flag_, DType, { + const size_t size = (minthree(outputs[0].Size(), inputs[0].Size(), inputs[1].Size()) + + DataType::kLanes - 1) / DataType::kLanes; + if (size != 0) { + Kernel, xpu>::Launch(s, size, + outputs[0].dptr(), + inputs[0].dptr(), inputs[1].dptr()); + } }); - } + }); } template diff --git a/src/operator/tensor/elemwise_binary_scalar_op.h b/src/operator/tensor/elemwise_binary_scalar_op.h index 834bbdbfc3d1..3e8702813a7c 100644 --- a/src/operator/tensor/elemwise_binary_scalar_op.h +++ b/src/operator/tensor/elemwise_binary_scalar_op.h @@ -244,6 +244,26 @@ class BinaryScalarOp : public UnaryOp { }); } + template + static void ComputeInt(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + DCHECK_EQ(inputs.size(), 1); + DCHECK_EQ(outputs.size(), 1); + using namespace mshadow; + using namespace mshadow::expr; + Stream *s = ctx.get_stream(); + const double alpha = nnvm::get(attrs.parsed); + MXNET_INT_TYPE_SWITCH(outputs[0].type_flag_, DType, { + MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { + mxnet_op::Kernel, xpu>::Launch( + s, inputs[0].Size(), outputs[0].dptr(), inputs[0].dptr(), DType(alpha)); + }); + }); + } + template static void ComputeLogic(const nnvm::NodeAttrs &attrs, const OpContext &ctx, diff --git a/tests/python/unittest/test_numpy_interoperability.py b/tests/python/unittest/test_numpy_interoperability.py index 6d9c63f9f857..658c7a37e354 100644 --- a/tests/python/unittest/test_numpy_interoperability.py +++ b/tests/python/unittest/test_numpy_interoperability.py @@ -264,7 +264,7 @@ def _add_workload_linalg_cholesky(): a = _np.matmul(a.transpose(t).conj(), a) OpArgMngr.add_workload('linalg.cholesky', np.array(a, dtype=dtype)) - + # test_0_size for dtype in dtypes: a = np.zeros((0, 1, 1)) @@ -816,6 +816,18 @@ def _add_workload_lcm(): OpArgMngr.add_workload('lcm', np.array(195225786*2, dtype=np.int32), np.array(195225786*5, dtype=np.int32)) +def _add_workload_bitwise_xor(): + OpArgMngr.add_workload('bitwise_xor', np.array([False, False, True, True], dtype=np.bool), + np.array([False, True, False, True], dtype=np.bool)) + for dtype in [np.int8, np.int32, np.int64]: + zeros = np.array([0], dtype=dtype) + ones = np.array([-1], dtype=dtype) + OpArgMngr.add_workload('bitwise_xor', zeros, zeros) + OpArgMngr.add_workload('bitwise_xor', ones, zeros) + OpArgMngr.add_workload('bitwise_xor', zeros, ones) + OpArgMngr.add_workload('bitwise_xor', ones, ones) + + def _add_workload_ldexp(): OpArgMngr.add_workload('ldexp', np.array(2., np.float32), np.array(3, np.int8)) OpArgMngr.add_workload('ldexp', np.array(2., np.float64), np.array(3, np.int8)) @@ -1236,6 +1248,7 @@ def _prepare_workloads(): _add_workload_inner() _add_workload_hypot() _add_workload_lcm() + _add_workload_bitwise_xor() _add_workload_ldexp() _add_workload_subtract(array_pool) _add_workload_multiply(array_pool) diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 2a23c976a092..405d4a4cca6d 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -1650,6 +1650,7 @@ def hybrid_forward(self, F, a, b, *args, **kwargs): 'power': (1.0, 2.0, [lambda y, x1, x2: _np.power(x1, x2 - 1.0) * x2], [lambda y, x1, x2: _np.power(x1, x2) * _np.log(x1)]), 'lcm': (-100, 100, [None], None, [[_np.int32]]), + 'bitwise_xor': (-100, 100, [None], None, [[_np.int32]]), 'maximum': (-1, 1, [lambda y, x1, x2: _np.ones(y.shape) * (x1 >= x2)], [lambda y, x1, x2: _np.ones(y.shape) * (x1 < x2)]), 'minimum': (-1, 1, [lambda y, x1, x2: _np.ones(y.shape) * (x1 <= x2)], @@ -4389,7 +4390,7 @@ def hybrid_forward(self, F, a): mx_out.backward() if (np_out.size == 0): np_backward = _np.zeros(shape) - else: + else: np_backward = np_diff_backward(_np.ones(np_out.shape, dtype=itype), n=n, axis=axis) assert x.grad.shape == np_backward.shape assert_almost_equal(x.grad.asnumpy(), np_backward, rtol=rtol, atol=atol) @@ -4535,7 +4536,7 @@ def hybrid_forward(self, F, a): copy_list = [True, False] hybridize_list = [True, False] atol, rtol = 1e-5, 1e-3 - + src_dtype_comb = list(itertools.product(src_list,dtype_list)) # check the dtype = int case in both imperative and sympolic expression src_dtype_comb.append((1,'int32')) @@ -4582,10 +4583,10 @@ def hybrid_forward(self, F, a): assert_almost_equal(mx_out_gluon.asnumpy(), np_out, rtol, atol) mx_out_gluon.backward() assert_almost_equal(x2.grad.asnumpy(), expected_grad, rtol=1e-3, atol=1e-5) - + # Test imperative once again # if copy = False, the value of x1 and x2 has changed - if copy == True: + if copy == True: np_out = _np.nan_to_num(x1) mx_out = np.nan_to_num(x3) assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, use_broadcast=False)