diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index 99e35f37844d..7a5cf37403b4 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -34,7 +34,7 @@ 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign', - 'ravel', 'hanning', 'hamming', 'blackman'] + 'ravel', 'hanning', 'hamming', 'blackman', 'flip'] @set_module('mxnet.ndarray.numpy') @@ -2828,3 +2828,71 @@ def blackman(M, dtype=_np.float32, ctx=None): if ctx is None: ctx = current_context() return _npi.blackman(M, dtype=dtype, ctx=ctx) + + +@set_module('mxnet.ndarray.numpy') +def flip(m, axis=None, out=None): + r""" + flip(m, axis=None, out=None) + + Reverse the order of elements in an array along the given axis. + + The shape of the array is preserved, but the elements are reordered. + + Parameters + ---------- + m : ndarray or scalar + Input array. + axis : None or int or tuple of ints, optional + Axis or axes along which to flip over. The default, + axis=None, will flip over all of the axes of the input array. + If axis is negative it counts from the last to the first axis. + + If axis is a tuple of ints, flipping is performed on all of the axes + specified in the tuple. + out : ndarray or scalar, optional + Alternative output array in which to place the result. It must have + the same shape and type as the expected output. + + Returns + ------- + out : ndarray or scalar + A view of `m` with the entries of axis reversed. Since a view is + returned, this operation is done in constant time. + + Examples + -------- + >>> A = np.arange(8).reshape((2,2,2)) + >>> A + array([[[0, 1], + [2, 3]], + [[4, 5], + [6, 7]]]) + >>> np.flip(A, 0) + array([[[4, 5], + [6, 7]], + [[0, 1], + [2, 3]]]) + >>> np.flip(A, 1) + array([[[2, 3], + [0, 1]], + [[6, 7], + [4, 5]]]) + >>> np.flip(A) + array([[[7, 6], + [5, 4]], + [[3, 2], + [1, 0]]]) + >>> np.flip(A, (0, 2)) + array([[[5, 4], + [7, 6]], + [[1, 0], + [3, 2]]]) + """ + from ...numpy import ndarray + if isinstance(m, numeric_types): + return _np.flip(m, axis) + elif isinstance(m, ndarray): + return _npi.flip(m, axis, out=out) + else: + raise TypeError('type {} not supported'.format(str(type(m)))) diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index 801538557c5d..5b387601f59d 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -53,7 +53,7 @@ 'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', - 'copysign', 'ravel', 'hanning', 'hamming', 'blackman'] + 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip'] # Return code for dispatching indexing function call _NDARRAY_UNSUPPORTED_INDEXING = -1 @@ -4367,3 +4367,65 @@ def blackman(M, dtype=_np.float32, ctx=None): >>> plt.show() """ return _mx_nd_np.blackman(M, dtype=dtype, ctx=ctx) + + +@set_module('mxnet.numpy') +def flip(m, axis=None, out=None): + r""" + flip(m, axis=None, out=None) + + Reverse the order of elements in an array along the given axis. + + The shape of the array is preserved, but the elements are reordered. + + Parameters + ---------- + m : ndarray or scalar + Input array. + axis : None or int or tuple of ints, optional + Axis or axes along which to flip over. The default, + axis=None, will flip over all of the axes of the input array. + If axis is negative it counts from the last to the first axis. + + If axis is a tuple of ints, flipping is performed on all of the axes + specified in the tuple. + out : ndarray or scalar, optional + Alternative output array in which to place the result. It must have + the same shape and type as the expected output. + + Returns + ------- + out : ndarray or scalar + A view of `m` with the entries of axis reversed. Since a view is + returned, this operation is done in constant time. + + Examples + -------- + >>> A = np.arange(8).reshape((2,2,2)) + >>> A + array([[[0, 1], + [2, 3]], + [[4, 5], + [6, 7]]]) + >>> np.flip(A, 0) + array([[[4, 5], + [6, 7]], + [[0, 1], + [2, 3]]]) + >>> np.flip(A, 1) + array([[[2, 3], + [0, 1]], + [[6, 7], + [4, 5]]]) + >>> np.flip(A) + array([[[7, 6], + [5, 4]], + [[3, 2], + [1, 0]]]) + >>> np.flip(A, (0, 2)) + array([[[5, 4], + [7, 6]], + [[1, 0], + [3, 2]]]) + """ + return _mx_nd_np.flip(m, axis, out=out) diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index f39d201638cb..dc2c84da167d 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -36,7 +36,7 @@ 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign', - 'ravel', 'hanning', 'hamming', 'blackman'] + 'ravel', 'hanning', 'hamming', 'blackman', 'flip'] def _num_outputs(sym): @@ -3091,4 +3091,42 @@ def blackman(M, dtype=_np.float32, ctx=None): return _npi.blackman(M, dtype=dtype, ctx=ctx) +@set_module('mxnet.symbol.numpy') +def flip(m, axis=None, out=None): + r""" + flip(m, axis=None, out=None) + + Reverse the order of elements in an array along the given axis. + + The shape of the array is preserved, but the elements are reordered. + + Parameters + ---------- + m : _Symbol or scalar + Input array. + axis : None or int or tuple of ints, optional + Axis or axes along which to flip over. The default, + axis=None, will flip over all of the axes of the input array. + If axis is negative it counts from the last to the first axis. + + If axis is a tuple of ints, flipping is performed on all of the axes + specified in the tuple. + out : _Symbol or scalar, optional + Alternative output array in which to place the result. It must have + the same shape and type as the expected output. + + Returns + ------- + out : _Symbol or scalar + A view of `m` with the entries of axis reversed. Since a view is + returned, this operation is done in constant time. + """ + if isinstance(m, numeric_types): + return _np.flip(m, axis) + elif isinstance(m, _Symbol): + return _npi.flip(m, axis, out=out) + else: + raise TypeError('type {} not supported'.format(str(type(m)))) + + _set_np_symbol_class(_Symbol) diff --git a/src/operator/numpy/np_matrix_op-inl.h b/src/operator/numpy/np_matrix_op-inl.h index 631196e3cd1a..5e25192d9298 100644 --- a/src/operator/numpy/np_matrix_op-inl.h +++ b/src/operator/numpy/np_matrix_op-inl.h @@ -297,6 +297,72 @@ void NumpyRollCompute(const nnvm::NodeAttrs& attrs, } } +struct FlipParam : public dmlc::Parameter { + mxnet::Tuple axis; + DMLC_DECLARE_PARAMETER(FlipParam) { + DMLC_DECLARE_FIELD(axis) + .describe("The axis which to flip elements."); + } +}; + +#define FLIP_MAX_DIM 10 +#define FLIP_MIN_DIM -1 + +template +void NumpyFlipForwardImpl(const OpContext& ctx, + const std::vector& inputs, + const std::vector& outputs, + const std::vector& stride_, + const std::vector& trailing_, + const index_t& flip_index); + +template +void NumpyFlipForward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + const FlipParam& param = nnvm::get(attrs.parsed); + mxnet::Tuple axistemp; + CHECK_EQ(inputs[0].type_flag_, outputs[0].type_flag_); + CHECK_LT(param.axis.ndim(), FLIP_MAX_DIM); + CHECK_GE(param.axis.ndim(), FLIP_MIN_DIM); + if (param.axis.ndim() == FLIP_MIN_DIM) { + if (inputs[0].shape_.ndim() == 0) { + mshadow::Stream *s = ctx.get_stream(); + MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, { + mshadow::Copy(outputs[0].FlatTo1D(s), inputs[0].FlatTo1D(s), s); + }); + return; + } + std::vector temp; + for (int i = 0; i < inputs[0].shape_.ndim(); i++) { + temp.push_back(i); + } + axistemp.assign(temp.begin(), temp.end()); + } else { + axistemp = param.axis; + } + + const mxnet::TShape& ishape = inputs[0].shape_; + if (ishape.ProdShape(0, ishape.ndim()) == 0) { + return; // zero shape + } + std::vector stride_(axistemp.ndim()); + std::vector trailing_(axistemp.ndim()); + index_t flip_index = 0; + for (int axis : axistemp) { + CHECK_LT(axis, ishape.ndim()); + stride_[flip_index] = ishape[axis]; + trailing_[flip_index] = 1; + for (int i2 = axis + 1; i2 < ishape.ndim(); ++i2) { + trailing_[flip_index] *= ishape[i2]; + } + flip_index++; + } + NumpyFlipForwardImpl(ctx, inputs, outputs, stride_, trailing_, flip_index); +} + } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/np_matrix_op.cc b/src/operator/numpy/np_matrix_op.cc index 45e4c3f2d5b4..96a10561be28 100644 --- a/src/operator/numpy/np_matrix_op.cc +++ b/src/operator/numpy/np_matrix_op.cc @@ -565,5 +565,52 @@ NNVM_REGISTER_OP(_np_roll) .add_argument("data", "NDArray-or-Symbol", "Input ndarray") .add_arguments(NumpyRollParam::__FIELDS__()); +template<> +void NumpyFlipForwardImpl(const OpContext& ctx, + const std::vector& inputs, + const std::vector& outputs, + const std::vector& stride_, + const std::vector& trailing_, + const index_t& flip_index) { + mshadow::Stream *s = ctx.get_stream(); + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { + mxnet_op::Kernel::Launch(s, inputs[0].Size(), flip_index, + inputs[0].dptr(), outputs[0].dptr(), + stride_.data(), trailing_.data()); + }); +} + +DMLC_REGISTER_PARAMETER(FlipParam); + +NNVM_REGISTER_OP(_npi_flip) +.set_num_outputs(1) +.set_num_inputs(1) +.set_attr_parser(ParamParser) +.set_attr("FListInputNames", +[](const NodeAttrs& attrs) { + return std::vector {"data"}; +}) +.set_attr("FResourceRequest", +[](const NodeAttrs& attrs) { + return std::vector {ResourceRequest::kTempSpace}; +}) +.set_attr("FInferShape", ElemwiseShape<1, 1>) +.set_attr("FInferType", ElemwiseType<1, 1>) +.set_attr("FCompute", NumpyFlipForward) +.set_attr("FGradient", ElemwiseGradUseNone{"_backward_npi_flip"}) +.add_argument("data", "NDArray-or-Symbol", "Input data array") +.add_arguments(FlipParam::__FIELDS__()); + +NNVM_REGISTER_OP(_backward_npi_flip) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("TIsBackward", true) +.set_attr("FResourceRequest", +[](const NodeAttrs& attrs) { + return std::vector {ResourceRequest::kTempSpace}; +}) +.set_attr("FCompute", NumpyFlipForward); + } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/np_matrix_op.cu b/src/operator/numpy/np_matrix_op.cu index ba49d320575c..caab4108b40e 100644 --- a/src/operator/numpy/np_matrix_op.cu +++ b/src/operator/numpy/np_matrix_op.cu @@ -56,5 +56,39 @@ NNVM_REGISTER_OP(_backward_np_vstack) NNVM_REGISTER_OP(_np_roll) .set_attr("FCompute", NumpyRollCompute); +template<> +void NumpyFlipForwardImpl(const OpContext& ctx, + const std::vector& inputs, + const std::vector& outputs, + const std::vector& stride_, + const std::vector& trailing_, + const index_t& flip_index) { + mshadow::Stream *s = ctx.get_stream(); + mshadow::Tensor workspace = + ctx.requested[0].get_space_typed( + mshadow::Shape1(flip_index * sizeof(index_t) * 2), s); + + auto stride_workspace = workspace.dptr_; + auto trailing_workspace = workspace.dptr_ + flip_index * sizeof(index_t); + + cudaMemcpyAsync(stride_workspace, thrust::raw_pointer_cast(stride_.data()), + stride_.size() * sizeof(index_t), + cudaMemcpyHostToDevice, mshadow::Stream::GetStream(s)); + cudaMemcpyAsync(trailing_workspace, thrust::raw_pointer_cast(trailing_.data()), + trailing_.size() * sizeof(index_t), + cudaMemcpyHostToDevice, mshadow::Stream::GetStream(s)); + + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { + mxnet_op::Kernel::Launch(s, inputs[0].Size(), flip_index, + inputs[0].dptr(), outputs[0].dptr(), + reinterpret_cast(stride_workspace), reinterpret_cast(trailing_workspace)); + }); +} + +NNVM_REGISTER_OP(_npi_flip) +.set_attr("FCompute", NumpyFlipForward); + +NNVM_REGISTER_OP(_backward_npi_flip) +.set_attr("FCompute", NumpyFlipForward); } // namespace op } // namespace mxnet diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index ad877bab5b7d..95140a69ab45 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -30,6 +30,7 @@ from mxnet.test_utils import verify_generator, gen_buckets_probs_with_ppf import platform + @with_seed() @use_np def test_np_tensordot(): @@ -2307,6 +2308,46 @@ def hybrid_forward(self, F, x, *args, **kwargs): assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5) +@with_seed() +@use_np +def test_np_flip(): + class TestFlip(HybridBlock): + def __init__(self, axis): + super(TestFlip, self).__init__() + self.axis = axis + + def hybrid_forward(self, F, x): + return F.np.flip(x, self.axis) + + shapes = [(1, 2, 3), (1, 0), ()] + types = ['int32', 'int64', 'float16', 'float32', 'float64'] + for hybridize in [True, False]: + for oneType in types: + rtol, atol=1e-3, 1e-5 + for shape in shapes: + axis = random.randint(-1, len(shape) - 1) + if axis is -1: + axis = None + test_flip = TestFlip(axis) + if hybridize: + test_flip.hybridize() + x = rand_ndarray(shape, dtype=oneType).as_np_ndarray() + x.attach_grad() + np_out = _np.flip(x.asnumpy(), axis) + with mx.autograd.record(): + mx_out = test_flip(x) + assert mx_out.shape == np_out.shape + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol) + mx_out.backward() + np_backward = _np.ones(np_out.shape) + assert_almost_equal(x.grad.asnumpy(), np_backward, rtol=rtol, atol=atol) + + # Test imperative once again + mx_out = np.flip(x, axis) + np_out = _np.flip(x.asnumpy(), axis) + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol) + + if __name__ == '__main__': import nose nose.runmodule()