diff --git a/python/mxnet/_numpy_op_doc.py b/python/mxnet/_numpy_op_doc.py index 6e2f5fa15919..9594297611de 100644 --- a/python/mxnet/_numpy_op_doc.py +++ b/python/mxnet/_numpy_op_doc.py @@ -653,3 +653,50 @@ def _np_trace(a, offset=0, axis1=0, axis2=1, out=None): (2, 3) """ pass + + +def moveaxis(a, source, destination): + """Move axes of an array to new positions. + + Other axes remain in their original order. + + Parameters + ---------- + a : ndarray + The array whose axes should be reordered. + source : int or sequence of int + Original positions of the axes to move. These must be unique. + destination : int or sequence of int + Destination positions for each of the original axes. These must also be + unique. + + Returns + ------- + result : ndarray + Array with moved axes. This array is a view of the input array. + + See Also + -------- + transpose: Permute the dimensions of an array. + swapaxes: Interchange two axes of an array. + + Examples + -------- + >>> x = np.zeros((3, 4, 5)) + >>> np.moveaxis(x, 0, -1).shape + (4, 5, 3) + >>> np.moveaxis(x, -1, 0).shape + (5, 3, 4) + + These all achieve the same result: + + >>> np.transpose(x).shape + (5, 4, 3) + >>> np.swapaxes(x, 0, -1).shape + (5, 4, 3) + >>> np.moveaxis(x, [0, 1], [-1, -2]).shape + (5, 4, 3) + >>> np.moveaxis(x, [0, 1, 2], [-1, -2, -3]).shape + (5, 4, 3) + """ + pass diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index e8332f1a83ef..e0214327c0fc 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -32,10 +32,10 @@ 'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', - 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'mean', + 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'hsplit', 'concatenate', 'stack', 'vstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'rad2deg', 'deg2rad', - 'unique', 'lcm', 'tril', 'identity', 'take'] + 'unique', 'lcm', 'tril', 'identity', 'take', 'rot90'] @set_module('mxnet.ndarray.numpy') @@ -2318,6 +2318,121 @@ def split(ary, indices_or_sections, axis=0): # pylint: enable=redefined-outer-name +# pylint: disable=redefined-outer-name +@set_module('mxnet.ndarray.numpy') +def hsplit(ary, indices_or_sections): + """Split an array into multiple sub-arrays horizontally (column-wise). + + This is equivalent to ``split`` with ``axis=0`` if ``ary`` has one + dimension, and otherwise that with ``axis=1``. + + Parameters + ---------- + ary : ndarray + Array to be divided into sub-arrays. + indices_or_sections : int, list of ints or tuple of ints. + If `indices_or_sections` is an integer, N, the array will be divided + into N equal arrays along `axis`. If such a split is not possible, + an error is raised. + + If `indices_or_sections` is a list of sorted integers, the entries + indicate where along `axis` the array is split. + + If an index exceeds the dimension of the array along `axis`, + it will raises errors. so index must less than or euqal to + the dimension of the array along axis. + + Returns + ------- + sub-arrays : list of ndarrays + A list of sub-arrays. + + Notes + ------ + - If `indices_or_sections` is given as an integer, but a split + does not result in equal division.It will raises ValueErrors. + + - If indices_or_sections is an integer, and the number is 1, it will + raises an error. Because single output from split is not supported yet... + + See Also + -------- + split : Split an array into multiple sub-arrays of equal size. + + Examples + -------- + >>> x = np.arange(16.0).reshape(4, 4) + >>> x + array([[ 0., 1., 2., 3.], + [ 4., 5., 6., 7.], + [ 8., 9., 10., 11.], + [12., 13., 14., 15.]]) + >>> np.hsplit(x, 2) + [array([[ 0., 1.], + [ 4., 5.], + [ 8., 9.], + [12., 13.]]), + array([[ 2., 3.], + [ 6., 7.], + [10., 11.], + [14., 15.]])] + >>> np.hsplit(x, [3, 6]) + [array([[ 0., 1., 2.], + [ 4., 5., 6.], + [ 8., 9., 10.], + [12., 13., 14.]]), + array([[ 3.], + [ 7.], + [11.], + [15.]]), + array([], shape=(4, 0), dtype=float32)] + + With a higher dimensional array the split is still along the second axis. + + >>> x = np.arange(8.0).reshape(2, 2, 2) + >>> x + array([[[ 0., 1.], + [ 2., 3.]], + [[ 4., 5.], + [ 6., 7.]]]) + >>> np.hsplit(x, 2) + [array([[[ 0., 1.]], + [[ 4., 5.]]]), + array([[[ 2., 3.]], + [[ 6., 7.]]])] + + If ``ary`` has one dimension, 'axis' = 0. + >>> x = np.arange(4) + array([0., 1., 2., 3.]) + >>> np.hsplit(x, 2) + [array([0., 1.]), array([2., 3.])] + + If you want to produce an empty sub-array, you can see an example. + >>> np.hsplit(x, [2, 2]) + [array([0., 1.]), array([], dtype=float32), array([2., 3.])] + """ + indices = [] + axis = 1 + if (len(ary.shape) == 1): + axis = 0 + axis_size = ary.shape[axis] + if isinstance(indices_or_sections, int): + sections = indices_or_sections + if axis_size % sections: + raise ValueError('array hsplit does not result in an equal division') + section_size = int(axis_size / sections) + indices = [i * section_size for i in range(sections)] + elif isinstance(indices_or_sections, (list, set, tuple)): + indices = [0] + list(indices_or_sections) + else: + raise ValueError('indices_or_sections must either int or tuple of ints') + ret = _npi.hsplit(ary, indices, axis, False) + if not isinstance(ret, list): + raise NotImplementedError('single output from hsplit is not supported yet...') + return ret +# pylint: enable=redefined-outer-name + + @set_module('mxnet.ndarray.numpy') def concatenate(seq, axis=0, out=None): """Join a sequence of arrays along an existing axis. @@ -3464,3 +3579,49 @@ def hypot(x1, x2, out=None): [ 5., 5., 5.]]) """ return _ufunc_helper(x1, x2, _npi.hypot, _np.hypot, _npi.hypot_scalar, None, out) + + +@set_module('mxnet.ndarray.numpy') +def rot90(m, k=1, axes=(0, 1)): + """ + Rotate an array by 90 degrees in the plane specified by axes. + Rotation direction is from the first towards the second axis. + Parameters + ---------- + m : ndarray + Array of two or more dimensions. + k : integer + Number of times the array is rotated by 90 degrees. + axes: (2,) array_like + The array is rotated in the plane defined by the axes. + Axes must be different. + + Returns + ------- + y : ndarray + A rotated view of `m`. + + ----- + rot90(m, k=1, axes=(1,0)) is the reverse of rot90(m, k=1, axes=(0,1)) + rot90(m, k=1, axes=(1,0)) is equivalent to rot90(m, k=-1, axes=(0,1)) + Examples + -------- + >>> m = np.array([[1,2],[3,4]], 'int') + >>> m + array([[1, 2], + [3, 4]], dtype=int64) + >>> np.rot90(m) + array([[2, 4], + [1, 3]], dtype=int64) + >>> np.rot90(m, 2) + array([[4, 3], + [2, 1]], dtype=int64) + >>> m = np.arange(8).reshape((2,2,2)) + >>> np.rot90(m, 1, (1,2)) + array([[[1., 3.], + [0., 2.]], + + [[5., 7.], + [4., 6.]]]) + """ + return _npi.rot90(m, k=k, axes=axes) diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index 7ba0f0d7d813..6de1d438ac5e 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -51,10 +51,10 @@ 'sqrt', 'cbrt', 'abs', 'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', - 'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', + 'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'hsplit', 'concatenate', 'stack', 'vstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'arctan2', 'hypot', - 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take'] + 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'rot90'] # Return code for dispatching indexing function call _NDARRAY_UNSUPPORTED_INDEXING = -1 @@ -3881,6 +3881,101 @@ def split(ary, indices_or_sections, axis=0): return _mx_nd_np.split(ary, indices_or_sections, axis=axis) +@set_module('mxnet.numpy') +def hsplit(ary, indices_or_sections): + """Split an array into multiple sub-arrays horizontally (column-wise). + + This is equivalent to ``split`` with ``axis=0`` if ``ary`` has one + dimension, and otherwise that with ``axis=1``. + + Parameters + ---------- + ary : ndarray + Array to be divided into sub-arrays. + indices_or_sections : int, list of ints or tuple of ints. + If `indices_or_sections` is an integer, N, the array will be divided + into N equal arrays along `axis`. If such a split is not possible, + an error is raised. + + If `indices_or_sections` is a list of sorted integers, the entries + indicate where along `axis` the array is split. + + If an index exceeds the dimension of the array along `axis`, + it will raises errors. so index must less than or euqal to + the dimension of the array along axis. + + Returns + ------- + sub-arrays : list of ndarrays + A list of sub-arrays. + + Notes + ------ + - If `indices_or_sections` is given as an integer, but a split + does not result in equal division.It will raises ValueErrors. + + - If indices_or_sections is an integer, and the number is 1, it will + raises an error. Because single output from split is not supported yet... + + See Also + -------- + split : Split an array into multiple sub-arrays of equal size. + + Examples + -------- + >>> x = np.arange(16.0).reshape(4, 4) + >>> x + array([[ 0., 1., 2., 3.], + [ 4., 5., 6., 7.], + [ 8., 9., 10., 11.], + [12., 13., 14., 15.]]) + >>> np.hsplit(x, 2) + [array([[ 0., 1.], + [ 4., 5.], + [ 8., 9.], + [12., 13.]]), + array([[ 2., 3.], + [ 6., 7.], + [10., 11.], + [14., 15.]])] + >>> np.hsplit(x, [3, 6]) + [array([[ 0., 1., 2.], + [ 4., 5., 6.], + [ 8., 9., 10.], + [12., 13., 14.]]), + array([[ 3.], + [ 7.], + [11.], + [15.]]), + array([], shape=(4, 0), dtype=float32)] + + With a higher dimensional array the split is still along the second axis. + + >>> x = np.arange(8.0).reshape(2, 2, 2) + >>> x + array([[[ 0., 1.], + [ 2., 3.]], + [[ 4., 5.], + [ 6., 7.]]]) + >>> np.hsplit(x, 2) + [array([[[ 0., 1.]], + [[ 4., 5.]]]), + array([[[ 2., 3.]], + [[ 6., 7.]]])] + + If ``ary`` has one dimension, 'axis' = 0. + >>> x = np.arange(4) + array([0., 1., 2., 3.]) + >>> np.hsplit(x, 2) + [array([0., 1.]), array([2., 3.])] + + If you want to produce an empty sub-array, you can see an example. + >>> np.hsplit(x, [2, 2]) + [array([0., 1.]), array([], dtype=float32), array([2., 3.])] + """ + return _mx_nd_np.hsplit(ary, indices_or_sections) + + @set_module('mxnet.numpy') def concatenate(seq, axis=0, out=None): """Join a sequence of arrays along an existing axis. @@ -4980,3 +5075,49 @@ def hypot(x1, x2, out=None): [ 5., 5., 5.]]) """ return _mx_nd_np.hypot(x1, x2, out=out) + + +@set_module('mxnet.numpy') +def rot90(m, k=1, axes=(0, 1)): + """ + Rotate an array by 90 degrees in the plane specified by axes. + Rotation direction is from the first towards the second axis. + Parameters + ---------- + m : ndarray + Array of two or more dimensions. + k : integer + Number of times the array is rotated by 90 degrees. + axes: (2,) array_like + The array is rotated in the plane defined by the axes. + Axes must be different. + + Returns + ------- + y : ndarray + A rotated view of `m`. + + ----- + rot90(m, k=1, axes=(1,0)) is the reverse of rot90(m, k=1, axes=(0,1)) + rot90(m, k=1, axes=(1,0)) is equivalent to rot90(m, k=-1, axes=(0,1)) + Examples + -------- + >>> m = np.array([[1,2],[3,4]], 'int') + >>> m + array([[1, 2], + [3, 4]], dtype=int64) + >>> np.rot90(m) + array([[2, 4], + [1, 3]], dtype=int64) + >>> np.rot90(m, 2) + array([[4, 3], + [2, 1]], dtype=int64) + >>> m = np.arange(8).reshape((2,2,2)) + >>> np.rot90(m, 1, (1,2)) + array([[[1., 3.], + [0., 2.]], + + [[5., 7.], + [4., 6.]]]) + """ + return _mx_nd_np.rot90(m, k=k, axes=axes) diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index 9c055c401b31..9b7264b37215 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -34,10 +34,10 @@ 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', - 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'mean', + 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'hsplit', 'concatenate', 'stack', 'vstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'rad2deg', 'deg2rad', - 'unique', 'lcm', 'tril', 'identity', 'take'] + 'unique', 'lcm', 'tril', 'identity', 'take', 'rot90'] def _num_outputs(sym): @@ -2565,6 +2565,112 @@ def split(ary, indices_or_sections, axis=0): # pylint: enable=redefined-outer-name +# pylint: disable=redefined-outer-name +@set_module('mxnet.ndarray.numpy') +def hsplit(ary, indices_or_sections): + """Split an array into multiple sub-arrays horizontally (column-wise). + + This is equivalent to ``split`` with ``axis=0`` if ``ary`` has one + dimension, and otherwise that with ``axis=1``. + + Parameters + ---------- + ary : _Symbol + Array to be divided into sub-arrays. + indices_or_sections : int, list of ints or tuple of ints. + If `indices_or_sections` is an integer, N, the array will be divided + into N equal arrays along `axis`. If such a split is not possible, + an error is raised. + + If `indices_or_sections` is a list of sorted integers, the entries + indicate where along `axis` the array is split. + + If an index exceeds the dimension of the array along `axis`, + it will raises errors. so index must less than or euqal to + the dimension of the array along axis. + + Returns + ------- + sub-arrays : _Symbol + A list of sub-arrays. + + Notes + ------ + - If `indices_or_sections` is given as an integer, but a split + does not result in equal division.It will raises ValueErrors. + + - If indices_or_sections is an integer, and the number is 1, it will + raises an error. Because single output from split is not supported yet... + + See Also + -------- + split : Split an array into multiple sub-arrays of equal size. + + Examples + -------- + >>> x = np.arange(16.0).reshape(4, 4) + >>> x + array([[ 0., 1., 2., 3.], + [ 4., 5., 6., 7.], + [ 8., 9., 10., 11.], + [12., 13., 14., 15.]]) + >>> np.hsplit(x, 2) + [array([[ 0., 1.], + [ 4., 5.], + [ 8., 9.], + [12., 13.]]), + array([[ 2., 3.], + [ 6., 7.], + [10., 11.], + [14., 15.]])] + >>> np.hsplit(x, [3, 6]) + [array([[ 0., 1., 2.], + [ 4., 5., 6.], + [ 8., 9., 10.], + [12., 13., 14.]]), + array([[ 3.], + [ 7.], + [11.], + [15.]]), + array([], shape=(4, 0), dtype=float32)] + + With a higher dimensional array the split is still along the second axis. + + >>> x = np.arange(8.0).reshape(2, 2, 2) + >>> x + array([[[ 0., 1.], + [ 2., 3.]], + [[ 4., 5.], + [ 6., 7.]]]) + >>> np.hsplit(x, 2) + [array([[[ 0., 1.]], + [[ 4., 5.]]]), + array([[[ 2., 3.]], + [[ 6., 7.]]])] + + If ``ary`` has one dimension, 'axis' = 0. + >>> x = np.arange(4) + array([0., 1., 2., 3.]) + >>> np.hsplit(x, 2) + [array([0., 1.]), array([2., 3.])] + + If you want to produce an empty sub-array, you can see an example. + >>> np.hsplit(x, [2, 2]) + [array([0., 1.]), array([], dtype=float32), array([2., 3.])] + """ + indices = [] + sections = 0 + if isinstance(indices_or_sections, int): + sections = indices_or_sections + elif isinstance(indices_or_sections, (list, set, tuple)): + indices = [0] + list(indices_or_sections) + else: + raise ValueError('indices_or_sections must either int or tuple of ints') + ret = _npi.hsplit(ary, indices, 1, False, sections) + return ret +# pylint: enable=redefined-outer-name + + @set_module('mxnet.symbol.numpy') def concatenate(seq, axis=0, out=None): """Join a sequence of arrays along an existing axis. @@ -3552,4 +3658,50 @@ def unique(ar, return_index=False, return_inverse=False, return_counts=False, ax return _npi.unique(ar, return_index, return_inverse, return_counts, axis) +@set_module('mxnet.symbol.numpy') +def rot90(m, k=1, axes=(0, 1)): + """ + Rotate an array by 90 degrees in the plane specified by axes. + Rotation direction is from the first towards the second axis. + Parameters + ---------- + m : _Symbol + Array of two or more dimensions. + k : integer + Number of times the array is rotated by 90 degrees. + axes: (2,) array_like + The array is rotated in the plane defined by the axes. + Axes must be different. + + Returns + ------- + y : _Symbol + A rotated view of `m`. + + ----- + rot90(m, k=1, axes=(1,0)) is the reverse of rot90(m, k=1, axes=(0,1)) + rot90(m, k=1, axes=(1,0)) is equivalent to rot90(m, k=-1, axes=(0,1)) + Examples + -------- + >>> m = np.array([[1,2],[3,4]], 'int') + >>> m + array([[1, 2], + [3, 4]], dtype=int64) + >>> np.rot90(m) + array([[2, 4], + [1, 3]], dtype=int64) + >>> np.rot90(m, 2) + array([[4, 3], + [2, 1]], dtype=int64) + >>> m = np.arange(8).reshape((2,2,2)) + >>> np.rot90(m, 1, (1,2)) + array([[[1., 3.], + [0., 2.]], + + [[5., 7.], + [4., 6.]]]) + """ + return _npi.rot90(m, k=k, axes=axes) + + _set_np_symbol_class(_Symbol) diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h index 92655c146193..a2bd02b224e0 100644 --- a/src/operator/mshadow_op.h +++ b/src/operator/mshadow_op.h @@ -353,6 +353,8 @@ MXNET_BINARY_MATH_OP(logical_or, a || b ? DType(1) : DType(0)); MXNET_BINARY_MATH_OP(logical_xor, (a || b) && !(a && b) ? DType(1) : DType(0)); +MXNET_BINARY_MATH_OP(bitwise_xor, static_cast(a) ^ static_cast(b)); + MXNET_UNARY_MATH_OP(square_root, math::sqrt(a)); MXNET_UNARY_MATH_OP(square_root_grad, 0.5f / math::id(a)); diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cu b/src/operator/numpy/np_elemwise_broadcast_op.cu index 806debf431b5..ab9872eb66b8 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op.cu +++ b/src/operator/numpy/np_elemwise_broadcast_op.cu @@ -27,6 +27,7 @@ namespace mxnet { namespace op { + NNVM_REGISTER_OP(_npi_add) .set_attr("FCompute", BinaryBroadcastCompute); @@ -118,5 +119,8 @@ NNVM_REGISTER_OP(_backward_npi_rarctan2_scalar) NNVM_REGISTER_OP(_npi_lcm_scalar) .set_attr("FCompute", BinaryScalarOp::Compute); +NNVM_REGISTER_OP(_np_bitwise_xor) +.set_attr("FCompute", BinaryBroadcastCompute); + } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/np_matrix_op-inl.h b/src/operator/numpy/np_matrix_op-inl.h index 5e25192d9298..c6181df506d9 100644 --- a/src/operator/numpy/np_matrix_op-inl.h +++ b/src/operator/numpy/np_matrix_op-inl.h @@ -363,6 +363,373 @@ void NumpyFlipForward(const nnvm::NodeAttrs& attrs, NumpyFlipForwardImpl(ctx, inputs, outputs, stride_, trailing_, flip_index); } +struct NumpyMoveaxisParam : public dmlc::Parameter { + mxnet::TShape source; + mxnet::TShape destination; + DMLC_DECLARE_PARAMETER(NumpyMoveaxisParam) { + DMLC_DECLARE_FIELD(source) + .describe("Original positions of the axes to move. These must be unique."); + DMLC_DECLARE_FIELD(destination) + .describe("Destination positions for each of the original axes. " + "These must also be unique."); + } +}; + +inline mxnet::TShape NumpyMoveaxisShapeImpl(const nnvm::NodeAttrs& attrs, + const int& ndim) { + const NumpyMoveaxisParam& param = nnvm::get(attrs.parsed); + mxnet::TShape axes(ndim, -1); + std::vector state_axes(ndim, false); + mxnet::TShape real_src(param.source.ndim(), -1); + mxnet::TShape real_des(param.destination.ndim(), -1); + for (int i = 0; i < param.source.ndim(); ++i) { + if (param.source[i] >= 0) { + CHECK_LT(static_cast(param.source[i]), ndim); + real_src[i] = param.source[i]; + } else { + CHECK_LT(param.source[i] + ndim, ndim); + real_src[i] = param.source[i] + ndim; + } + if (param.destination[i] >= 0) { + CHECK_LT(static_cast(param.destination[i]), ndim); + real_des[i] = param.destination[i]; + } else { + CHECK_LT(param.destination[i] + ndim, ndim); + real_des[i] = param.destination[i] + ndim; + } + } + if (ndim > 1) { + for (int i = 0; i < param.source.ndim() - 1; ++i) { + for (int j = i + 1; j < param.source.ndim(); ++j) { + CHECK_NE(real_src[i], real_src[j]) + << "repeated axis in `source` argument"; + CHECK_NE(real_des[i], real_des[j]) + << "repeated axis in `destination` argument"; + } + } + } + for (int i = 0; i < param.source.ndim(); ++i) { + axes[real_des[i]] = real_src[i]; + state_axes[real_src[i]] = true; + } + for (int i = 0; i < axes.ndim(); ++i) { + if (axes[i] < 0) { + for (int j = 0; j < axes.ndim(); ++j) { + if (state_axes[j] == false) { + axes[i] = j; + state_axes[j] = true; + break; + } + } + } + } + return axes; +} + +template +void NumpyMoveaxisCompute(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow::expr; + const NumpyMoveaxisParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 1U); + CHECK_EQ(req[0], kWriteTo) << "Moveaxis does not support inplace"; + CHECK_EQ(param.source.ndim(), param.destination.ndim()) + << "source and destination not equal."; + mxnet::TShape axes; + axes = NumpyMoveaxisShapeImpl(attrs, inputs[0].ndim()); + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, Dtype, { + TransposeImpl(ctx.run_ctx, inputs[0], outputs[0], axes); + }) +} + +struct NumpyRot90Param : public dmlc::Parameter { + int k; + dmlc::optional axes; + DMLC_DECLARE_PARAMETER(NumpyRot90Param) { + DMLC_DECLARE_FIELD(k) + .set_default(1) + .describe("Number of times the array is rotated by 90 degrees."); + DMLC_DECLARE_FIELD(axes) + .set_default(dmlc::optional()) + .describe(" The array is rotated in the plane defined by the axes. Axes must be different."); + } +}; + +struct rot90reverse { + MSHADOW_XINLINE static index_t ReverseIndex(index_t idx, + index_t nreversedim, + const index_t * stride_, + const index_t * trailing_) { + index_t outputIndex = idx; + for (index_t i = 0; i < nreversedim; ++i) { + const index_t low = outputIndex % trailing_[i]; + index_t high = outputIndex / trailing_[i]; + const index_t x = high % stride_[i]; + high /= stride_[i]; + outputIndex = (high * stride_[i] + stride_[i] - 1 - x) * trailing_[i] + low; + } + return outputIndex; + } + template + MSHADOW_XINLINE static void Map(index_t index, index_t nreversedim, const DType *src, DType *dst, + const index_t * stride_, + const index_t * trailing_) { + index_t new_idx = ReverseIndex(index, nreversedim, stride_, trailing_); + dst[new_idx] = src[index]; + } +}; + +template +void NumpyRot90ComputeFlipIml(const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs, + const index_t axis0, const index_t axis1) { + using namespace mshadow; + using namespace mxnet_op; + + const mxnet::TShape& ishape = inputs[0].shape_; + Stream *s = ctx.get_stream(); + + std::vector stride_(2); + std::vector trailing_(2); + index_t reverse_index = 0; + std::vector temp{axis0, axis1}; + for (int axis : temp) { + stride_[reverse_index] = ishape[axis]; + trailing_[reverse_index] = 1; + for (int i2 = axis + 1; i2 < ishape.ndim(); ++i2) { + trailing_[reverse_index] *= ishape[i2]; + } + reverse_index++; + } + + index_t workspace_size = 2 * sizeof(index_t); + Tensor workspace = + ctx.requested[0].get_space_typed(Shape1(2 * workspace_size), s); + Tensor stride_cpu_tensor(stride_.data(), Shape1(stride_.size())); + Tensor stride_xpu_tensor( + reinterpret_cast(workspace.dptr_), Shape1(stride_.size())); + Tensor trailing_cpu_tensor(trailing_.data(), Shape1(trailing_.size())); + Tensor trailing_xpu_tensor( + reinterpret_cast(workspace.dptr_ + workspace_size), Shape1(trailing_.size())); + + mshadow::Copy(stride_xpu_tensor, stride_cpu_tensor, s); + mshadow::Copy(trailing_xpu_tensor, trailing_cpu_tensor, s); + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { + Kernel::Launch(s, inputs[0].Size(), reverse_index, + inputs[0].dptr(), outputs[0].dptr(), + stride_xpu_tensor.dptr_, trailing_xpu_tensor.dptr_); + }); +} + +struct rot90Transreverse { + MSHADOW_XINLINE static index_t ReverseIndex(index_t idx, + const index_t stride_, + const index_t trailing_) { + index_t outputIndex = idx; + const index_t low = outputIndex % trailing_; + index_t high = outputIndex / trailing_; + const index_t x = high % stride_; + high /= stride_; + outputIndex = (high * stride_ + stride_ - 1 - x) * trailing_ + low; + + return outputIndex; + } + template + MSHADOW_XINLINE static void Map(index_t index, const DType *src, DType *dst, + const index_t stride_, + const index_t trailing_) { + index_t new_idx = ReverseIndex(index, stride_, trailing_); + dst[new_idx] = src[index]; + } +}; + +template +void NumpyRot90ComputeFlipTransposeIml(const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs, + const mxnet::TShape axes_list, + const index_t axis) { + using namespace mshadow; + using namespace mxnet_op; + + const mxnet::TShape& ishape = inputs[0].shape_; + Stream *s = ctx.get_stream(); + + index_t stride_; + index_t trailing_; + + stride_ = ishape[axis]; + trailing_ = 1; + for (int i2 = axis + 1; i2 < ishape.ndim(); ++i2) { + trailing_ *= ishape[i2]; + } + + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { + index_t workspace_size = inputs[0].Size() * sizeof(DType); + Tensor workspace = + ctx.requested[0].get_space_typed(Shape1(workspace_size), s); + DType* data_ptr = reinterpret_cast(workspace.dptr_); + TBlob mid_data = TBlob(data_ptr, inputs[0].shape_, xpu::kDevMask); + Kernel::Launch(s, inputs[0].Size(), inputs[0].dptr(), + mid_data.dptr(), + stride_, trailing_); + mxnet::op::TransposeImpl(ctx.run_ctx, mid_data, outputs[0], axes_list); + }); +} + + +template +void NumpyRot90ComputeTransposeFlipIml(const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs, + const mxnet::TShape axes_list, + const index_t axis) { + using namespace mshadow; + using namespace mxnet_op; + + Stream *s = ctx.get_stream(); + + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { + index_t workspace_size = inputs[0].Size() * sizeof(DType); + Tensor workspace = + ctx.requested[0].get_space_typed(Shape1(workspace_size), s); + DType* data_ptr = reinterpret_cast(workspace.dptr_); + mxnet::TShape mid_shape(outputs[0].shape_); + TBlob mid_data = TBlob(data_ptr, mid_shape, xpu::kDevMask); + mxnet::op::TransposeImpl(ctx.run_ctx, inputs[0], mid_data, axes_list); + + index_t stride_; + index_t trailing_; + stride_ = mid_shape[axis]; + trailing_ = 1; + for (int i2 = axis + 1; i2 < mid_shape.ndim(); ++i2) { + trailing_ *= mid_shape[i2]; + } + Kernel::Launch(s, mid_data.Size(), mid_data.dptr(), + outputs[0].dptr(), + stride_, trailing_); + }); +} + +template +struct rot90 { + template + MSHADOW_XINLINE static void Map(index_t i, const DType *in_data, DType *out_data) { + KERNEL_ASSIGN(out_data[i], req, in_data[i]); + } +}; + +template +void NumpyRot90Compute(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mxnet_op; + const NumpyRot90Param& param = nnvm::get(attrs.parsed); + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 1U); + CHECK_EQ(inputs[0].type_flag_, outputs[0].type_flag_); + if (outputs[0].Size() == 0) return; + Stream *s = ctx.get_stream(); + + int real_k(param.k); + real_k = real_k % 4; + if (real_k < 0) { + real_k += 4; + } + + // axis has value + mxnet::TShape real_axes(param.axes.value()); + for (index_t i = 0; i < real_axes.ndim(); i++) { + if (real_axes[i] < 0) { + real_axes[i] += inputs[0].shape_.ndim(); + } + } + if (real_k == 0) { + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { + MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { + Kernel, xpu>::Launch(s, inputs[0].Size(), inputs[0].dptr(), + outputs[0].dptr()); + }); + }); + } else if (real_k == 2) { + NumpyRot90ComputeFlipIml(ctx, inputs, req, outputs, real_axes[0], real_axes[1]); + } else if (real_k == 1) { + mxnet::TShape axes_list(inputs[0].shape_.ndim(), -1); + for (int i = 0; i < inputs[0].shape_.ndim(); ++i) { + axes_list[i] = i; + } + axes_list[real_axes[0]] += axes_list[real_axes[1]]; + axes_list[real_axes[1]] = axes_list[real_axes[0]] - axes_list[real_axes[1]]; + axes_list[real_axes[0]] -= axes_list[real_axes[1]]; + NumpyRot90ComputeFlipTransposeIml(ctx, inputs, req, outputs, axes_list, real_axes[1]); + } else if (real_k == 3) { + mxnet::TShape axes_list(inputs[0].shape_.ndim(), -1); + for (int i = 0; i < inputs[0].shape_.ndim(); ++i) { + axes_list[i] = i; + } + axes_list[real_axes[0]] += axes_list[real_axes[1]]; + axes_list[real_axes[1]] = axes_list[real_axes[0]] - axes_list[real_axes[1]]; + axes_list[real_axes[0]] -= axes_list[real_axes[1]]; + NumpyRot90ComputeTransposeFlipIml(ctx, inputs, req, outputs, axes_list, real_axes[1]); + } +} + +template +inline void HSplitOpForward(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mshadow; + using namespace mshadow::expr; + using namespace mxnet_op; + const SplitParam ¶m = nnvm::get(attrs.parsed); + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), param.sections > 0 ? param.sections : param.indices.ndim()); + const TBlob &input_data = inputs[split_enum::kData]; + int real_axis; + if (input_data.ndim() > 1) { + real_axis = 1; + } else { + real_axis = 0; + } + SplitOpForwardImpl(attrs, ctx, inputs, req, outputs, real_axis); +} + +template +inline void HSplitOpBackward(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mshadow; + using namespace mshadow::expr; + using namespace mxnet_op; + const SplitParam ¶m = nnvm::get(attrs.parsed); + CHECK_EQ(inputs.size(), (param.sections > 0) ? param.sections : param.indices.ndim()) + << "out grad vector size mush match the output size"; + CHECK_EQ(outputs.size(), 1U); + int real_axis; + if (outputs[split_enum::kData].ndim() > 1) { + real_axis = 1; + } else { + real_axis = 0; + } + SplitOpBackwardImpl(attrs, ctx, inputs, req, outputs, real_axis); +} + } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/np_matrix_op.cc b/src/operator/numpy/np_matrix_op.cc index 96a10561be28..5fb82f03a671 100644 --- a/src/operator/numpy/np_matrix_op.cc +++ b/src/operator/numpy/np_matrix_op.cc @@ -32,6 +32,8 @@ namespace op { DMLC_REGISTER_PARAMETER(NumpyTransposeParam); DMLC_REGISTER_PARAMETER(NumpyRollParam); +DMLC_REGISTER_PARAMETER(NumpyMoveaxisParam); +DMLC_REGISTER_PARAMETER(NumpyRot90Param); bool NumpyTransposeShape(const nnvm::NodeAttrs& attrs, mxnet::ShapeVector *in_attrs, @@ -612,5 +614,171 @@ NNVM_REGISTER_OP(_backward_npi_flip) }) .set_attr("FCompute", NumpyFlipForward); +bool NumpyMoveaxisShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector *in_attrs, + mxnet::ShapeVector *out_attrs) { + const NumpyMoveaxisParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + mxnet::TShape& shp = (*in_attrs)[0]; + CHECK_LE(shp.ndim(), 6) << "Transpose support at most 6 dimensions"; + CHECK_EQ(param.source.ndim(), param.destination.ndim()) + << "source and destination not equal."; + mxnet::TShape ret(shp.ndim(), -1); + mxnet::TShape axes; + axes = NumpyMoveaxisShapeImpl(attrs, shp.ndim()); + for (int i = 0; i < shp.ndim(); ++i) { + CHECK(axes[i] < static_cast(shp.ndim())); + ret[i] = shp[axes[i]]; + } + SHAPE_ASSIGN_CHECK(*out_attrs, 0, ret); + return shape_is_known(ret); +} + +NNVM_REGISTER_OP(_np_moveaxis) +.describe(R"code(Move axes of an array to new positions. +Other axes remain in their original order. +)code" ADD_FILELINE) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", NumpyMoveaxisShape) +.set_attr("FInferType", ElemwiseType<1, 1>) +.set_attr("FGradient", + [](const nnvm::NodePtr& n, const std::vector& ograds) { + const NumpyMoveaxisParam& param = nnvm::get(n->attrs.parsed); + std::ostringstream os1; + os1 << param.source; + std::ostringstream os2; + os2 << param.destination; + return MakeNonlossGradNode("_np_moveaxis", n, ograds, {}, + {{"source", os2.str()}, {"destination", os1.str()}}); +}) +.set_attr("FCompute", NumpyMoveaxisCompute) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { return std::vector{"a"}; +}) +.add_argument("a", "NDArray-or-Symbol", "Source input") +.add_arguments(NumpyMoveaxisParam::__FIELDS__()); + +inline bool NumpyRot90Shape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector *in_attrs, + mxnet::ShapeVector *out_attrs) { + using namespace mshadow; + const NumpyRot90Param& param = nnvm::get(attrs.parsed); + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + mxnet::TShape& shp = (*in_attrs)[0]; + if (!param.axes.has_value() || (param.axes.has_value() && param.axes.value().ndim() != 2)) { + LOG(FATAL) << "The length of axes must be 2."; + } + int real_k(param.k); + real_k = real_k % 4; + if (real_k < 0) { + real_k += 4; + } + + mxnet::TShape res(shp); + mxnet::TShape real_axes(param.axes.value()); + for (index_t i = 0; i < real_axes.ndim(); i++) { + if (real_axes[i] < 0) { + real_axes[i] += shp.ndim(); + } + } + + CHECK_NE(real_axes[0], real_axes[1]) + << "axes have duplicates " + << real_axes; + if (real_axes[0] > shp.ndim() || real_axes[1] > shp.ndim() || + real_axes[0] < 0 || real_axes[1] < 0) { + LOG(FATAL) << "Axes out of range for array of ndim"; + } + + if (real_k % 2 == 0) { + SHAPE_ASSIGN_CHECK(*out_attrs, 0, res); + return shape_is_known(res); + } + + res[real_axes[0]] += res[real_axes[1]]; + res[real_axes[1]] = res[real_axes[0]] - res[real_axes[1]]; + res[real_axes[0]] -= res[real_axes[1]]; + SHAPE_ASSIGN_CHECK(*out_attrs, 0, res); + return shape_is_known(res); +} + +NNVM_REGISTER_OP(_npi_rot90) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"data"}; +}) +.set_attr("FInferShape", NumpyRot90Shape) +.set_attr("FInferType", ElemwiseType<1, 1>) +.set_attr("FCompute", NumpyRot90Compute) +.set_attr("FGradient", + [](const nnvm::NodePtr& n, const std::vector& ograds) { + const NumpyRot90Param& param = nnvm::get(n->attrs.parsed); + std::ostringstream os1; + os1 << param.k; + std::ostringstream os2; + os2 << param.axes; + return MakeNonlossGradNode("_npi_rot90", n, ograds, {}, + {{"k", os1.str()}, {"axes", os2.str()}}); +}) +.set_attr("FResourceRequest", + [](const NodeAttrs& n) { + return std::vector{ResourceRequest::kTempSpace}; +}) +.add_argument("data", "NDArray-or-Symbol", "Input ndarray") +.add_arguments(NumpyRot90Param::__FIELDS__()); + +inline bool HSplitOpShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector* in_attrs, + mxnet::ShapeVector* out_attrs) { + using namespace mshadow; + CHECK_EQ(in_attrs->size(), 1U); + mxnet::TShape dshape = in_attrs->at(split_enum::kData); + if (!mxnet::ndim_is_known(dshape)) return false; + int real_axis; + if (dshape.ndim() > 1) { + real_axis = 1; + } else { + real_axis = 0; + } + return SplitOpShapeImpl(attrs, in_attrs, out_attrs, real_axis); +} + +NNVM_REGISTER_OP(_npi_hsplit) +.set_attr_parser(ParamParser) +.set_num_inputs(1) +.set_num_outputs(SplitNumOutputs) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"data"}; +}) +.set_attr("FInferShape", HSplitOpShape) +.set_attr("FInferType", SplitOpType) +.set_attr("FCompute", HSplitOpForward) +.set_attr("FResourceRequest", + [](const NodeAttrs& n) { + return std::vector{ResourceRequest::kTempSpace}; +}) +.set_attr("FGradient", ElemwiseGradUseNone{"_npi_hsplit_backward"}) +.add_argument("data", "NDArray-or-Symbol", "The input") +.add_arguments(SplitParam::__FIELDS__()); + +NNVM_REGISTER_OP(_npi_hsplit_backward) +.set_attr_parser(ParamParser) +.set_num_inputs(SplitNumOutputs) +.set_num_outputs(1) +.set_attr("TIsBackward", true) +.set_attr("FResourceRequest", + [](const NodeAttrs& n) { + return std::vector{ResourceRequest::kTempSpace}; +}) +.set_attr("FCompute", HSplitOpBackward); + } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/np_matrix_op.cu b/src/operator/numpy/np_matrix_op.cu index caab4108b40e..fed186263910 100644 --- a/src/operator/numpy/np_matrix_op.cu +++ b/src/operator/numpy/np_matrix_op.cu @@ -90,5 +90,18 @@ NNVM_REGISTER_OP(_npi_flip) NNVM_REGISTER_OP(_backward_npi_flip) .set_attr("FCompute", NumpyFlipForward); + +NNVM_REGISTER_OP(_np_moveaxis) +.set_attr("FCompute", NumpyMoveaxisCompute); + +NNVM_REGISTER_OP(_npi_rot90) +.set_attr("FCompute", NumpyRot90Compute); + +NNVM_REGISTER_OP(_npi_hsplit) +.set_attr("FCompute", HSplitOpForward); + +NNVM_REGISTER_OP(_npi_hsplit_backward) +.set_attr("FCompute", HSplitOpBackward); + } // namespace op } // namespace mxnet diff --git a/src/operator/operator_tune.cc b/src/operator/operator_tune.cc index b81cd78ad507..4a681c110434 100644 --- a/src/operator/operator_tune.cc +++ b/src/operator/operator_tune.cc @@ -367,6 +367,7 @@ IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::logical_or); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::logical_or); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::logical_xor); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::logical_xor); // NOLINT() +IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::bitwise_xor); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::smooth_l1_loss); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::smooth_l1_gradient); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::lcm); // NOLINT() diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index 5a2bd036c22b..43d86337f445 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -2646,24 +2646,14 @@ inline bool SplitOpType(const nnvm::NodeAttrs& attrs, return true; } -inline bool SplitOpShape(const nnvm::NodeAttrs& attrs, - mxnet::ShapeVector* in_attrs, - mxnet::ShapeVector* out_attrs) { +inline bool SplitOpShapeImpl(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector* in_attrs, + mxnet::ShapeVector* out_attrs, + const int real_axis) { using namespace mshadow; const SplitParam& param = nnvm::get(attrs.parsed); - CHECK_EQ(in_attrs->size(), 1U); mxnet::TShape dshape = in_attrs->at(split_enum::kData); mxnet::TShape ishape = in_attrs->at(split_enum::kData); - if (!mxnet::ndim_is_known(dshape)) return false; - if (param.axis >= 0) { - CHECK_LT(param.axis, dshape.ndim()); - } else { - CHECK_LT(param.axis + dshape.ndim(), dshape.ndim()); - } - int real_axis = param.axis; - if (real_axis < 0) { - real_axis += dshape.ndim(); - } const mxnet::TShape indices = (param.sections > 0) ? GetSplitIndices(ishape, real_axis, param.sections) : param.indices; int num_outputs = (param.sections > 0) ? indices.ndim() - 1 : indices.ndim(); @@ -2680,7 +2670,7 @@ inline bool SplitOpShape(const nnvm::NodeAttrs& attrs, if (ishape[real_axis] == 0U) { end = start; } else { - CHECK(start < end) + CHECK(start <= end) << "start " << start << " is not less than end " << end << "for subarray " << i; CHECK(end <= ishape[real_axis]) << "end " << end << " is no less than the size of the axis " << ishape[real_axis]; @@ -2716,6 +2706,26 @@ inline bool SplitOpShape(const nnvm::NodeAttrs& attrs, return true; } +inline bool SplitOpShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector* in_attrs, + mxnet::ShapeVector* out_attrs) { + using namespace mshadow; + const SplitParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(in_attrs->size(), 1U); + mxnet::TShape dshape = in_attrs->at(split_enum::kData); + if (!mxnet::ndim_is_known(dshape)) return false; + if (param.axis >= 0) { + CHECK_LT(param.axis, dshape.ndim()); + } else { + CHECK_LT(param.axis + dshape.ndim(), dshape.ndim()); + } + int real_axis = param.axis; + if (real_axis < 0) { + real_axis += dshape.ndim(); + } + return SplitOpShapeImpl(attrs, in_attrs, out_attrs, real_axis); +} + struct SplitKernel { /*! * \brief Map function for forward split_v2 operator @@ -2781,24 +2791,19 @@ struct ConcatenateKernel { }; template -inline void SplitOpForward(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { +inline void SplitOpForwardImpl(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs, + const int real_axis) { using namespace mshadow; using namespace mshadow::expr; using namespace mxnet_op; const SplitParam& param = nnvm::get(attrs.parsed); - CHECK_EQ(inputs.size(), 1U); - CHECK_EQ(outputs.size(), (param.sections > 0) ? param.sections : param.indices.ndim()); Stream *s = ctx.get_stream(); const TBlob& input_data = inputs[split_enum::kData]; size_t leading = 1, trailing = 1; - int real_axis = param.axis; - if (real_axis < 0) { - real_axis += input_data.ndim(); - } CHECK_LT(real_axis, input_data.ndim()); size_t mid = input_data.shape_[real_axis]; for (int i = 0; i < real_axis; ++i) { @@ -2844,25 +2849,39 @@ inline void SplitOpForward(const nnvm::NodeAttrs& attrs, } template -inline void SplitOpBackward(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { +inline void SplitOpForward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { using namespace mshadow; using namespace mshadow::expr; using namespace mxnet_op; const SplitParam& param = nnvm::get(attrs.parsed); - CHECK_EQ(inputs.size(), (param.sections > 0) ? param.sections : param.indices.ndim()) - << "out grad vector size mush match the output size"; - CHECK_EQ(outputs.size(), 1U); - Stream *s = ctx.get_stream(); - TBlob input_grad = outputs[split_enum::kData]; - size_t leading = 1, trailing = 1; + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), (param.sections > 0) ? param.sections : param.indices.ndim()); + const TBlob& input_data = inputs[split_enum::kData]; int real_axis = param.axis; if (real_axis < 0) { - real_axis += input_grad.ndim(); + real_axis += input_data.ndim(); } + SplitOpForwardImpl(attrs, ctx, inputs, req, outputs, real_axis); +} + +template +inline void SplitOpBackwardImpl(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs, + const int real_axis) { + using namespace mshadow; + using namespace mshadow::expr; + using namespace mxnet_op; + const SplitParam& param = nnvm::get(attrs.parsed); + Stream *s = ctx.get_stream(); + TBlob input_grad = outputs[split_enum::kData]; + size_t leading = 1, trailing = 1; CHECK_LT(real_axis, input_grad.ndim()); size_t mid = input_grad.shape_[real_axis]; for (int i = 0; i < real_axis; ++i) { @@ -2907,6 +2926,26 @@ inline void SplitOpBackward(const nnvm::NodeAttrs& attrs, }); } +template +inline void SplitOpBackward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mshadow::expr; + using namespace mxnet_op; + const SplitParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(inputs.size(), (param.sections > 0) ? param.sections : param.indices.ndim()) + << "out grad vector size mush match the output size"; + CHECK_EQ(outputs.size(), 1U); + int real_axis = param.axis; + if (real_axis < 0) { + real_axis += outputs[split_enum::kData].ndim(); + } + SplitOpBackwardImpl(attrs, ctx, inputs, req, outputs, real_axis); +} + inline uint32_t SplitNumOutputs(const NodeAttrs& attrs) { const SplitParam& param = nnvm::get(attrs.parsed); return (param.sections > 0) ? param.sections : param.indices.ndim(); diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 224125eb64f3..7a77533577ee 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -1320,6 +1320,57 @@ def get_indices(axis_size): assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5) +@with_seed() +@use_np +def test_np_hsplit(): + class TestHSplit(HybridBlock): + def __init__(self, indices_or_sections): + super(TestHSplit, self).__init__() + self._indices_or_sections = indices_or_sections + + def hybrid_forward(self, F, a, *args, **kwargs): + return F.np.hsplit(a, indices_or_sections=self._indices_or_sections) + + shapes = [ + (10,), + (3, 8, 5), + (3, 0, 5), + (3, 8, 5, 6), + (3, 0, 5, 6), + ] + indices_or_sections_num = [ + (2, 4), + (3, 3), + (3,), + (1,), + 2, + ] + for hybridize in [True, False]: + for shape in shapes: + for indices_or_sections in indices_or_sections_num: + # test gluon + test_hsplit = TestHSplit(indices_or_sections=indices_or_sections) + if hybridize: + test_hsplit.hybridize() + + a = mx.nd.random.uniform(-1.0, 1.0, shape=shape).as_np_ndarray() + a.attach_grad() + expected_ret = _np.hsplit(a.asnumpy(), indices_or_sections=indices_or_sections) + with mx.autograd.record(): + y = test_hsplit(a) + assert len(y) == len(expected_ret) + for mx_out, np_out in zip(y, expected_ret): + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5) + mx.autograd.backward(y) + assert_almost_equal(a.grad.asnumpy(), _np.ones(a.shape), rtol=1e-3, atol=1e-5) + + # test imperative + mx_outs = np.hsplit(a, indices_or_sections=indices_or_sections) + np_outs = _np.hsplit(a.asnumpy(), indices_or_sections=indices_or_sections) + for mx_out, np_out in zip(mx_outs, np_outs): + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5) + + @with_seed() @use_np def test_np_concat(): @@ -2861,6 +2912,103 @@ def test_np_builtin_op_signature(): assert str(op.__signature__) == str(inspect.signature(getattr(_numpy_op_doc, op_name))) +@with_seed() +@use_np +def test_np_moveaxis(): + class TestMoveaxis(HybridBlock): + def __init__(self, source=None, destination=None): + super(TestMoveaxis, self).__init__() + self._source = source + self._destination= destination + + def hybrid_forward(self, F, x): + return F.np.moveaxis(x, source=self._source, destination=self._destination) + + dtypes = ['int32', 'int64', 'float16', 'float32', 'float64'] + for hybridize in [False, True]: + for dtype in dtypes: + for ndim in [0, 1, 2, 3, 4, 5, 6]: + shape = rand_shape_nd(ndim, dim=5, allow_zero_size=True) + np_data = _np.random.uniform(low=-100, high=100, size=shape).astype(dtype) + mx_data = np.array(np_data, dtype=dtype) + axis = [i for i in range(ndim)] + random.shuffle(axis) + for i in range(ndim): + source = random.sample(axis, i) + destination = random.sample(axis, i) + + # test gluon + test_moveaxis = TestMoveaxis(source,destination) + if hybridize: + test_moveaxis.hybridize() + np_out = _np.moveaxis(np_data, source=source, destination=destination) + mx_data.attach_grad() + with mx.autograd.record(): + mx_out = test_moveaxis(mx_data) + assert mx_out.shape == np_out.shape + mx_out.backward() + assert same(mx_data.grad.shape, mx_data.shape) + assert same(mx_data.grad.asnumpy(), _np.ones(shape)) + # test imperative + np_out = _np.moveaxis(np_data, source=source, destination=destination) + mx_out = np.moveaxis(mx_data, source=source, destination= destination) + assert np_out.dtype == mx_out.dtype + assert same(mx_out.asnumpy(), np_out) + + +@with_seed() +@use_np +def test_np_rot90(): + class TestTRot90(HybridBlock): + def __init__(self, k=1, axes=(0, 1)): + super(TestTRot90, self).__init__() + self._k = k + self._axes = axes + + def hybrid_forward(self, F, a, *args): + return F.np.rot90(a, self._k, self._axes) + + configs = [ + ((2, 3), 1, (0, 1)), + ((2, 3), 3, (0, 1)), + ((2, 3), 1, (1, 0)), + ((2, 3), 2, (1, 0)), + ((2, 3), 3, (1, 0)), + ((2, 3), 0, (1, 0)), + ((2, 3, 4, 5), 3, (1, 2)), + ((2, 3, 4, 5), -3, (2, 3)), + ((2, 3, 0, 5), -2, (2, 3)), + ((2, 0, 0, 5), -3, (2, 3)), + ((2, 3, 0, 5), 0, (2, 1)), + ] + dtypes = ['uint8', 'int8', 'int32', 'int64', 'float16', 'float32', 'float64'] + + for config in configs: + for dtype in dtypes: + for hybridize in [True, False]: + shape, k, axes = config[0], config[1], config[2] + x = rand_ndarray(shape=shape, dtype=dtype).as_np_ndarray() + net = TestTRot90(k=k, axes=axes) + if hybridize: + net.hybridize() + + x.attach_grad() + np_out = _np.rot90(x.asnumpy(), k=k, axes=axes) + with mx.autograd.record(): + mx_out = net(x) + assert mx_out.shape == np_out.shape + assert same(mx_out.asnumpy(), np_out) + mx_out.backward() + np_backward = _np.ones(shape, dtype) + + assert same(x.grad.asnumpy().shape, np_backward.shape) + assert same(x.grad.asnumpy(), np_backward) + + np_out = _np.rot90(x.asnumpy(), k=k, axes=axes) + mx_out = np.rot90(x, k=k, axes=axes) + assert same(mx_out.asnumpy(), np_out) + + if __name__ == '__main__': import nose nose.runmodule()