diff --git a/python/mxnet/_numpy_op_doc.py b/python/mxnet/_numpy_op_doc.py index 5543ebc8e8c9..61c80d53c4bc 100644 --- a/python/mxnet/_numpy_op_doc.py +++ b/python/mxnet/_numpy_op_doc.py @@ -52,3 +52,67 @@ def _np_zeros_like(a): Array of zeros with the same shape and type as `a`. """ pass + + +def _np_roll(a, shift, axis=None): + """Roll array elements along a given axis. + + Elements that roll beyond the last position are re-introduced at + the first. + + Parameters + ---------- + a : ndarray + Input array. + shift : int or tuple of ints + The number of places by which elements are shifted. If a tuple, + then `axis` must be a tuple of the same size, and each of the + given axes is shifted by the corresponding number. If an int + while `axis` is a tuple of ints, then the same value is used for + all given axes. + axis : int or tuple of ints, optional + Axis or axes along which elements are shifted. By default, the + array is flattened before shifting, after which the original + shape is restored. + + Returns + ------- + res : ndarray + Output array, with the same shape as `a`. + + Notes + ----- + Supports rolling over multiple dimensions simultaneously. + + Examples + -------- + >>> x = np.arange(10) + >>> np.roll(x, 2) + array([8., 9., 0., 1., 2., 3., 4., 5., 6., 7.]) + >>> np.roll(x, -2) + array([2., 3., 4., 5., 6., 7., 8., 9., 0., 1.]) + + >>> x2 = np.reshape(x, (2,5)) + >>> x2 + array([[0., 1., 2., 3., 4.], + [5., 6., 7., 8., 9.]]) + >>> np.roll(x2, 1) + array([[9., 0., 1., 2., 3.], + [4., 5., 6., 7., 8.]]) + >>> np.roll(x2, -1) + array([[1., 2., 3., 4., 5.], + [6., 7., 8., 9., 0.]]) + >>> np.roll(x2, 1, axis=0) + array([[5., 6., 7., 8., 9.], + [0., 1., 2., 3., 4.]]) + >>> np.roll(x2, -1, axis=0) + array([[5., 6., 7., 8., 9.], + [0., 1., 2., 3., 4.]]) + >>> np.roll(x2, 1, axis=1) + array([[4., 0., 1., 2., 3.], + [9., 5., 6., 7., 8.]]) + >>> np.roll(x2, -1, axis=1) + array([[1., 2., 3., 4., 0.], + [6., 7., 8., 9., 5.]]) + """ + pass diff --git a/src/operator/numpy/np_matrix_op-inl.h b/src/operator/numpy/np_matrix_op-inl.h index 6d3d9ea5ec85..23afdc4065b3 100644 --- a/src/operator/numpy/np_matrix_op-inl.h +++ b/src/operator/numpy/np_matrix_op-inl.h @@ -26,6 +26,7 @@ #define MXNET_OPERATOR_NUMPY_NP_MATRIX_OP_INL_H_ #include +#include #include "../tensor/matrix_op-inl.h" #include "../nn/concat-inl.h" @@ -60,6 +61,161 @@ void NumpyTranspose(const nnvm::NodeAttrs& attrs, } } +struct NumpyRollParam : public dmlc::Parameter { + dmlc::optional shift; + dmlc::optional axis; + DMLC_DECLARE_PARAMETER(NumpyRollParam) { + DMLC_DECLARE_FIELD(shift) + .set_default(dmlc::optional()) + .describe("The number of places by which elements are shifted. If a tuple," + "then axis must be a tuple of the same size, and each of the given axes is shifted" + "by the corresponding number. If an int while axis is a tuple of ints, " + "then the same value is used for all given axes."); + DMLC_DECLARE_FIELD(axis) + .set_default(dmlc::optional()) + .describe("Axis or axes along which elements are shifted. By default, the array is flattened" + "before shifting, after which the original shape is restored."); + } +}; + +template +struct RollAxisNone_forward { + template + MSHADOW_XINLINE static void Map(int i, DType* out_data, const DType* in_data, + const int size, const int shift) { + int new_index = i - shift < 0 ? i - shift + size : i - shift; + KERNEL_ASSIGN(out_data[i], req, in_data[new_index]); + } +}; + +template +struct RollAxis_forward { + template + MSHADOW_XINLINE static void Map(int i, DType* out_data, const DType* in_data, + const size_t* new_index) { + KERNEL_ASSIGN(out_data[i], req, in_data[new_index[i]]); + } +}; + +inline void RollDfs(const std::vector>& new_axes, + const std::vector& value, + std::vector* new_index, + int index, int ndim, int mid) { + for (int a : new_axes[index]) { + if (index == ndim - 1) { + std::vector& out = (*new_index); + out.push_back(mid + a); + } else { + mid += a * value[ndim - 1 - index]; + RollDfs(new_axes, value, new_index, index + 1, ndim, mid); + mid -= a * value[ndim - 1 - index]; + } + } +} + +template +void NumpyRollCompute(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mxnet_op; + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 1U); + CHECK_EQ(req.size(), 1U); + if (inputs[0].Size() == 0U) return; + const NumpyRollParam& param = nnvm::get(attrs.parsed); + const index_t ndim(inputs[0].shape_.ndim()); + Stream *s = ctx.get_stream(); + std::vector shifts(ndim, 0); + index_t input_size = inputs[0].Size(); + if (!param.axis.has_value()) { + int shift = param.shift.value()[0]; + shift = shift % input_size; + if (shift < 0) { + shift += inputs[0].shape_.Size(); + } + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { + MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { + Kernel, xpu>::Launch( + s, outputs[0].Size(), outputs[0].dptr(), inputs[0].dptr(), + inputs[0].Size(), shift); + }); + }); + } else { + mxnet::TShape axes(param.axis.value()); + for (int i = 0; i < axes.ndim(); ++i) { + if (axes[i] < 0) + axes[i] += ndim; + } + for (int i = 0; i < axes.ndim(); ++i) { + CHECK_LT(axes[i], ndim) + << "axis " << axes[i] + << " Exceeds input dimensions " << inputs[0].shape_; + CHECK_GE(axes[0], 0) + << "Reduction axis " << param.axis.value() + << " Exceeds input dimensions " << inputs[0].shape_; + } + if (param.shift.value().ndim() == 1) { + for (int i = 0; i < axes.ndim(); ++i) { + shifts[axes[i]] = param.shift.value()[0]; + } + } else { + if (param.shift.value().ndim() != axes.ndim()) { + LOG(FATAL) << "shift and `axis` must be a tuple of the same size,"; + } + for (int i = 0; i < axes.ndim(); ++i) { + shifts[axes[i]] = param.shift.value()[i]; + } + } + // keep shift in a legal range + for (int i = 0; i < ndim; ++i) { + int trans_shift = shifts[i] % inputs[0].shape_[i]; + if (trans_shift < 0) { + trans_shift = shifts[i] + inputs[0].shape_[i]; + } + shifts[i] = trans_shift; + } + // the result of new axis after shift. + std::vector> new_axes; + std::vector new_index; + std::vector temp; + std::vector value(ndim, 0); + int mid_val = 1; + for (int i = 0; i < ndim; ++i) { + if (shifts[i] != 0) { + for (int j = 0; j < inputs[0].shape_[i]; ++j) { + int new_axis = (j + inputs[0].shape_[i] - shifts[i]) % inputs[0].shape_[i]; + temp.push_back(new_axis); + } + } else { + for (int j = 0; j < inputs[0].shape_[i]; ++j) { + temp.push_back(j); + } + } + new_axes.push_back(temp); + temp.clear(); + value[i] = mid_val; + mid_val *= inputs[0].shape_[ndim - 1 - i]; + } + RollDfs(new_axes, value, &new_index, 0, ndim, 0); + size_t workspace_size = new_index.size() * sizeof(size_t); + Tensor workspace = + ctx.requested[0].get_space_typed(Shape1(workspace_size), s); + Tensor index_cpu_tensor(new_index.data(), Shape1(new_index.size())); + Tensor index_xpu_tensor( + reinterpret_cast(workspace.dptr_), Shape1(new_index.size())); + mshadow::Copy(index_xpu_tensor, index_cpu_tensor, s); + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { + MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { + Kernel, xpu>::Launch( + s, outputs[0].Size(), outputs[0].dptr(), inputs[0].dptr(), + index_xpu_tensor.dptr_); + }); + }); + } +} + } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/np_matrix_op.cc b/src/operator/numpy/np_matrix_op.cc index 5ad6c8908017..e69d89b1149e 100644 --- a/src/operator/numpy/np_matrix_op.cc +++ b/src/operator/numpy/np_matrix_op.cc @@ -30,6 +30,7 @@ namespace mxnet { namespace op { DMLC_REGISTER_PARAMETER(NumpyTransposeParam); +DMLC_REGISTER_PARAMETER(NumpyRollParam); bool NumpyTransposeShape(const nnvm::NodeAttrs& attrs, mxnet::ShapeVector *in_attrs, @@ -345,5 +346,73 @@ Examples:: .add_argument("data", "NDArray-or-Symbol[]", "List of arrays to stack") .add_arguments(StackParam::__FIELDS__()); +inline bool NumpyRollShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector *in_attrs, + mxnet::ShapeVector *out_attrs) { + using namespace mshadow; + const NumpyRollParam& param = nnvm::get(attrs.parsed); + + if (!param.shift.has_value()) { + LOG(FATAL) << "roll missing 1 required positional argument: 'shift'."; + } + if (param.shift.value().ndim() > 1 && + param.axis.has_value() && + param.axis.value().ndim() != param.shift.value().ndim()) { + LOG(FATAL) << "shift and `axis` must be a tuple of the same size."; + } + if (!param.axis.has_value() && param.shift.has_value() && param.shift.value().ndim() > 1) { + LOG(FATAL) << "shift must be an int."; + } + if (param.axis.has_value()) { + mxnet::TShape axes(param.axis.value()); + const index_t ndim = (*in_attrs)[0].ndim(); + for (index_t i = 0; i < axes.ndim(); i++) { + if (axes[i] < 0) { + axes[i] += ndim; + } + } + std::sort(axes.begin(), axes.end()); + for (index_t i = 1; i < axes.ndim(); i++) { + CHECK_LT(axes[i - 1], axes[i]) + << "axes have duplicates " << axes; + } + CHECK_LT(axes[axes.ndim() - 1], ndim) + << "axis " << axes[axes.ndim() - 1] + << " Exceeds input dimensions " << (*in_attrs)[0]; + CHECK_GE(axes[0], 0) + << "Reduction axis " << param.axis.value() + << " Exceeds input dimensions " << (*in_attrs)[0]; + } + return ElemwiseShape<1, 1>(attrs, in_attrs, out_attrs); +} + +NNVM_REGISTER_OP(_np_roll) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"data"}; +}) +.set_attr("FInferShape", NumpyRollShape) +.set_attr("FInferType", ElemwiseType<1, 1>) +.set_attr("FCompute", NumpyRollCompute) +.set_attr("FGradient", + [](const nnvm::NodePtr& n, const std::vector& ograds) { + const NumpyRollParam& param = nnvm::get(n->attrs.parsed); + std::ostringstream os1; + os1 << param.shift; + std::ostringstream os2; + os2 << param.axis; + return MakeNonlossGradNode("_np_roll", n, ograds, {}, + {{"shift", os1.str()}, {"axis", os2.str()}}); +}) +.set_attr("FResourceRequest", + [](const NodeAttrs& n) { + return std::vector{ResourceRequest::kTempSpace}; +}) +.add_argument("data", "NDArray-or-Symbol", "Input ndarray") +.add_arguments(NumpyRollParam::__FIELDS__()); + } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/np_matrix_op.cu b/src/operator/numpy/np_matrix_op.cu index 4ba527deca09..0cf9d37f66a6 100644 --- a/src/operator/numpy/np_matrix_op.cu +++ b/src/operator/numpy/np_matrix_op.cu @@ -46,5 +46,8 @@ NNVM_REGISTER_OP(_backward_np_concat) NNVM_REGISTER_OP(_npi_stack) .set_attr("FCompute", StackOpForward); +NNVM_REGISTER_OP(_np_roll) +.set_attr("FCompute", NumpyRollCompute); + } // namespace op } // namespace mxnet diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 399cdead6177..475f97747d0c 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -1080,6 +1080,54 @@ def hybrid_forward(self, F, a, *args): assert same(mx_out.asnumpy(), np_out) +@with_seed() +@use_np +def test_np_roll(): + class TestRoll(HybridBlock): + def __init__(self, shift=None, axis=None): + super(TestRoll, self).__init__() + self._shift = shift + self._axis = axis + + def hybrid_forward(self, F, x): + return F.np.roll(x, shift=self._shift, axis=self._axis) + + dtypes = ['int32', 'int64', 'float16', 'float32', 'float64'] + configs = [ + ((), (3,), None), + ((1,), (-3,), None), + ((20,), (-3,), None), + ((3,), (2,), 0), + ((2, 3, 4), (12,), (1,)), + ((2, 3, 4), (10, -10), (0, 1)), + ((2, 3, 4, 5), (0, 1), (-1, 2)), + ((2, 3, 0, 1), (0, 1), (-1, 2)), + ((2, 3, 4, 5), 10, (0, 2)), + ] + for dtype in dtypes: + for config in configs: + for hybridize in [False, True]: + shape, shift, axis = config[0], config[1], config[2] + x = rand_ndarray(shape=shape, dtype=dtype).as_np_ndarray() + net = TestRoll(shift=shift, axis=axis) + np_out = _np.roll(x.asnumpy(), shift=shift, axis=axis) + if hybridize: + net.hybridize() + x.attach_grad() + with mx.autograd.record(): + mx_out = net(x) + assert mx_out.shape == np_out.shape + mx_out.backward() + assert same(mx_out.asnumpy(), np_out) + assert same(x.grad.shape, x.shape) + assert same(x.grad.asnumpy(), _np.ones(shape)) + + # test imperativen + np_out = _np.roll(x.asnumpy(), shift=shift, axis=axis) + mx_out = np.roll(x, shift=shift, axis=axis) + assert same(mx_out.asnumpy(), np_out) + + if __name__ == '__main__': import nose nose.runmodule()