From cf3984bf5c67cb7d1feeb5b3cb55a41ca995e5c8 Mon Sep 17 00:00:00 2001 From: Yiyan66 <57363390+Yiyan66@users.noreply.github.com> Date: Wed, 10 Jun 2020 05:56:13 +0800 Subject: [PATCH] [numpy] fix op repeat with list input (#18371) * except .h * except storage * repeat * change fwd * delete * codecov Co-authored-by: Ubuntu --- python/mxnet/ndarray/numpy/_op.py | 8 +- python/mxnet/symbol/numpy/_symbol.py | 8 +- src/api/operator/numpy/np_matrix_op.cc | 22 -- src/api/operator/numpy/np_repeat_op.cc | 52 +++++ src/operator/numpy/np_repeat_op-inl.h | 221 ++++++++++++++++++ src/operator/numpy/np_repeat_op.cc | 50 ++++ src/operator/numpy/np_repeat_op.cu | 35 +++ .../unittest/test_numpy_interoperability.py | 8 +- tests/python/unittest/test_numpy_op.py | 2 + 9 files changed, 378 insertions(+), 28 deletions(-) create mode 100644 src/api/operator/numpy/np_repeat_op.cc create mode 100644 src/operator/numpy/np_repeat_op-inl.h create mode 100644 src/operator/numpy/np_repeat_op.cc create mode 100644 src/operator/numpy/np_repeat_op.cu diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index e7b092126c6c..c51d14f860b1 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -4076,7 +4076,13 @@ def repeat(a, repeats, axis=None): [3, 4], [3, 4]]) """ - return _api_internal.repeat(a, repeats, axis) + if isinstance(repeats, numeric_types): + repeats = [repeats] + if axis is not None: + tmp = swapaxes(a, 0, axis) + res = _api_internal.repeats(tmp, repeats, 0) + return swapaxes(res, 0, axis) + return _api_internal.repeats(a, repeats, axis) # pylint: disable=redefined-outer-name diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index 54c155006b04..d3521cad1274 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -2448,7 +2448,13 @@ def repeat(a, repeats, axis=None): [3, 4], [3, 4]]) """ - return _npi.repeat(a, repeats=repeats, axis=axis) + if isinstance(repeats, numeric_types): + repeats = [repeats] + if axis is not None: + tmp = swapaxes(a, 0, axis) + res = _npi.repeats(tmp, repeats=repeats, axis=0) + return swapaxes(res, 0, axis) + return _npi.repeats(a, repeats=repeats, axis=axis) def _unary_func_helper(x, fn_array, fn_scalar, out=None, **kwargs): diff --git a/src/api/operator/numpy/np_matrix_op.cc b/src/api/operator/numpy/np_matrix_op.cc index f9a575d8c8fc..95b9cb573904 100644 --- a/src/api/operator/numpy/np_matrix_op.cc +++ b/src/api/operator/numpy/np_matrix_op.cc @@ -470,28 +470,6 @@ MXNET_REGISTER_API("_npi.diag_indices_from") *ret = ndoutputs[0]; }); -MXNET_REGISTER_API("_npi.repeat") -.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { - using namespace runtime; - const nnvm::Op* op = Op::Get("_npi_repeat"); - nnvm::NodeAttrs attrs; - op::RepeatParam param; - param.repeats = args[1].operator int(); - if (args[2].type_code() == kNull) { - param.axis = dmlc::optional(); - } else { - param.axis = args[2].operator int64_t(); - } - int num_inputs = 1; - int num_outputs = 0; - attrs.parsed = std::move(param); - attrs.op = op; - SetAttrDict(&attrs); - NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; - auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); - *ret = ndoutputs[0]; -}); - MXNET_REGISTER_API("_npi.diagflat") .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; diff --git a/src/api/operator/numpy/np_repeat_op.cc b/src/api/operator/numpy/np_repeat_op.cc new file mode 100644 index 000000000000..c79fb8bbe03c --- /dev/null +++ b/src/api/operator/numpy/np_repeat_op.cc @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file np_repeat_op.cc + * \brief Implementation of the API of functions in src/operator/numpy/np_repeat_op.cc + */ +#include +#include "../utils.h" +#include "../../../operator/numpy/np_repeat_op-inl.h" + +namespace mxnet { + +MXNET_REGISTER_API("_npi.repeats") +.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_repeats"); + nnvm::NodeAttrs attrs; + op::RepeatsParam param; + param.repeats = Tuple(args[1].operator ObjectRef());; + if (args[2].type_code() == kNull) { + param.axis = dmlc::optional(); + } else { + param.axis = args[2].operator int64_t(); + } + int num_inputs = 1; + int num_outputs = 0; + attrs.parsed = std::move(param); + attrs.op = op; + SetAttrDict(&attrs); + NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = ndoutputs[0]; +}); + +} // namespace mxnet diff --git a/src/operator/numpy/np_repeat_op-inl.h b/src/operator/numpy/np_repeat_op-inl.h new file mode 100644 index 000000000000..638f1dee921a --- /dev/null +++ b/src/operator/numpy/np_repeat_op-inl.h @@ -0,0 +1,221 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_repeat_op-inl.h + * \brief Function definition of the repeat op + */ + +#ifndef MXNET_OPERATOR_NUMPY_NP_REPEAT_OP_INL_H_ +#define MXNET_OPERATOR_NUMPY_NP_REPEAT_OP_INL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include "../mshadow_op.h" +#include "../elemwise_op_common.h" +#include "../channel_op_common.h" +#include "../mxnet_op.h" +#include "../../common/static_array.h" + +namespace mxnet { +namespace op { + +struct RepeatsParam : public dmlc::Parameter { + dmlc::optional> repeats; + dmlc::optional axis; + DMLC_DECLARE_PARAMETER(RepeatsParam) { + DMLC_DECLARE_FIELD(repeats) + .describe("The number of repetitions for each element."); + DMLC_DECLARE_FIELD(axis) + .set_default(dmlc::optional()) + .describe("The axis along which to repeat values." + " The negative numbers are interpreted counting from the backward." + " By default, use the flattened input array," + " and return a flat output array."); + } + void SetAttrDict(std::unordered_map* dict) { + std::ostringstream repeats_s, axis_s; + repeats_s << repeats; + axis_s << axis; + (*dict)["repeats"] = repeats_s.str(); + (*dict)["axis"] = axis_s.str(); + } +}; + +inline void GetRepeatsParams(const RepeatsParam& param, const mxnet::TShape& ishape, + int* repeats, dmlc::optional* axisOpt, int* axis) { + *repeats = 0; + const mxnet::Tuple &repts = param.repeats.value(); + for (int i=0; i < repts.ndim(); i++) { + CHECK_GE(repts[i], 0) << "repeats cannot be a negative number"; + *repeats += repts[i]; + } + *axisOpt = param.axis; + if (static_cast(*axisOpt)) { + int ndims = ishape.ndim(); + *axis = axisOpt->value(); + if (*axis < 0) { + *axis += ndims; + } + CHECK(*axis >= 0 && *axis < ndims) << "axis = " << axisOpt->value() << " out of bounds"; + } +} + +inline bool RepeatsOpShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector *in_attrs, + mxnet::ShapeVector *out_attrs) { + const RepeatsParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + const mxnet::TShape& ishape = (*in_attrs)[0]; + int repeats = 0; + dmlc::optional axisOpt; + int axis = -1; + GetRepeatsParams(param, ishape, &repeats, &axisOpt, &axis); + // If 0 repeats, return an empty 1-dim, 0-size array + if (0 == repeats) { + SHAPE_ASSIGN_CHECK(*out_attrs, 0, mxnet::TShape(1, 0)); + return true; + } + + // If repeats > 0, multiply the size of the corresponding axis by repeats + if (static_cast(axisOpt)) { + mxnet::TShape shape(ishape.ndim(), -1); + for (int i = 0; i < ishape.ndim(); ++i) { + if (i == axis) { + shape[i] = param.repeats.value().ndim() == 1 ? repeats * ishape[i] : repeats; + } else { + shape[i] = ishape[i]; + } + } + SHAPE_ASSIGN_CHECK(*out_attrs, 0, shape); + } else { // If axis is not input by user, return a flat 1D array of size = repeats + repeats = param.repeats.value().ndim() == 1 ? ishape.Size() * repeats : repeats; + mxnet::TShape shape(1, repeats); + SHAPE_ASSIGN_CHECK(*out_attrs, 0, shape); + } + return shape_is_known(out_attrs->at(0)); +} + +struct repeat_noaxis_fwd { + template + MSHADOW_XINLINE static void Map(index_t i, OType* out, IType* input, + const int* indx) { + using namespace mxnet_op; + int ind = 0; + while (i >= indx[ind]) ind++; + out[i] = input[ind]; + } +}; + +struct repeat_axis_fwd { + template + MSHADOW_XINLINE static void Map(index_t i, OType* out, IType* input, + const int* indx, int stride) { + using namespace mxnet_op; + int ind_row = i / stride, ind_col = i % stride; + int ind = 0; + while (ind_row >= indx[ind]) ind++; + out[i] = input[ind * stride + ind_col]; + } +}; + +template +void NumpyRepeatsOpForward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + const TBlob& iTBlob = inputs[0]; + const mxnet::TShape& ishape = iTBlob.shape_; + if (!shape_is_known(ishape)) return; + Stream *s = ctx.get_stream(); + + int repeats = 0; + dmlc::optional axisOpt; + int axis = -1; + const RepeatsParam& param = nnvm::get(attrs.parsed); + GetRepeatsParams(param, ishape, &repeats, &axisOpt, &axis); + if (0 == repeats) return; + mxnet::Tuple repts = param.repeats.value(); + if (repts.ndim() == 1) { + int len = static_cast(axisOpt) ? ishape[axis] : ishape.Size(); + std::vector temp(len, repeats); + repts = mxnet::Tuple(temp); + } + for (int i=1; i < repts.ndim(); i++) { + repts[i] += repts[i-1]; + } + size_t total_temp_size = repts.ndim() * sizeof(int); + Tensor temp_space = + ctx.requested[0].get_space_typed(Shape1(total_temp_size), s); + int* ind = reinterpret_cast(temp_space.dptr_); + + if (ctx.run_ctx.ctx.dev_mask() == gpu::kDevMask) { + #if MXNET_USE_CUDA + cudaMemcpyAsync(ind, repts.begin(), repts.ndim() * sizeof(int), + cudaMemcpyHostToDevice, Stream::GetStream(ctx.get_stream())); + #else + LOG(FATAL) << "Illegal attempt to use GPU in a CPU-only build"; + #endif + } else { + std::memcpy(ind, repts.begin(), repts.ndim() * sizeof(int)); + } + + if (!param.axis.has_value()) { + mshadow::Stream *s = ctx.get_stream(); + const TBlob& in_data = inputs[0]; + const TBlob& out_data = outputs[0]; + MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, IType, { + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, { + mxnet_op::Kernel::Launch( + s, out_data.Size(), out_data.dptr(), + in_data.dptr(), ind); + }); + }); + } else { + mshadow::Stream *s = ctx.get_stream(); + const TBlob& in_data = inputs[0]; + const TBlob& out_data = outputs[0]; + int stride = 1; + for (int i = 1; i < ishape.ndim(); i++) { + stride *= ishape[i]; + } + + MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, IType, { + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, { + mxnet_op::Kernel::Launch( + s, out_data.Size(), out_data.dptr(), + in_data.dptr(), ind, stride); + }); + }); + } +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_NUMPY_NP_REPEAT_OP_INL_H_ diff --git a/src/operator/numpy/np_repeat_op.cc b/src/operator/numpy/np_repeat_op.cc new file mode 100644 index 000000000000..ad8803ba02d1 --- /dev/null +++ b/src/operator/numpy/np_repeat_op.cc @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! +* Copyright (c) 2019 by Contributors +* \file np_repeat_op.cc +* \brief CPU implementation of numpy repeat operator +*/ + +#include "./np_repeat_op-inl.h" +#include "../tensor/matrix_op-inl.h" + +namespace mxnet { +namespace op { + +DMLC_REGISTER_PARAMETER(RepeatsParam); + +NNVM_REGISTER_OP(_npi_repeats) +.set_attr_parser(ParamParser) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr("FInferShape", RepeatsOpShape) +.set_attr("FInferType", RepeatOpType) +.set_attr("FResourceRequest", + [](const NodeAttrs& n) { + return std::vector{ResourceRequest::kTempSpace}; +}) +.set_attr("FCompute", NumpyRepeatsOpForward) +.set_attr("FGradient", MakeZeroGradNodes) +.add_argument("data", "NDArray-or-Symbol", "Input data array") +.add_arguments(RepeatsParam::__FIELDS__()); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/np_repeat_op.cu b/src/operator/numpy/np_repeat_op.cu new file mode 100644 index 000000000000..4f57278df38b --- /dev/null +++ b/src/operator/numpy/np_repeat_op.cu @@ -0,0 +1,35 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +*/ + +/*! +* Copyright (c) 2015 by Contributors +* \file np_repeat_op.cu +* \brief GPU Implementation of numpy-compatible repeat operator +*/ +#include +#include "./np_repeat_op-inl.h" + +namespace mxnet { +namespace op { + +NNVM_REGISTER_OP(_npi_repeats) +.set_attr("FCompute", NumpyRepeatsOpForward); + +} // namespace op +} // namespace mxnet diff --git a/tests/python/unittest/test_numpy_interoperability.py b/tests/python/unittest/test_numpy_interoperability.py index 0060b73c06ce..de363d2ee69f 100644 --- a/tests/python/unittest/test_numpy_interoperability.py +++ b/tests/python/unittest/test_numpy_interoperability.py @@ -1153,11 +1153,11 @@ def _add_workload_repeat(array_pool): m = _np.array([1, 2, 3, 4, 5, 6]) m_rect = m.reshape((2, 3)) - # OpArgMngr.add_workload('repeat', np.array(m), [1, 3, 2, 1, 1, 2]) # Argument "repeats" only supports int + OpArgMngr.add_workload('repeat', np.array(m), [1, 3, 2, 1, 1, 2]) # Argument "repeats" only supports int OpArgMngr.add_workload('repeat', np.array(m), 2) B = np.array(m_rect) - # OpArgMngr.add_workload('repeat', B, [2, 1], axis=0) # Argument "repeats" only supports int - # OpArgMngr.add_workload('repeat', B, [1, 3, 2], axis=1) # Argument "repeats" only supports int + OpArgMngr.add_workload('repeat', B, [2, 1], axis=0) # Argument "repeats" only supports int + OpArgMngr.add_workload('repeat', B, [1, 3, 2], axis=1) # Argument "repeats" only supports int OpArgMngr.add_workload('repeat', B, 2, axis=0) OpArgMngr.add_workload('repeat', B, 2, axis=1) @@ -1165,7 +1165,7 @@ def _add_workload_repeat(array_pool): a = _np.arange(60).reshape(3, 4, 5) for axis in itertools.chain(range(-a.ndim, a.ndim), [None]): OpArgMngr.add_workload('repeat', np.array(a), 2, axis=axis) - # OpArgMngr.add_workload('repeat', np.array(a), [2], axis=axis) # Argument "repeats" only supports int + OpArgMngr.add_workload('repeat', np.array(a), [2], axis=axis) # Argument "repeats" only supports int def _add_workload_reshape(): diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 550b6dd42d32..3ea6119e69a0 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -5309,6 +5309,8 @@ def test_np_repeat(): ((4, 2), 2, 0), ((4, 2), 2, 1), ((4, 2), 2, -1), + ((4, 2), [2,3] * 4, None), + ((4, 2), [1,2], 1), ] class TestRepeat(HybridBlock):