diff --git a/docs/api/python/ndarray/contrib.md b/docs/api/python/ndarray/contrib.md index 3dcb6d18f95f..a9c12bd6198f 100644 --- a/docs/api/python/ndarray/contrib.md +++ b/docs/api/python/ndarray/contrib.md @@ -34,6 +34,7 @@ In the rest of this document, we list routines provided by the `ndarray.contrib` .. autosummary:: :nosignatures: + BilinearResize2D CTCLoss DeformableConvolution DeformablePSROIPooling diff --git a/docs/api/python/symbol/contrib.md b/docs/api/python/symbol/contrib.md index 7f5cc4bb3ff7..dbf9eb5a3a1f 100644 --- a/docs/api/python/symbol/contrib.md +++ b/docs/api/python/symbol/contrib.md @@ -34,6 +34,7 @@ In the rest of this document, we list routines provided by the `symbol.contrib` .. autosummary:: :nosignatures: + BilinearResize2D CTCLoss DeformableConvolution DeformablePSROIPooling diff --git a/src/operator/contrib/bilinear_resize-inl.h b/src/operator/contrib/bilinear_resize-inl.h new file mode 100644 index 000000000000..2d6385362734 --- /dev/null +++ b/src/operator/contrib/bilinear_resize-inl.h @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * Copyright (c) 2018 by Contributors + * \file bilinear_resize-inl.h + * \brief bilinear resize operator + * \author Hang Zhang +*/ +#ifndef MXNET_OPERATOR_CONTRIB_BILINEAR_RESIZE_INL_H_ +#define MXNET_OPERATOR_CONTRIB_BILINEAR_RESIZE_INL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +/* contrib +#include "../ndarray/ndarray_function.h" +#include "./operator_common.h" +#include "./mxnet_op.h" +#include "./mshadow_op.h" +*/ +#include "../../ndarray/ndarray_function.h" +#include "../operator_common.h" +#include "../mxnet_op.h" +#include "../mshadow_op.h" +#include "../tensor/init_op.h" + +namespace mxnet { +namespace op { + +struct BilinearSampleParam : public dmlc::Parameter { + int height; + int width; + DMLC_DECLARE_PARAMETER(BilinearSampleParam) { + DMLC_DECLARE_FIELD(height).set_range(1, 1000) + .describe("output height (required)"); + DMLC_DECLARE_FIELD(width).set_range(1, 1000) + .describe("output width (required)"); + } +}; + +static inline bool IsWriting(const OpReqType ort) { + return ort == kWriteTo || ort == kWriteInplace; +} + +template +void SpatialUpSamplingBilinearUpdateOutput(mshadow::Stream *s, + const std::vector &input, + const std::vector &output); + +template +void SpatialUpSamplingBilinearUpdateGradInput(mshadow::Stream *s, + const std::vector &input, + const std::vector &output); + +#if MXNET_USE_CUDA +template +void SpatialUpSamplingBilinearUpdateOutput(mshadow::Stream *s, + const std::vector &input, + const std::vector &output); + +template +void SpatialUpSamplingBilinearUpdateGradInput(mshadow::Stream *s, + const std::vector &input, + const std::vector &output); +#endif // MXNET_USE_CUDA + +template +inline void BilinearSampleOpForward(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 1U); + mshadow::Stream *s = ctx.get_stream(); + MSHADOW_REAL_TYPE_SWITCH_EX(inputs[0].type_flag_, DType, AccReal, { + SpatialUpSamplingBilinearUpdateOutput(s, inputs, outputs); + }); +} + + +template +inline void BilinearSampleOpBackward(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 1U); + mshadow::Stream *s = ctx.get_stream(); + if (IsWriting(req[0])) { + // zero grad before backwarding + MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, { + Fill(s, outputs[0], kWriteTo, 0); + }) + } + MSHADOW_REAL_TYPE_SWITCH_EX(inputs[0].type_flag_, DType, AccReal, { + SpatialUpSamplingBilinearUpdateGradInput(s, inputs, outputs); + }); +} + + +static bool BilinearSampleOpInferShape(const nnvm::NodeAttrs& attrs, + std::vector *in_shape, + std::vector *out_shape) { + using namespace mshadow; + CHECK_EQ(in_shape->size(), 1U) << "Input:[data]"; + CHECK_EQ(out_shape->size(), 1U) << "Output:[data]"; + const BilinearSampleParam& param = nnvm::get(attrs.parsed); + TShape dshape(in_shape->at(0)); + if (dshape.ndim() == 0) return false; + dshape[2] = param.height; + dshape[3] = param.width; + out_shape->clear(); + out_shape->push_back(dshape); + return true; +} + +static bool BilinearSampleOpInferType(const nnvm::NodeAttrs& attrs, + std::vector *in_type, + std::vector *out_type) { + using namespace mshadow; + CHECK_EQ(in_type->size(), 1U); + int dtype = (*in_type)[0]; + CHECK_NE(dtype, -1) << "First input must have specified type"; + // For float16 input type beta, gamma, mean, and average are stored in float32. + // For other input types, these parameters have the same type as input + // NOTE: This requirement is from cuDNN (v. 4 and 5) + int dtype_param = 0; + MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, { + dtype_param = mshadow::DataType::kFlag; }); + out_type->clear(); + out_type->push_back(dtype_param); + return true; +} + +static inline bool BilinearSampleOpStorageType(const nnvm::NodeAttrs &attrs, + const int dev_mask, + DispatchMode *dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 1); + CHECK_EQ(out_attrs->size(), 1); + *dispatch_mode = DispatchMode::kFCompute; + for (int& v : *in_attrs) { + if (v == - 1) v = kDefaultStorage; + } + for (size_t i = 0; i < out_attrs->size(); i++) { + (*out_attrs)[i] = kDefaultStorage; + } + return true; +} + + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_CONTRIB_BILINEAR_RESIZE_INL_H_ + diff --git a/src/operator/contrib/bilinear_resize.cc b/src/operator/contrib/bilinear_resize.cc new file mode 100644 index 000000000000..6d2b350c28c1 --- /dev/null +++ b/src/operator/contrib/bilinear_resize.cc @@ -0,0 +1,199 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * Copyright (c) 2018 by Contributors + * \file bilinear_resize.cc + * \brief bilinear resize operator + * \author Hang Zhang +*/ +#include "bilinear_resize-inl.h" +// #include "elemwise_op_common.h" +#include "../elemwise_op_common.h" + +namespace mxnet { +namespace op { + +using namespace mshadow; + +template +void SpatialUpSamplingBilinearUpdateOutput(mshadow::Stream *s, + const std::vector &input, + const std::vector &output) { + Tensor itensor = input[0].get(s); + Tensor otensor = output[0].get(s); + int nbatch = otensor.size(0); + int channels = otensor.size(1); + int outputHeight = otensor.size(2); + int outputWidth = otensor.size(3); + int inputHeight = itensor.size(2); + int inputWidth = itensor.size(3); + + DType *idata = itensor.dptr_; + DType *odata = otensor.dptr_; + channels = nbatch * channels; + // special case: just copy + if (inputHeight == outputHeight && inputWidth == outputWidth) { + for (int h2 = 0; h2 < outputHeight; ++h2) { + const int h1 = h2; + for (int w2 = 0; w2 < outputWidth; ++w2) { + const int w1 = w2; + const DType* pos1 = &idata[h1 * inputWidth + w1]; + DType* pos2 = &odata[h2 * outputWidth + w2]; + for (int c = 0; c < channels; ++c) { + pos2[0] = pos1[0]; + pos1 += inputWidth * inputHeight; + pos2 += outputWidth * outputHeight; + } + } + } + return; + } + const float rheight =(outputHeight > 1) ? static_cast(inputHeight - 1)/ + (outputHeight - 1) : 0.f; + const float rwidth = (outputWidth > 1) ? static_cast(inputWidth - 1) / + (outputWidth - 1) : 0.f; + for (int h2 = 0; h2 < outputHeight; ++h2) { + const float h1r = rheight * h2; + const int h1 = h1r; + const int h1p = (h1 < inputHeight - 1) ? 1 : 0; + const DType h1lambda = h1r - h1; + const DType h0lambda = (DType)1. - h1lambda; + for (int w2 = 0; w2 < outputWidth; ++w2) { + const float w1r = rwidth * w2; + const int w1 = w1r; + const int w1p = (w1 < inputWidth - 1) ? 1 : 0; + const DType w1lambda = w1r - w1; + const DType w0lambda = (DType)1. - w1lambda; + const DType* pos1 = &idata[h1 * inputWidth + w1]; + DType* pos2 = &odata[h2 * outputWidth + w2]; + for (int c = 0; c < channels; ++c) { + pos2[0] = h0lambda * (w0lambda * pos1[0]+ w1lambda * pos1[w1p]) + + h1lambda * (w0lambda * pos1[h1p * inputWidth] + + w1lambda * pos1[h1p * inputWidth + w1p]); + pos1 += inputWidth * inputHeight; + pos2 += outputWidth * outputHeight; + } + } + } +} + + +template +void SpatialUpSamplingBilinearUpdateGradInput(mshadow::Stream *s, + const std::vector &input, + const std::vector &output) { + Tensor gradOutput = input[0].get(s); + Tensor gradInput = output[0].get(s); + + int nbatch = gradInput.size(0); + int channels = gradInput.size(1); + int outputHeight = gradOutput.size(2); + int outputWidth = gradOutput.size(3); + int inputHeight = gradInput.size(2); + int inputWidth = gradInput.size(3); + + DType *data1 = gradInput.dptr_; + DType *data2 = gradOutput.dptr_; + channels = nbatch * channels; + + // special case: same-size matching grids + if (inputHeight == outputHeight && inputWidth == outputWidth) { + for (int h2 = 0; h2 < outputHeight; ++h2) { + const int h1 = h2; + for (int w2 = 0; w2 < outputWidth; ++w2) { + const int w1 = w2; + DType* pos1 = &data1[h1 * inputWidth + w1]; + const DType* pos2 = &data2[h2 * outputWidth + w2]; + for (int c = 0; c < channels; ++c) { + pos1[0] += pos2[0]; + pos1 += inputWidth * inputHeight; + pos2 += outputWidth * outputHeight; + } + } + } + return; + } + const float rheight =(outputHeight > 1) ? static_cast(inputHeight - 1)/ + (outputHeight - 1) : 0.f; + const float rwidth = (outputWidth > 1) ? static_cast(inputWidth - 1)/ + (outputWidth - 1) : 0.f; + for (int h2 = 0; h2 < outputHeight; ++h2) { + const float h1r = rheight * h2; + const int h1 = h1r; + const int h1p = (h1 < inputHeight - 1) ? 1 : 0; + const DType h1lambda = h1r - h1; + const DType h0lambda = (DType)1. - h1lambda; + for (int w2 = 0; w2 < outputWidth; ++w2) { + const float w1r = rwidth * w2; + const int w1 = w1r; + const int w1p = (w1 < inputWidth - 1) ? 1 : 0; + const DType w1lambda = w1r - w1; + const DType w0lambda = (DType)1. - w1lambda; + DType* pos1 = &data1[h1 * inputWidth + w1]; + const DType* pos2 = &data2[h2 * outputWidth + w2]; + for (int c = 0; c < channels; ++c) { + pos1[0] += h0lambda * w0lambda * pos2[0]; + pos1[w1p] += h0lambda * w1lambda * pos2[0]; + pos1[h1p * inputWidth] += h1lambda * w0lambda * pos2[0]; + pos1[h1p * inputWidth + w1p] += h1lambda * w1lambda * pos2[0]; + pos1 += inputWidth * inputHeight; + pos2 += outputWidth * outputHeight; + } + } + } +} + + +DMLC_REGISTER_PARAMETER(BilinearSampleParam); + +NNVM_REGISTER_OP(_contrib_BilinearResize2D) +.describe(R"code( +Perform 2D resizing (upsampling or downsampling) for 4D input using bilinear interpolation. + +Expected input is a 4 dimensional NDArray (NCHW) and the output +with the shape of (N x C x height x width). +The key idea of bilinear interpolation is to perform linear interpolation +first in one direction, and then again in the other direction. See the wikipedia of +`Bilinear interpolation `_ +for more details. +)code" ADD_FILELINE) +.set_attr_parser(ParamParser) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr("FInferShape", BilinearSampleOpInferShape) +.set_attr("FInferType", BilinearSampleOpInferType) +.set_attr("FInferStorageType", BilinearSampleOpStorageType) +.set_attr("FCompute", BilinearSampleOpForward) +.set_attr("FGradient", + ElemwiseGradUseNone{"_backward_contrib_BilinearResize2D"}) +.add_argument("data", "NDArray-or-Symbol", "Input data") +.add_arguments(BilinearSampleParam::__FIELDS__()); + +NNVM_REGISTER_OP(_backward_contrib_BilinearResize2D) +.set_attr_parser(ParamParser) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr("TIsBackward", true) +.set_attr("FInferStorageType", BilinearSampleOpStorageType) +.set_attr("FCompute", BilinearSampleOpBackward); + + +} // namespace op +} // namespace mxnet + diff --git a/src/operator/contrib/bilinear_resize.cu b/src/operator/contrib/bilinear_resize.cu new file mode 100644 index 000000000000..2ad818aba5a1 --- /dev/null +++ b/src/operator/contrib/bilinear_resize.cu @@ -0,0 +1,221 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * Copyright (c) 2018 by Contributors + * \file bilinear_resize.cu + * \brief bilinear resize operator + * \author Hang Zhang +*/ +#include +#include +#include "bilinear_resize-inl.h" + +namespace mxnet { +namespace op { + +using namespace mshadow; + +template +struct ScalarConvert { + static __host__ __device__ __forceinline__ Out to(const In v) { return (Out) v; } +}; + + +// The maximum number of threads in a block +static const unsigned MAX_BLOCK_SIZE = 512U; + +// Number of threads in a block given an input size up to MAX_BLOCK_SIZE +static unsigned getNumThreads(int nElem, const bool smaller) { + unsigned threadSizes[5] = {32, 64, 128, 256, MAX_BLOCK_SIZE}; + const int maxi = smaller ? 4 : 5; + for (int i = 0; i != maxi; ++i) { + if (static_cast(nElem) <= threadSizes[i]) { + return threadSizes[i]; + } + } + return smaller ? (MAX_BLOCK_SIZE >> 1) : MAX_BLOCK_SIZE; +} + +template +__global__ void caffe_gpu_interp2_kernel(const int n, + const Acctype rheight, const Acctype rwidth, + const Tensor data1, + Tensor data2) { + int index = threadIdx.x + blockIdx.x * blockDim.x; + const int batchsize = data1.size(0); + const int channels = data1.size(1); + const int height1 = data1.size(2); + const int width1 = data1.size(3); + const int height2 = data2.size(2); + const int width2 = data2.size(3); + + if (index < n) { + const int w2 = index % width2; // 0:width2-1 + const int h2 = index / width2; // 0:height2-1 + // special case: just copy + if (height1 == height2 && width1 == width2) { + const int h1 = h2; + const int w1 = w2; + for (int n = 0; n < batchsize ; n++) { + for (int c = 0; c < channels; ++c) { + const Dtype val = data1[n][c][h1][w1]; + data2[n][c][h2][w2] = val; + } + } + return; + } + // + const Acctype h1r = rheight * h2; + const int h1 = h1r; + const int h1p = (h1 < height1 - 1) ? 1 : 0; + const Acctype h1lambda = h1r - h1; + const Acctype h0lambda = Acctype(1) - h1lambda; + // + const Acctype w1r = rwidth * w2; + const int w1 = w1r; + const int w1p = (w1 < width1 - 1) ? 1 : 0; + const Acctype w1lambda = w1r - w1; + const Acctype w0lambda = Acctype(1) - w1lambda; + // + for (int n = 0; n < batchsize ; n++) { + for (int c = 0; c < channels; ++c) { + const Acctype val = h0lambda * (w0lambda * data1[n][c][h1][w1] + + w1lambda * data1[n][c][h1][w1+w1p]) + + h1lambda * (w0lambda * data1[n][c][h1+h1p][w1] + + w1lambda * data1[n][c][h1+h1p][w1+w1p]); + data2[n][c][h2][w2] = ScalarConvert::to(val); + } + } + } +} + +// Backward (adjoint) operation 1 <- 2 (accumulates) +template +__global__ void caffe_gpu_interp2_kernel_backward(const int n, + const Acctype rheight, const Acctype rwidth, + Tensor data1, const Tensor data2) { + int index = threadIdx.x + blockIdx.x * blockDim.x; + const int batchsize = data1.size(0); + const int channels = data1.size(1); + const int height1 = data1.size(2); + const int width1 = data1.size(3); + const int height2 = data2.size(2); + const int width2 = data2.size(3); + if (index < n) { + const int w2 = index % width2; // 0:width2-1 + const int h2 = index / width2; // 0:height2-1 + // special case: just copy + if (height1 == height2 && width1 == width2) { + const int h1 = h2; + const int w1 = w2; + for (int n = 0; n < batchsize ; n++) { + for (int c = 0; c < channels; ++c) { + const Dtype val = data2[n][c][h1][w1]; + data1[n][c][h2][w2] += val; + } + } + return; + } + // + const Acctype h1r = rheight * h2; + const int h1 = h1r; + const int h1p = (h1 < height1 - 1) ? 1 : 0; + const Acctype h1lambda = h1r - h1; + const Acctype h0lambda = Acctype(1) - h1lambda; + // + const Acctype w1r = rwidth * w2; + const int w1 = w1r; + const int w1p = (w1 < width1 - 1) ? 1 : 0; + const Acctype w1lambda = w1r - w1; + const Acctype w0lambda = Acctype(1) - w1lambda; + // + for (int n = 0; n < batchsize ; n++) { + for (int c = 0; c < channels; ++c) { + const Dtype d2val = data2[n][c][h2][w2]; + atomicAdd(&data1[n][c][h1][w1], + ScalarConvert::to(h0lambda * w0lambda * d2val)); + atomicAdd(&data1[n][c][h1][w1+w1p], + ScalarConvert::to(h0lambda * w1lambda * d2val)); + atomicAdd(&data1[n][c][h1+h1p][w1], + ScalarConvert::to(h1lambda * w0lambda * d2val)); + atomicAdd(&data1[n][c][h1+h1p][w1+w1p], + ScalarConvert::to(h1lambda * w1lambda * d2val)); + } + } + } +} + +template +void SpatialUpSamplingBilinearUpdateOutput(mshadow::Stream *s, + const std::vector &input, + const std::vector &output) { + Tensor idata = input[0].get(s); + Tensor odata = output[0].get(s); + int outputHeight = odata.size(2); + int outputWidth = odata.size(3); + int inputHeight = idata.size(2); + int inputWidth = idata.size(3); + + const AccReal rheight = (outputHeight > 1) ? (AccReal)(inputHeight - 1)/ + (outputHeight - 1) : AccReal(0); + const AccReal rwidth = (outputWidth > 1) ? (AccReal)(inputWidth - 1)/ + (outputWidth - 1) : AccReal(0); + const int num_kernels = outputHeight * outputWidth; + const int num_threads = getNumThreads(inputHeight*inputWidth, false); + dim3 blocks(static_cast(num_kernels / num_threads) + 1); + dim3 threads(num_threads); + cudaStream_t stream = mshadow::Stream::GetStream(s); + caffe_gpu_interp2_kernel + <<>>( + num_kernels, rheight, rwidth, idata, odata); + MSHADOW_CUDA_POST_KERNEL_CHECK(SpatialUpSamplingBilinearUpdateOutput); +} + +template +void SpatialUpSamplingBilinearUpdateGradInput(mshadow::Stream *s, + const std::vector &input, + const std::vector &output) { + Tensor data1 = output[0].get(s); + Tensor data2 = input[0].get(s); + int height1 = data1.size(2); + int width1 = data1.size(3); + int height2 = data2.size(2); + int width2 = data2.size(3); + const AccReal rheight = (height2 > 1) ? (AccReal)(height1 - 1)/(height2 - 1) : AccReal(0); + const AccReal rwidth = (width2 > 1) ? (AccReal)(width1 - 1) / (width2 - 1) : AccReal(0); + const int num_kernels = height2 * width2; + const int num_threads = getNumThreads(height1*width1, false); + dim3 blocks(static_cast(num_kernels / num_threads) + 1); + dim3 threads(num_threads); + cudaStream_t stream = mshadow::Stream::GetStream(s); + caffe_gpu_interp2_kernel_backward + <<>>( + num_kernels, rheight, rwidth, data1, data2); + MSHADOW_CUDA_POST_KERNEL_CHECK(SpatialUpSamplingBilinearUpdateGradInput); +} + +NNVM_REGISTER_OP(_contrib_BilinearResize2D) +.set_attr("FCompute", BilinearSampleOpForward); + +NNVM_REGISTER_OP(_backward_contrib_BilinearResize2D) +.set_attr("FCompute", BilinearSampleOpBackward); + +} // namespace op +} // namespace mxnet +