Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
refactor moveaxis code
Browse files Browse the repository at this point in the history
  • Loading branch information
gyshi committed Oct 4, 2019
1 parent 9753eb1 commit 8ce4075
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 68 deletions.
52 changes: 30 additions & 22 deletions src/operator/numpy/np_matrix_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -375,41 +375,30 @@ struct NumpyMoveaxisParam : public dmlc::Parameter<NumpyMoveaxisParam> {
}
};

template<typename xpu>
void NumpyMoveaxisCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mshadow::expr;
inline mxnet::TShape NumpyMoveaxisShapeImpl(const nnvm::NodeAttrs& attrs,
const int& ndim) {
const NumpyMoveaxisParam& param = nnvm::get<NumpyMoveaxisParam>(attrs.parsed);
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req[0], kWriteTo) << "Moveaxis does not support inplace";
mxnet::TShape axes(inputs[0].ndim(), -1);
mxnet::TShape axes(ndim, -1);
std::vector<bool> state_axes(ndim, false);
mxnet::TShape real_src(param.source.ndim(), -1);
mxnet::TShape real_des(param.destination.ndim(), -1);
std::vector<bool> state_axes(inputs[0].ndim(), false);
CHECK_EQ(param.source.ndim(), param.destination.ndim())
<< "source and destination not equal.";
for (int i = 0; i < param.source.ndim(); ++i) {
if (param.source[i] >= 0) {
CHECK_LT(static_cast<size_t>(param.source[i]), inputs[0].ndim());
CHECK_LT(static_cast<size_t>(param.source[i]), ndim);
real_src[i] = param.source[i];
} else {
CHECK_LT(param.source[i] + inputs[0].ndim(), inputs[0].ndim());
real_src[i] = param.source[i] + inputs[0].ndim();
CHECK_LT(param.source[i] + ndim, ndim);
real_src[i] = param.source[i] + ndim;
}
if (param.destination[i] >= 0) {
CHECK_LT(static_cast<size_t>(param.destination[i]), inputs[0].ndim());
CHECK_LT(static_cast<size_t>(param.destination[i]), ndim);
real_des[i] = param.destination[i];
} else {
CHECK_LT(param.destination[i] + inputs[0].ndim(), inputs[0].ndim());
real_des[i] = param.destination[i] + inputs[0].ndim();
CHECK_LT(param.destination[i] + ndim, ndim);
real_des[i] = param.destination[i] + ndim;
}
}
if (inputs[0].ndim() > 1) {
if (ndim > 1) {
for (int i = 0; i < param.source.ndim() - 1; ++i) {
for (int j = i + 1; j < param.source.ndim(); ++j) {
CHECK_NE(real_src[i], real_src[j])
Expand All @@ -434,6 +423,25 @@ void NumpyMoveaxisCompute(const nnvm::NodeAttrs& attrs,
}
}
}
return axes;
}

template<typename xpu>
void NumpyMoveaxisCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
using namespace mshadow::expr;
const NumpyMoveaxisParam& param = nnvm::get<NumpyMoveaxisParam>(attrs.parsed);
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req[0], kWriteTo) << "Moveaxis does not support inplace";
CHECK_EQ(param.source.ndim(), param.destination.ndim())
<< "source and destination not equal.";
mxnet::TShape axes;
axes = NumpyMoveaxisShapeImpl(attrs, inputs[0].ndim());
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, Dtype, {
TransposeImpl<xpu>(ctx.run_ctx, inputs[0], outputs[0], axes);
})
Expand Down
48 changes: 2 additions & 46 deletions src/operator/numpy/np_matrix_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -625,51 +625,8 @@ bool NumpyMoveaxisShape(const nnvm::NodeAttrs& attrs,
CHECK_EQ(param.source.ndim(), param.destination.ndim())
<< "source and destination not equal.";
mxnet::TShape ret(shp.ndim(), -1);
mxnet::TShape axes(shp.ndim(), -1);
std::vector<bool> state_axes(shp.ndim(), false);
mxnet::TShape real_src(param.source.ndim(), -1);
mxnet::TShape real_des(param.destination.ndim(), -1);
for (int i = 0; i < param.source.ndim(); ++i) {
if (param.source[i] >= 0) {
CHECK_LT(static_cast<size_t>(param.source[i]), shp.ndim());
real_src[i] = param.source[i];
} else {
CHECK_LT(param.source[i] + shp.ndim(), shp.ndim());
real_src[i] = param.source[i] + shp.ndim();
}
if (param.destination[i] >= 0) {
CHECK_LT(static_cast<size_t>(param.destination[i]), shp.ndim());
real_des[i] = param.destination[i];
} else {
CHECK_LT(param.destination[i] + shp.ndim(), shp.ndim());
real_des[i] = param.destination[i] + shp.ndim();
}
}
if (shp.ndim() > 1) {
for (int i = 0; i < param.source.ndim() - 1; ++i) {
for (int j = i + 1; j < param.source.ndim(); ++j) {
CHECK_NE(real_src[i], real_src[j])
<< "repeated axis in `source` argument";
CHECK_NE(real_des[i], real_des[j])
<< "repeated axis in `destination` argument";
}
}
}
for (int i = 0; i < param.source.ndim(); ++i) {
axes[real_des[i]] = real_src[i];
state_axes[real_src[i]] = true;
}
for (int i = 0; i < axes.ndim(); ++i) {
if (axes[i] < 0) {
for (int j = 0; j < axes.ndim(); ++j) {
if (state_axes[j] == false) {
axes[i] = j;
state_axes[j] = true;
break;
}
}
}
}
mxnet::TShape axes;
axes = NumpyMoveaxisShapeImpl(attrs, shp.ndim());
for (int i = 0; i < shp.ndim(); ++i) {
CHECK(axes[i] < static_cast<int64_t>(shp.ndim()));
ret[i] = shp[axes[i]];
Expand Down Expand Up @@ -745,7 +702,6 @@ inline bool NumpyRot90Shape(const nnvm::NodeAttrs& attrs,
res[real_axes[0]] += res[real_axes[1]];
res[real_axes[1]] = res[real_axes[0]] - res[real_axes[1]];
res[real_axes[0]] -= res[real_axes[1]];

SHAPE_ASSIGN_CHECK(*out_attrs, 0, res);
return shape_is_known(res);
}
Expand Down

0 comments on commit 8ce4075

Please sign in to comment.