Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Refactor NumPy-compatible elemwise broadcast operators #16827

Merged
merged 1 commit into from
Nov 16, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 39 additions & 1 deletion python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
'linspace', 'logspace', 'expand_dims', 'tile', 'arange', 'split', 'vsplit', 'concatenate', 'append',
'stack', 'vstack', 'column_stack', 'dstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax',
'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip',
'around', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take',
'around', 'hypot', 'bitwise_xor', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take',
'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal',
'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero', 'shares_memory', 'may_share_memory', 'diff', 'resize',
'nan_to_num']
Expand Down Expand Up @@ -4291,6 +4291,44 @@ def hypot(x1, x2, out=None, **kwargs):
return _ufunc_helper(x1, x2, _npi.hypot, _np.hypot, _npi.hypot_scalar, None, out)


@set_module('mxnet.ndarray.numpy')
@wrap_np_binary_func
def bitwise_xor(x1, x2, out=None, **kwargs):
r"""
Compute the bit-wise XOR of two arrays element-wise.

Parameters
----------
x1, x2 : ndarray or scalar
Only integer and boolean types are handled. If x1.shape != x2.shape,
they must be broadcastable to a common shape (which becomes the shape of the output).
out : ndarray, optional
A location into which the result is stored. If provided, it must have a shape that the
inputs broadcast to. If not provided or None, a freshly-allocated array is returned.

Returns
-------
out : ndarray
Result.

Examples
--------
>>> np.bitwise_xor(13, 17)
28

>>> np.bitwise_xor(31, 5)
26
>>> np.bitwise_xor(np.array([31,3], dtype='int32'), 5)
array([26, 6])

>>> np.bitwise_xor(np.array([31,3], dtype='int32'), np.array([5,6], dtype='int32'))
array([26, 5])
>>> np.bitwise_xor(np.array([True, True], dtype='bool'), np.array([False, True], dtype='bool'))
array([ True, False])
"""
return _ufunc_helper(x1, x2, _npi.bitwise_xor, _np.bitwise_xor, _npi.bitwise_xor_scalar, None, out)


@set_module('mxnet.ndarray.numpy')
@wrap_np_binary_func
def ldexp(x1, x2, out=None, **kwargs):
Expand Down
42 changes: 40 additions & 2 deletions python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@
'tensordot', 'histogram', 'eye', 'linspace', 'logspace', 'expand_dims', 'tile', 'arange',
'split', 'vsplit', 'concatenate', 'stack', 'vstack', 'column_stack', 'dstack', 'mean', 'maximum', 'minimum',
'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming',
'blackman', 'flip', 'around', 'arctan2', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril',
'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less',
'blackman', 'flip', 'around', 'arctan2', 'hypot', 'bitwise_xor', 'rad2deg', 'deg2rad', 'unique', 'lcm',
'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less',
'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero', 'shares_memory',
'may_share_memory', 'diff', 'resize', 'nan_to_num']

Expand Down Expand Up @@ -6198,6 +6198,44 @@ def hypot(x1, x2, out=None, **kwargs):
return _mx_nd_np.hypot(x1, x2, out=out)


@set_module('mxnet.numpy')
@wrap_np_binary_func
def bitwise_xor(x1, x2, out=None, **kwargs):
r"""
Compute the bit-wise XOR of two arrays element-wise.

Parameters
----------
x1, x2 : ndarray or scalar
Only integer and boolean types are handled. If x1.shape != x2.shape,
they must be broadcastable to a common shape (which becomes the shape of the output).
out : ndarray, optional
A location into which the result is stored. If provided, it must have a shape that the
inputs broadcast to. If not provided or None, a freshly-allocated array is returned.

Returns
-------
out : ndarray
Result.

Examples
--------
>>> np.bitwise_xor(13, 17)
28

>>> np.bitwise_xor(31, 5)
26
>>> np.bitwise_xor(np.array([31,3], dtype=np.int32), 5)
array([26, 6])

>>> np.bitwise_xor(np.array([31,3], dtype='int32'), np.array([5,6], dtype='int32'))
array([26, 5])
>>> np.bitwise_xor(np.array([True, True], dtype='bool'), np.array([False, True], dtype='bool'))
array([ True, False])
"""
return _mx_nd_np.bitwise_xor(x1, x2, out=out)


@set_module('mxnet.numpy')
@wrap_np_binary_func
def ldexp(x1, x2, out=None, **kwargs):
Expand Down
1 change: 1 addition & 0 deletions python/mxnet/numpy_dispatch_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ def _register_array_function():
'ceil',
'trunc',
'floor',
'bitwise_xor',
'logical_not',
'equal',
'not_equal',
Expand Down
35 changes: 29 additions & 6 deletions python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
'linspace', 'logspace', 'expand_dims', 'tile', 'arange', 'split', 'vsplit', 'concatenate', 'append',
'stack', 'vstack', 'column_stack', 'dstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax',
'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip',
'around', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take',
'around', 'hypot', 'bitwise_xor', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take',
'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal',
'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'shares_memory', 'may_share_memory', 'diff',
'resize', 'nan_to_num']
Expand Down Expand Up @@ -4058,17 +4058,16 @@ def hypot(x1, x2, out=None, **kwargs):

Parameters
----------
x1, x2 : array_like
x1, x2 : _Symbol or scalar
Leg of the triangle(s).
out : ndarray, None, or tuple of ndarray and None, optional
out : _Symbol or None, optional
A location into which the result is stored. If provided, it must have
a shape that the inputs broadcast to. If not provided or `None`,
a freshly-allocated array is returned. A tuple (possible only as a
keyword argument) must have length equal to the number of outputs.
a freshly-allocated array is returned.

Returns
-------
z : ndarray
z : _Symbol or scalar
The hypotenuse of the triangle(s).
This is a scalar if both `x1` and `x2` are scalars.

Expand All @@ -4080,6 +4079,30 @@ def hypot(x1, x2, out=None, **kwargs):
return _ufunc_helper(x1, x2, _npi.hypot, _np.hypot, _npi.hypot_scalar, None, out)


@set_module('mxnet.symbol.numpy')
@wrap_np_binary_func
def bitwise_xor(x1, x2, out=None, **kwargs):
r"""
Compute the bit-wise XOR of two arrays element-wise.

Parameters
----------
x1, x2 : _Symbol or scalar
Only integer and boolean types are handled. If x1.shape != x2.shape,
they must be broadcastable to a common shape (which becomes the shape of the output).
out : _Symbol or None, optional
A location into which the result is stored. If provided, it must have
a shape that the inputs broadcast to. If not provided or `None`,
a freshly-allocated array is returned.

Returns
-------
out : _Symbol or scalar
Result.
"""
return _ufunc_helper(x1, x2, _npi.bitwise_xor, _np.bitwise_xor, _npi.bitwise_xor_scalar, None, out)


@set_module('mxnet.symbol.numpy')
def unique(ar, return_index=False, return_inverse=False, return_counts=False, axis=None):
"""
Expand Down
3 changes: 2 additions & 1 deletion src/operator/elemwise_op_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,8 @@ inline bool ElemwiseIntType(const nnvm::NodeAttrs& attrs,
CHECK(in_attrs->at(0) == mshadow::kInt64 ||
in_attrs->at(0) == mshadow::kInt32 ||
in_attrs->at(0) == mshadow::kInt8 ||
in_attrs->at(0) == mshadow::kUint8) << "Only supports integer types.";
in_attrs->at(0) == mshadow::kUint8 ||
in_attrs->at(0) == mshadow::kBool) << "Only supports integer types.";
if (n_in != -1) {
CHECK_EQ(in_attrs->size(), static_cast<size_t>(n_in)) << " in operator " << attrs.name;
}
Expand Down
Loading