Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNet-1211] Factor and "Like" modes in BilinearResize2D operator #13226

Merged
merged 9 commits into from
May 5, 2019
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ List of Contributors
* [Dang Trung Kien](https://github.com/kiendang)
* [Zach Boldyga](https://github.com/zboldyga)
* [Gordon Reid](https://github.com/gordon1992)
* [Mikhail Lobanov](https://github.com/lobanov-m)
* [Ming Yang](http://ufoym.com)
* [Satya Krishna Gorti](https://github.com/satyakrishnagorti)
* [Neo Chien](https://github.com/cchung100m)
Expand Down
4 changes: 4 additions & 0 deletions python/mxnet/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,10 @@ def looks_like_weight(name):
continue
else:
inputs = node["inputs"]

if node['op'] == '_contrib_BilinearResize2D':
inputs = [inputs[0]]

for item in inputs:
input_node = nodes[item[0]]
input_name = input_node["name"]
Expand Down
173 changes: 154 additions & 19 deletions src/operator/contrib/bilinear_resize-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@
#include "../mxnet_op.h"
#include "../mshadow_op.h"

namespace bilinear_resize {
enum BilinearResizeOpMode{simple, odd_scale, like, to_even_down, to_even_up, to_odd_down,
to_odd_up};
} // namespace bilinear_resize


namespace mxnet {
namespace op {

Expand All @@ -52,15 +58,45 @@ struct BilinearSampleParam : public dmlc::Parameter<BilinearSampleParam> {
int width;
dmlc::optional<float> scale_height;
dmlc::optional<float> scale_width;
int mode;
DMLC_DECLARE_PARAMETER(BilinearSampleParam) {
DMLC_DECLARE_FIELD(height).set_default(1).set_range(1, 10000)
.describe("output height (required, but ignored if scale_height is defined)");
.describe("output height (required, but ignored if scale_height is defined or mode is not "
"\"size\")");
DMLC_DECLARE_FIELD(width).set_default(1).set_range(1, 10000)
.describe("output width (required, but ignored if scale_width is defined)");
.describe("output width (required, but ignored if scale_width is defined or mode is not "
"\"size\")");
DMLC_DECLARE_FIELD(scale_height).set_default(dmlc::optional<float>())
.describe("sampling scale of the height (optional, ignores height if defined)");
.describe("sampling scale of the height (optional, used in modes \"scale\" and \"odd_scale\")");
DMLC_DECLARE_FIELD(scale_width).set_default(dmlc::optional<float>())
.describe("sampling scale of the scale_width (optional, ignores width if defined)");
.describe("sampling scale of the width (optional, used in modes \"scale\" and \"odd_scale\")");
DMLC_DECLARE_FIELD(mode)
.add_enum("size", bilinear_resize::simple)
.add_enum("odd_scale", bilinear_resize::odd_scale)
.add_enum("like", bilinear_resize::like)
.add_enum("to_even_down", bilinear_resize::to_even_down)
.add_enum("to_even_up", bilinear_resize::to_even_up)
.add_enum("to_odd_down", bilinear_resize::to_odd_down)
.add_enum("to_odd_up", bilinear_resize::to_odd_up)
.set_default(bilinear_resize::simple)
.describe("resizing mode. \"simple\" - output height equals parameter \"height\" if "
"\"scale_height\" parameter is not defined or input height multiplied by "
"\"scale_height\" otherwise. Same for width;"
"\"odd_scale\" - if original height or width is odd, then result height is "
"calculated like result_h = (original_h - 1) * scale + 1; "
"for scale > 1 the result shape would be like if we did deconvolution with kernel "
"= (1, 1) and stride = (height_scale, width_scale); and for scale < 1 shape "
"would be like we did convolution with kernel = (1, 1) and "
"stride = (int(1 / height_scale), int( 1/ width_scale);"
"\"like\" - resize first input to the height and width of second input; "
"\"to_even_down\" - resize input to nearest lower even height and width "
"(if original height is odd then result height = original height - 1);"
"\"to_even_up\" - resize input to nearest bigger even height and width "
"(if original height is odd then result height = original height + 1);"
"\"to_odd_down\" - resize input to nearest odd height and width "
"(if original height is odd then result height = original height - 1);"
"\"to_odd_up\" - resize input to nearest odd height and width "
"(if original height is odd then result height = original height + 1);");
}
};

Expand All @@ -76,7 +112,8 @@ void SpatialUpSamplingBilinearUpdateOutput(mshadow::Stream<cpu> *s,
template<typename xpu, typename DType, typename AccReal>
void SpatialUpSamplingBilinearUpdateGradInput(mshadow::Stream<cpu> *s,
const std::vector<TBlob> &input,
const std::vector<TBlob> &output);
const std::vector<TBlob> &output,
bool modeLike);

#if MXNET_USE_CUDA
template<typename xpu, typename DType, typename AccReal>
Expand All @@ -87,7 +124,8 @@ void SpatialUpSamplingBilinearUpdateOutput(mshadow::Stream<gpu> *s,
template<typename xpu, typename DType, typename AccReal>
void SpatialUpSamplingBilinearUpdateGradInput(mshadow::Stream<gpu> *s,
const std::vector<TBlob> &input,
const std::vector<TBlob> &output);
const std::vector<TBlob> &output,
bool modeLike);
#endif // MXNET_USE_CUDA

template <typename xpu>
Expand All @@ -96,7 +134,9 @@ inline void BilinearSampleOpForward(const nnvm::NodeAttrs& attrs,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
CHECK_EQ(inputs.size(), 1U);
const BilinearSampleParam& param = nnvm::get<BilinearSampleParam>(attrs.parsed);
size_t expected = param.mode == bilinear_resize::like ? 2 : 1;
CHECK_EQ(inputs.size(), expected);
CHECK_EQ(outputs.size(), 1U);
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
MSHADOW_REAL_TYPE_SWITCH_EX(inputs[0].type_flag_, DType, AccReal, {
Expand All @@ -111,8 +151,11 @@ inline void BilinearSampleOpBackward(const nnvm::NodeAttrs& attrs,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
const BilinearSampleParam& param = nnvm::get<BilinearSampleParam>(attrs.parsed);
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);
bool modeLike = param.mode == bilinear_resize::like;
size_t expected = modeLike ? 2 : 1;
CHECK_EQ(outputs.size(), expected);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so there will be 2 outputs in like mode? what is the second output? if it does have two outputs, it should be specified in the param description.

Copy link
Contributor Author

@lobanov-m lobanov-m Nov 26, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are two outputs of the backward function. The operator still has one output. In "like" mod we pass two input tensors to resize one to the size of second, so the backward function should return gradients to both tensors. Actually the second tensor, from which we get result size, should get zero gradients from the output of this operator, because it is needed only to get it's shape. It is realized in cc and cu files.

mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
if (IsWriting(req[0])) {
// zero grad before backwarding
Expand All @@ -121,7 +164,7 @@ inline void BilinearSampleOpBackward(const nnvm::NodeAttrs& attrs,
})
}
MSHADOW_REAL_TYPE_SWITCH_EX(inputs[0].type_flag_, DType, AccReal, {
SpatialUpSamplingBilinearUpdateGradInput<xpu, DType, AccReal>(s, inputs, outputs);
SpatialUpSamplingBilinearUpdateGradInput<xpu, DType, AccReal>(s, inputs, outputs, modeLike);
});
}

Expand All @@ -130,28 +173,120 @@ static bool BilinearSampleOpInferShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *in_shape,
mxnet::ShapeVector *out_shape) {
using namespace mshadow;
CHECK_EQ(in_shape->size(), 1U) << "Input:[data]";
CHECK_EQ(out_shape->size(), 1U) << "Output:[data]";
const BilinearSampleParam& param = nnvm::get<BilinearSampleParam>(attrs.parsed);
size_t expected = param.mode == bilinear_resize::like ? 2 : 1;
CHECK_EQ(in_shape->size(), expected);
mxnet::TShape dshape(in_shape->at(0));
if (mxnet::op::shape_is_none(dshape)) return false;
if (param.scale_height.has_value()) {
dshape[2] = static_cast<int>(param.scale_height.value() * in_shape->at(0)[2]);
} else {
dshape[2] = param.height;
int16_t new_height = -1;
int16_t new_width = -1;
switch (param.mode) {
case bilinear_resize::simple:
{
if (param.scale_height.has_value()) {
new_height = static_cast<int>(param.scale_height.value() * in_shape->at(0)[2]);
} else {
new_height = param.height;
}
if (param.scale_height.has_value()) {
new_width = static_cast<int>(param.scale_width.value() * in_shape->at(0)[3]);
} else {
new_width = param.width;
}
break;
}
case bilinear_resize::odd_scale:
{
new_height = ((dshape[2] % 2) == 0) ? (int16_t) (dshape[2] * param.scale_height.value()) :
(int16_t) ((dshape[2] - 1) * param.scale_height.value()) + 1;
new_width = ((dshape[3] % 2) == 0) ? (int16_t) (dshape[3] * param.scale_width.value()) :
(int16_t) ((dshape[3] - 1) * param.scale_width.value()) + 1;
break;
}
case bilinear_resize::like:
{
TShape like_shape(in_shape->at(1));
if (dshape.ndim() == 0) return false;
new_height = like_shape[2];
new_width = like_shape[3];
break;
}
case bilinear_resize::to_even_down:
{
new_height = ((dshape[2] % 2) == 0) ? dshape[2] : dshape[2] - 1;
new_width = ((dshape[3] % 2) == 0) ? dshape[3] : dshape[3] - 1;
break;
}
case bilinear_resize::to_even_up:
{
new_height = ((dshape[2] % 2) == 0) ? dshape[2] : dshape[2] + 1;
new_width = ((dshape[3] % 2) == 0) ? dshape[3] : dshape[3] + 1;
break;
}
case bilinear_resize::to_odd_down:
{
new_height = ((dshape[2] % 2) == 1) ? dshape[2] : dshape[2] - 1;
new_width = ((dshape[3] % 2) == 1) ? dshape[3] : dshape[3] - 1;
break;
}
case bilinear_resize::to_odd_up:
{
new_height = ((dshape[2] % 2) == 1) ? dshape[2] : dshape[2] + 1;
new_width = ((dshape[3] % 2) == 1) ? dshape[3] : dshape[3] + 1;
break;
}
default:
{
LOG(FATAL) << "Invalid mode " << param.mode;
}
}

if (param.scale_height.has_value()) {
dshape[3] = static_cast<int>(param.scale_width.value() * in_shape->at(0)[3]);
} else {
dshape[3] = param.width;
}
dshape[2] = new_height;
dshape[3] = new_width;

out_shape->clear();
out_shape->push_back(dshape);
return true;
}


inline uint16_t BilinearSampleOpNumInputs(const NodeAttrs& attrs) {
auto& param = nnvm::get<BilinearSampleParam>(attrs.parsed);
if (param.mode == bilinear_resize::like) {
return 2;
} else {
return 1;
}
}

inline uint16_t BilinearSampleOpNumBackwardInputs(const NodeAttrs& attrs) {
auto& param = nnvm::get<BilinearSampleParam>(attrs.parsed);
if (param.mode == bilinear_resize::like) {
return 3;
} else {
return 1;
}
}

inline uint16_t BilinearSampleOpNumBackwardOutputs(const NodeAttrs& attrs) {
auto& param = nnvm::get<BilinearSampleParam>(attrs.parsed);
if (param.mode == bilinear_resize::like) {
return 2;
} else {
return 1;
}
}

inline std::vector<std::string> BilinearSampleOpInputNames(const NodeAttrs& attrs) {
auto& param = nnvm::get<BilinearSampleParam>(attrs.parsed);
if (param.mode == bilinear_resize::like) {
return std::vector<std::string>{"data", "like"};
} else {
return std::vector<std::string>{"data"};
}
}

} // namespace op
} // namespace mxnet

Expand Down
52 changes: 36 additions & 16 deletions src/operator/contrib/bilinear_resize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ void SpatialUpSamplingBilinearUpdateOutput(mshadow::Stream<cpu> *s,
template<typename xpu, typename DType, typename AccReal>
void SpatialUpSamplingBilinearUpdateGradInput(mshadow::Stream<cpu> *s,
const std::vector<TBlob> &input,
const std::vector<TBlob> &output) {
const std::vector<TBlob> &output,
bool modeLike) {
Tensor<xpu, 4, DType> gradOutput = input[0].get<xpu, 4, DType>(s);
Tensor<xpu, 4, DType> gradInput = output[0].get<xpu, 4, DType>(s);

Expand All @@ -108,8 +109,8 @@ void SpatialUpSamplingBilinearUpdateGradInput(mshadow::Stream<cpu> *s,
int inputHeight = gradInput.size(2);
int inputWidth = gradInput.size(3);

DType *data1 = gradInput.dptr_;
DType *data2 = gradOutput.dptr_;
DType *dataInput = gradInput.dptr_;
DType *dataOutput = gradOutput.dptr_;
channels = nbatch * channels;

// special case: same-size matching grids
Expand All @@ -118,8 +119,8 @@ void SpatialUpSamplingBilinearUpdateGradInput(mshadow::Stream<cpu> *s,
const int h1 = h2;
for (int w2 = 0; w2 < outputWidth; ++w2) {
const int w1 = w2;
DType* pos1 = &data1[h1 * inputWidth + w1];
const DType* pos2 = &data2[h2 * outputWidth + w2];
DType* pos1 = &dataInput[h1 * inputWidth + w1];
const DType* pos2 = &dataOutput[h2 * outputWidth + w2];
for (int c = 0; c < channels; ++c) {
pos1[0] += pos2[0];
pos1 += inputWidth * inputHeight;
Expand All @@ -145,15 +146,32 @@ void SpatialUpSamplingBilinearUpdateGradInput(mshadow::Stream<cpu> *s,
const int w1p = (w1 < inputWidth - 1) ? 1 : 0;
const DType w1lambda = w1r - w1;
const DType w0lambda = (DType)1. - w1lambda;
DType* pos1 = &data1[h1 * inputWidth + w1];
const DType* pos2 = &data2[h2 * outputWidth + w2];
DType* posInput = &dataInput[h1 * inputWidth + w1];
const DType* posOutput = &dataOutput[h2 * outputWidth + w2];
for (int c = 0; c < channels; ++c) {
pos1[0] += h0lambda * w0lambda * pos2[0];
pos1[w1p] += h0lambda * w1lambda * pos2[0];
pos1[h1p * inputWidth] += h1lambda * w0lambda * pos2[0];
pos1[h1p * inputWidth + w1p] += h1lambda * w1lambda * pos2[0];
pos1 += inputWidth * inputHeight;
pos2 += outputWidth * outputHeight;
posInput[0] += h0lambda * w0lambda * posOutput[0];
posInput[w1p] += h0lambda * w1lambda * posOutput[0];
posInput[h1p * inputWidth] += h1lambda * w0lambda * posOutput[0];
posInput[h1p * inputWidth + w1p] += h1lambda * w1lambda * posOutput[0];
posInput += inputWidth * inputHeight;
posOutput += outputWidth * outputHeight;
}
}
}

if (modeLike) {
Tensor<xpu, 4, DType> gradInputLike = output[1].get<xpu, 4, DType>(s);
int inputHeightLike = gradInputLike.size(2);
int inputWidthLike = gradInputLike.size(3);
DType *dataInputLike = gradInputLike.dptr_;
int channelsLike = nbatch * gradInputLike.size(1);
for (int h_like = 0; h_like < inputHeightLike; ++h_like) {
for (int w_like = 0; w_like < inputWidthLike; ++w_like) {
DType *posInput = &dataInputLike[h_like * inputWidthLike + w_like];
for (int c = 0; c < channelsLike; ++c) {
posInput[0] = 0;
posInput += inputWidthLike * inputHeightLike;
}
}
}
}
Expand All @@ -174,19 +192,21 @@ first in one direction, and then again in the other direction. See the wikipedia
for more details.
)code" ADD_FILELINE)
.set_attr_parser(ParamParser<BilinearSampleParam>)
.set_num_inputs(1)
.set_num_inputs(BilinearSampleOpNumInputs)
.set_num_outputs(1)
.set_attr<nnvm::FListInputNames>("FListInputNames", BilinearSampleOpInputNames)
.set_attr<mxnet::FInferShape>("FInferShape", BilinearSampleOpInferShape)
.set_attr<FCompute>("FCompute<cpu>", BilinearSampleOpForward<cpu>)
.set_attr<nnvm::FGradient>("FGradient",
ElemwiseGradUseNone{"_backward_contrib_BilinearResize2D"})
.add_argument("data", "NDArray-or-Symbol", "Input data")
.add_argument("like", "NDArray-or-Symbol", "Resize data to it's shape")
.add_arguments(BilinearSampleParam::__FIELDS__());

NNVM_REGISTER_OP(_backward_contrib_BilinearResize2D)
.set_attr_parser(ParamParser<BilinearSampleParam>)
.set_num_inputs(1)
.set_num_outputs(1)
.set_num_inputs(BilinearSampleOpNumBackwardInputs)
.set_num_outputs(BilinearSampleOpNumBackwardOutputs)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FCompute>("FCompute<cpu>", BilinearSampleOpBackward<cpu>);

Expand Down
Loading