Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[numpy] fix op repeat with list input (#18371)
Browse files Browse the repository at this point in the history
* except .h

* except storage

* repeat

* change fwd

* delete

* codecov

Co-authored-by: Ubuntu <[email protected]>
  • Loading branch information
Yiyan66 and Ubuntu committed Jun 9, 2020
1 parent 028d01d commit cf3984b
Show file tree
Hide file tree
Showing 9 changed files with 378 additions and 28 deletions.
8 changes: 7 additions & 1 deletion python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -4076,7 +4076,13 @@ def repeat(a, repeats, axis=None):
[3, 4],
[3, 4]])
"""
return _api_internal.repeat(a, repeats, axis)
if isinstance(repeats, numeric_types):
repeats = [repeats]
if axis is not None:
tmp = swapaxes(a, 0, axis)
res = _api_internal.repeats(tmp, repeats, 0)
return swapaxes(res, 0, axis)
return _api_internal.repeats(a, repeats, axis)


# pylint: disable=redefined-outer-name
Expand Down
8 changes: 7 additions & 1 deletion python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -2448,7 +2448,13 @@ def repeat(a, repeats, axis=None):
[3, 4],
[3, 4]])
"""
return _npi.repeat(a, repeats=repeats, axis=axis)
if isinstance(repeats, numeric_types):
repeats = [repeats]
if axis is not None:
tmp = swapaxes(a, 0, axis)
res = _npi.repeats(tmp, repeats=repeats, axis=0)
return swapaxes(res, 0, axis)
return _npi.repeats(a, repeats=repeats, axis=axis)


def _unary_func_helper(x, fn_array, fn_scalar, out=None, **kwargs):
Expand Down
22 changes: 0 additions & 22 deletions src/api/operator/numpy/np_matrix_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -470,28 +470,6 @@ MXNET_REGISTER_API("_npi.diag_indices_from")
*ret = ndoutputs[0];
});

MXNET_REGISTER_API("_npi.repeat")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_repeat");
nnvm::NodeAttrs attrs;
op::RepeatParam param;
param.repeats = args[1].operator int();
if (args[2].type_code() == kNull) {
param.axis = dmlc::optional<int>();
} else {
param.axis = args[2].operator int64_t();
}
int num_inputs = 1;
int num_outputs = 0;
attrs.parsed = std::move(param);
attrs.op = op;
SetAttrDict<op::RepeatParam>(&attrs);
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr);
*ret = ndoutputs[0];
});

MXNET_REGISTER_API("_npi.diagflat")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
Expand Down
52 changes: 52 additions & 0 deletions src/api/operator/numpy/np_repeat_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file np_repeat_op.cc
* \brief Implementation of the API of functions in src/operator/numpy/np_repeat_op.cc
*/
#include <mxnet/api_registry.h>
#include "../utils.h"
#include "../../../operator/numpy/np_repeat_op-inl.h"

namespace mxnet {

MXNET_REGISTER_API("_npi.repeats")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_repeats");
nnvm::NodeAttrs attrs;
op::RepeatsParam param;
param.repeats = Tuple<int>(args[1].operator ObjectRef());;
if (args[2].type_code() == kNull) {
param.axis = dmlc::optional<int>();
} else {
param.axis = args[2].operator int64_t();
}
int num_inputs = 1;
int num_outputs = 0;
attrs.parsed = std::move(param);
attrs.op = op;
SetAttrDict<op::RepeatsParam>(&attrs);
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr);
*ret = ndoutputs[0];
});

} // namespace mxnet
221 changes: 221 additions & 0 deletions src/operator/numpy/np_repeat_op-inl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2019 by Contributors
* \file np_repeat_op-inl.h
* \brief Function definition of the repeat op
*/

#ifndef MXNET_OPERATOR_NUMPY_NP_REPEAT_OP_INL_H_
#define MXNET_OPERATOR_NUMPY_NP_REPEAT_OP_INL_H_

#include <mxnet/operator_util.h>
#include <vector>
#include <string>
#include <algorithm>
#include <utility>
#include <type_traits>
#include <unordered_map>
#include "../mshadow_op.h"
#include "../elemwise_op_common.h"
#include "../channel_op_common.h"
#include "../mxnet_op.h"
#include "../../common/static_array.h"

namespace mxnet {
namespace op {

struct RepeatsParam : public dmlc::Parameter<RepeatsParam> {
dmlc::optional<mxnet::Tuple<int>> repeats;
dmlc::optional<int> axis;
DMLC_DECLARE_PARAMETER(RepeatsParam) {
DMLC_DECLARE_FIELD(repeats)
.describe("The number of repetitions for each element.");
DMLC_DECLARE_FIELD(axis)
.set_default(dmlc::optional<int>())
.describe("The axis along which to repeat values."
" The negative numbers are interpreted counting from the backward."
" By default, use the flattened input array,"
" and return a flat output array.");
}
void SetAttrDict(std::unordered_map<std::string, std::string>* dict) {
std::ostringstream repeats_s, axis_s;
repeats_s << repeats;
axis_s << axis;
(*dict)["repeats"] = repeats_s.str();
(*dict)["axis"] = axis_s.str();
}
};

inline void GetRepeatsParams(const RepeatsParam& param, const mxnet::TShape& ishape,
int* repeats, dmlc::optional<int>* axisOpt, int* axis) {
*repeats = 0;
const mxnet::Tuple<int> &repts = param.repeats.value();
for (int i=0; i < repts.ndim(); i++) {
CHECK_GE(repts[i], 0) << "repeats cannot be a negative number";
*repeats += repts[i];
}
*axisOpt = param.axis;
if (static_cast<bool>(*axisOpt)) {
int ndims = ishape.ndim();
*axis = axisOpt->value();
if (*axis < 0) {
*axis += ndims;
}
CHECK(*axis >= 0 && *axis < ndims) << "axis = " << axisOpt->value() << " out of bounds";
}
}

inline bool RepeatsOpShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *in_attrs,
mxnet::ShapeVector *out_attrs) {
const RepeatsParam& param = nnvm::get<RepeatsParam>(attrs.parsed);
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
const mxnet::TShape& ishape = (*in_attrs)[0];
int repeats = 0;
dmlc::optional<int> axisOpt;
int axis = -1;
GetRepeatsParams(param, ishape, &repeats, &axisOpt, &axis);
// If 0 repeats, return an empty 1-dim, 0-size array
if (0 == repeats) {
SHAPE_ASSIGN_CHECK(*out_attrs, 0, mxnet::TShape(1, 0));
return true;
}

// If repeats > 0, multiply the size of the corresponding axis by repeats
if (static_cast<bool>(axisOpt)) {
mxnet::TShape shape(ishape.ndim(), -1);
for (int i = 0; i < ishape.ndim(); ++i) {
if (i == axis) {
shape[i] = param.repeats.value().ndim() == 1 ? repeats * ishape[i] : repeats;
} else {
shape[i] = ishape[i];
}
}
SHAPE_ASSIGN_CHECK(*out_attrs, 0, shape);
} else { // If axis is not input by user, return a flat 1D array of size = repeats
repeats = param.repeats.value().ndim() == 1 ? ishape.Size() * repeats : repeats;
mxnet::TShape shape(1, repeats);
SHAPE_ASSIGN_CHECK(*out_attrs, 0, shape);
}
return shape_is_known(out_attrs->at(0));
}

struct repeat_noaxis_fwd {
template<typename IType, typename OType>
MSHADOW_XINLINE static void Map(index_t i, OType* out, IType* input,
const int* indx) {
using namespace mxnet_op;
int ind = 0;
while (i >= indx[ind]) ind++;
out[i] = input[ind];
}
};

struct repeat_axis_fwd {
template<typename IType, typename OType>
MSHADOW_XINLINE static void Map(index_t i, OType* out, IType* input,
const int* indx, int stride) {
using namespace mxnet_op;
int ind_row = i / stride, ind_col = i % stride;
int ind = 0;
while (ind_row >= indx[ind]) ind++;
out[i] = input[ind * stride + ind_col];
}
};

template<typename xpu>
void NumpyRepeatsOpForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
const TBlob& iTBlob = inputs[0];
const mxnet::TShape& ishape = iTBlob.shape_;
if (!shape_is_known(ishape)) return;
Stream<xpu> *s = ctx.get_stream<xpu>();

int repeats = 0;
dmlc::optional<int> axisOpt;
int axis = -1;
const RepeatsParam& param = nnvm::get<RepeatsParam>(attrs.parsed);
GetRepeatsParams(param, ishape, &repeats, &axisOpt, &axis);
if (0 == repeats) return;
mxnet::Tuple<int> repts = param.repeats.value();
if (repts.ndim() == 1) {
int len = static_cast<bool>(axisOpt) ? ishape[axis] : ishape.Size();
std::vector<int> temp(len, repeats);
repts = mxnet::Tuple<int>(temp);
}
for (int i=1; i < repts.ndim(); i++) {
repts[i] += repts[i-1];
}
size_t total_temp_size = repts.ndim() * sizeof(int);
Tensor<xpu, 1, char> temp_space =
ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(total_temp_size), s);
int* ind = reinterpret_cast<int*>(temp_space.dptr_);

if (ctx.run_ctx.ctx.dev_mask() == gpu::kDevMask) {
#if MXNET_USE_CUDA
cudaMemcpyAsync(ind, repts.begin(), repts.ndim() * sizeof(int),
cudaMemcpyHostToDevice, Stream<gpu>::GetStream(ctx.get_stream<gpu>()));
#else
LOG(FATAL) << "Illegal attempt to use GPU in a CPU-only build";
#endif
} else {
std::memcpy(ind, repts.begin(), repts.ndim() * sizeof(int));
}

if (!param.axis.has_value()) {
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
const TBlob& in_data = inputs[0];
const TBlob& out_data = outputs[0];
MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, IType, {
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, {
mxnet_op::Kernel<repeat_noaxis_fwd, xpu>::Launch(
s, out_data.Size(), out_data.dptr<OType>(),
in_data.dptr<IType>(), ind);
});
});
} else {
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
const TBlob& in_data = inputs[0];
const TBlob& out_data = outputs[0];
int stride = 1;
for (int i = 1; i < ishape.ndim(); i++) {
stride *= ishape[i];
}

MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, IType, {
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, {
mxnet_op::Kernel<repeat_axis_fwd, xpu>::Launch(
s, out_data.Size(), out_data.dptr<OType>(),
in_data.dptr<IType>(), ind, stride);
});
});
}
}

} // namespace op
} // namespace mxnet

#endif // MXNET_OPERATOR_NUMPY_NP_REPEAT_OP_INL_H_
50 changes: 50 additions & 0 deletions src/operator/numpy/np_repeat_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2019 by Contributors
* \file np_repeat_op.cc
* \brief CPU implementation of numpy repeat operator
*/

#include "./np_repeat_op-inl.h"
#include "../tensor/matrix_op-inl.h"

namespace mxnet {
namespace op {

DMLC_REGISTER_PARAMETER(RepeatsParam);

NNVM_REGISTER_OP(_npi_repeats)
.set_attr_parser(ParamParser<RepeatsParam>)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr<mxnet::FInferShape>("FInferShape", RepeatsOpShape)
.set_attr<nnvm::FInferType>("FInferType", RepeatOpType)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& n) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<FCompute>("FCompute<cpu>", NumpyRepeatsOpForward<cpu>)
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.add_argument("data", "NDArray-or-Symbol", "Input data array")
.add_arguments(RepeatsParam::__FIELDS__());

} // namespace op
} // namespace mxnet
Loading

0 comments on commit cf3984b

Please sign in to comment.