diff --git a/python/mxnet/gluon/data/vision/transforms.py b/python/mxnet/gluon/data/vision/transforms.py index 2f557f591f60..aa4a3e3d8957 100644 --- a/python/mxnet/gluon/data/vision/transforms.py +++ b/python/mxnet/gluon/data/vision/transforms.py @@ -262,8 +262,8 @@ def forward(self, x): return image.center_crop(x, *self._args)[0] -class Resize(Block): - """Resize an image to the given size. +class Resize(HybridBlock): + """Resize an image or a batch of image NDArray to the given size. Should be applied before `mxnet.gluon.data.vision.transforms.ToTensor`. Parameters @@ -276,13 +276,17 @@ class Resize(Block): interpolation : int Interpolation method for resizing. By default uses bilinear interpolation. See OpenCV's resize function for available choices. + Note that the Resize on gpu use contrib.bilinearResize2D operator + which only support bilinear interpolation(1). The result would be slightly + different on gpu compared to cpu. OpenCV tend to align center while bilinearResize2D + use algorithm which aligns corner. Inputs: - - **data**: input tensor with (Hi x Wi x C) shape. + - **data**: input tensor with (H x W x C) or (N x H x W x C) shape. Outputs: - - **out**: output tensor with (H x W x C) shape. + - **out**: output tensor with (H x W x C) or (N x H x W x C) shape. Examples -------- @@ -290,6 +294,9 @@ class Resize(Block): >>> image = mx.nd.random.uniform(0, 255, (224, 224, 3)).astype(dtype=np.uint8) >>> transformer(image) + >>> image = mx.nd.random.uniform(0, 255, (3, 224, 224, 3)).astype(dtype=np.uint8) + >>> transformer(image) + """ def __init__(self, size, keep_ratio=False, interpolation=1): super(Resize, self).__init__() @@ -297,23 +304,8 @@ def __init__(self, size, keep_ratio=False, interpolation=1): self._size = size self._interpolation = interpolation - def forward(self, x): - if isinstance(self._size, numeric_types): - if not self._keep: - wsize = self._size - hsize = self._size - else: - h, w, _ = x.shape - if h > w: - wsize = self._size - hsize = int(h * wsize / w) - else: - hsize = self._size - wsize = int(w * hsize / h) - else: - wsize, hsize = self._size - return image.imresize(x, wsize, hsize, self._interpolation) - + def hybrid_forward(self, F, x): + return F.image.resize(x, self._size, self._keep, self._interpolation) class RandomFlipLeftRight(HybridBlock): """Randomly flip the input image left to right with a probability diff --git a/src/io/image_io.cc b/src/io/image_io.cc index b3f7c40b2b1a..44fcdb8321de 100644 --- a/src/io/image_io.cc +++ b/src/io/image_io.cc @@ -38,6 +38,7 @@ #include #include "../operator/elemwise_op_common.h" +#include "../operator/image/resize-inl.h" #if MXNET_USE_OPENCV #include @@ -285,19 +286,8 @@ inline void Imresize(const nnvm::NodeAttrs& attrs, const std::vector &inputs, const std::vector &req, const std::vector &outputs) { -#if MXNET_USE_OPENCV - CHECK_NE(inputs[0].type_flag_, mshadow::kFloat16) << "imresize doesn't support fp16"; - const int DTYPE[] = {CV_32F, CV_64F, -1, CV_8U, CV_32S}; - int cv_type = CV_MAKETYPE(DTYPE[inputs[0].type_flag_], inputs[0].shape_[2]); const auto& param = nnvm::get(attrs.parsed); - cv::Mat buf(inputs[0].shape_[0], inputs[0].shape_[1], cv_type, inputs[0].dptr_); - cv::Mat dst(outputs[0].shape_[0], outputs[0].shape_[1], cv_type, outputs[0].dptr_); - cv::resize(buf, dst, cv::Size(param.w, param.h), 0, 0, param.interp); - CHECK(!dst.empty()); - CHECK_EQ(static_cast(dst.ptr()), outputs[0].dptr_); -#else - LOG(FATAL) << "Build with USE_OPENCV=1 for image io."; -#endif // MXNET_USE_OPENCV + op::image::ResizeImpl(inputs, outputs, param.h, param.w, param.interp); } diff --git a/src/operator/contrib/bilinear_resize-inl.cuh b/src/operator/contrib/bilinear_resize-inl.cuh new file mode 100644 index 000000000000..b8dacb1c4f31 --- /dev/null +++ b/src/operator/contrib/bilinear_resize-inl.cuh @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * Copyright (c) 2019 by Contributors + * \file bilinear_resize-inl.cuh + * \brief bilinear resize operator cuda implementation + * \author Hang Zhang, Jake Lee +*/ + +#ifndef MXNET_OPERATOR_CONTRIB_BILINEAR_RESIZE_CUH_ +#define MXNET_OPERATOR_CONTRIB_BILINEAR_RESIZE_CUH_ + +#include +#include + +namespace mxnet { +namespace op { + +using namespace mshadow; + +enum ImageLayout { + HWC, + NHWC, + NCHW +}; + +template +struct ScalarConvert { + static __host__ __device__ __forceinline__ Out to(const In v) { return (Out) v; } +}; + +// The maximum number of threads in a block +static const unsigned MAX_BLOCK_SIZE = 512U; + +// Number of threads in a block given an input size up to MAX_BLOCK_SIZE +static unsigned getNumThreads(int nElem, const bool smaller) { + unsigned threadSizes[5] = {32, 64, 128, 256, MAX_BLOCK_SIZE}; + const int maxi = smaller ? 4 : 5; + for (int i = 0; i != maxi; ++i) { + if (static_cast(nElem) <= threadSizes[i]) { + return threadSizes[i]; + } + } + return smaller ? (MAX_BLOCK_SIZE >> 1) : MAX_BLOCK_SIZE; +} + +// caffe_gpu_interp2_kernel overloading with Tensor +template +__global__ void caffe_gpu_interp2_kernel(const int n, + const Acctype rheight, const Acctype rwidth, + const Tensor data1, + Tensor data2, + ImageLayout layout) { + int index = threadIdx.x + blockIdx.x * blockDim.x; + const int channels = data1.size(2); + const int height1 = data1.size(0); + const int width1 = data1.size(1); + const int height2 = data2.size(0); + const int width2 = data2.size(1); + + if (index < n) { + const int w2 = index % width2; // 0:width2-1 + const int h2 = index / width2; // 0:height2-1 + // special case: just copy + if (height1 == height2 && width1 == width2) { + const int h1 = h2; + const int w1 = w2; + for (int c = 0; c < channels; ++c) { + const Dtype val = data1[h1][w1][c]; + data2[h2][w2][c] = val; + } + return; + } + // + const Acctype h1r = rheight * h2; + const int h1 = h1r; + const int h1p = (h1 < height1 - 1) ? 1 : 0; + const Acctype h1lambda = h1r - h1; + const Acctype h0lambda = Acctype(1) - h1lambda; + // + const Acctype w1r = rwidth * w2; + const int w1 = w1r; + const int w1p = (w1 < width1 - 1) ? 1 : 0; + const Acctype w1lambda = w1r - w1; + const Acctype w0lambda = Acctype(1) - w1lambda; + for (int c = 0; c < channels; ++c) { + const Acctype val = h0lambda * (w0lambda * data1[h1][w1][c] + + w1lambda * data1[h1][w1+w1p][c]) + + h1lambda * (w0lambda * data1[h1+h1p][w1][c] + + w1lambda * data1[h1+h1p][w1+w1p][c]); + data2[h2][w2][c] = ScalarConvert::to(val); + } + } +} + +// caffe_gpu_interp2_kernel overloading with Tensor +template +__global__ void caffe_gpu_interp2_kernel(const int n, + const Acctype rheight, const Acctype rwidth, + const Tensor data1, + Tensor data2, + ImageLayout layout) { + int index = threadIdx.x + blockIdx.x * blockDim.x; + int batch_size = (layout == NHWC) ? data1.size(0) : data1.size(0); + int channels = (layout == NHWC) ? data1.size(3) : data1.size(1); + int height1 = (layout == NHWC) ? data1.size(1) : data1.size(2); + int width1 = (layout == NHWC) ? data1.size(2) : data1.size(3); + int height2 = (layout == NHWC) ? data2.size(1) : data2.size(2); + int width2 = (layout == NHWC) ? data2.size(2): data2.size(3); + + if (index < n) { + const int w2 = index % width2; // 0:width2-1 + const int h2 = index / width2; // 0:height2-1 + // special case: just copy + if (height1 == height2 && width1 == width2) { + const int h1 = h2; + const int w1 = w2; + + for (int n = 0; n < batch_size; ++n) { + for (int c = 0; c < channels; ++c) { + if (layout == NHWC) { + const Dtype val = data1[n][h1][w1][c]; + data2[n][h2][w2][c] = val; + } else { + const Dtype val = data1[n][c][h1][w1]; + data2[n][c][h2][w2] = val; + } + } + } + return; + } + // + const Acctype h1r = rheight * h2; + const int h1 = h1r; + const int h1p = (h1 < height1 - 1) ? 1 : 0; + const Acctype h1lambda = h1r - h1; + const Acctype h0lambda = Acctype(1) - h1lambda; + // + const Acctype w1r = rwidth * w2; + const int w1 = w1r; + const int w1p = (w1 < width1 - 1) ? 1 : 0; + const Acctype w1lambda = w1r - w1; + const Acctype w0lambda = Acctype(1) - w1lambda; + + for (auto n = 0; n < batch_size; ++n) { + for (int c = 0; c < channels; ++c) { + if (layout == NHWC) { + const Acctype val = h0lambda * (w0lambda * data1[n][h1][w1][c] + + w1lambda * data1[n][h1][w1+w1p][c]) + + h1lambda * (w0lambda * data1[n][h1+h1p][w1][c] + + w1lambda * data1[n][h1+h1p][w1+w1p][c]); + data2[n][h2][w2][c] = ScalarConvert::to(val); + } else { + const Acctype val = h0lambda * (w0lambda * data1[n][c][h1][w1] + + w1lambda * data1[n][c][h1][w1+w1p]) + + h1lambda * (w0lambda * data1[n][c][h1+h1p][w1] + + w1lambda * data1[n][c][h1+h1p][w1+w1p]); + data2[n][c][h2][w2] = ScalarConvert::to(val); + } + } + } + } +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_CONTRIB_BILINEAR_RESIZE_CUH_ \ No newline at end of file diff --git a/src/operator/contrib/bilinear_resize.cu b/src/operator/contrib/bilinear_resize.cu index f01c9c2fa132..b0a4c4b316d9 100644 --- a/src/operator/contrib/bilinear_resize.cu +++ b/src/operator/contrib/bilinear_resize.cu @@ -25,86 +25,13 @@ #include #include #include "bilinear_resize-inl.h" +#include "bilinear_resize-inl.cuh" namespace mxnet { namespace op { using namespace mshadow; -template -struct ScalarConvert { - static __host__ __device__ __forceinline__ Out to(const In v) { return (Out) v; } -}; - - -// The maximum number of threads in a block -static const unsigned MAX_BLOCK_SIZE = 512U; - -// Number of threads in a block given an input size up to MAX_BLOCK_SIZE -static unsigned getNumThreads(int nElem, const bool smaller) { - unsigned threadSizes[5] = {32, 64, 128, 256, MAX_BLOCK_SIZE}; - const int maxi = smaller ? 4 : 5; - for (int i = 0; i != maxi; ++i) { - if (static_cast(nElem) <= threadSizes[i]) { - return threadSizes[i]; - } - } - return smaller ? (MAX_BLOCK_SIZE >> 1) : MAX_BLOCK_SIZE; -} - -template -__global__ void caffe_gpu_interp2_kernel(const int n, - const Acctype rheight, const Acctype rwidth, - const Tensor data1, - Tensor data2) { - int index = threadIdx.x + blockIdx.x * blockDim.x; - const int batchsize = data1.size(0); - const int channels = data1.size(1); - const int height1 = data1.size(2); - const int width1 = data1.size(3); - const int height2 = data2.size(2); - const int width2 = data2.size(3); - - if (index < n) { - const int w2 = index % width2; // 0:width2-1 - const int h2 = index / width2; // 0:height2-1 - // special case: just copy - if (height1 == height2 && width1 == width2) { - const int h1 = h2; - const int w1 = w2; - for (int n = 0; n < batchsize ; n++) { - for (int c = 0; c < channels; ++c) { - const Dtype val = data1[n][c][h1][w1]; - data2[n][c][h2][w2] = val; - } - } - return; - } - // - const Acctype h1r = rheight * h2; - const int h1 = h1r; - const int h1p = (h1 < height1 - 1) ? 1 : 0; - const Acctype h1lambda = h1r - h1; - const Acctype h0lambda = Acctype(1) - h1lambda; - // - const Acctype w1r = rwidth * w2; - const int w1 = w1r; - const int w1p = (w1 < width1 - 1) ? 1 : 0; - const Acctype w1lambda = w1r - w1; - const Acctype w0lambda = Acctype(1) - w1lambda; - // - for (int n = 0; n < batchsize ; n++) { - for (int c = 0; c < channels; ++c) { - const Acctype val = h0lambda * (w0lambda * data1[n][c][h1][w1] - + w1lambda * data1[n][c][h1][w1+w1p]) - + h1lambda * (w0lambda * data1[n][c][h1+h1p][w1] - + w1lambda * data1[n][c][h1+h1p][w1+w1p]); - data2[n][c][h2][w2] = ScalarConvert::to(val); - } - } - } -} - // Backward (adjoint) operation 1 <- 2 (accumulates) template __global__ void caffe_gpu_interp2_kernel_backward(const int n, @@ -181,9 +108,10 @@ void SpatialUpSamplingBilinearUpdateOutput(mshadow::Stream *s, dim3 blocks(static_cast(num_kernels / num_threads) + 1); dim3 threads(num_threads); cudaStream_t stream = mshadow::Stream::GetStream(s); + ImageLayout layout = NCHW; caffe_gpu_interp2_kernel <<>>( - num_kernels, rheight, rwidth, idata, odata); + num_kernels, rheight, rwidth, idata, odata, layout); MSHADOW_CUDA_POST_KERNEL_CHECK(SpatialUpSamplingBilinearUpdateOutput); } @@ -215,6 +143,5 @@ NNVM_REGISTER_OP(_contrib_BilinearResize2D) NNVM_REGISTER_OP(_backward_contrib_BilinearResize2D) .set_attr("FCompute", BilinearSampleOpBackward); - } // namespace op } // namespace mxnet diff --git a/src/operator/image/image_random-inl.h b/src/operator/image/image_random-inl.h index 74807b9b681e..aeea0bcf9fec 100644 --- a/src/operator/image/image_random-inl.h +++ b/src/operator/image/image_random-inl.h @@ -26,14 +26,18 @@ #define MXNET_OPERATOR_IMAGE_IMAGE_RANDOM_INL_H_ -#include #include -#include #include #include +#include #include +#include +#include "mxnet/base.h" #include "../mxnet_op.h" #include "../operator_common.h" +#if MXNET_USE_OPENCV + #include +#endif // MXNET_USE_OPENCV namespace mxnet { namespace op { diff --git a/src/operator/image/image_utils.h b/src/operator/image/image_utils.h new file mode 100644 index 000000000000..a7155345c967 --- /dev/null +++ b/src/operator/image/image_utils.h @@ -0,0 +1,59 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +*/ + +/*! + * Copyright (c) 2019 by Contributors + * \file image_utils.h + * \brief the image operator utility function implementation + * \author Jake Lee + */ + +#ifndef MXNET_OPERATOR_IMAGE_IMAGE_UTILS_H_ +#define MXNET_OPERATOR_IMAGE_IMAGE_UTILS_H_ + +#include +#if MXNET_USE_OPENCV + #include +#endif // MXNET_USE_OPENCV + +namespace mxnet { +namespace op { +namespace image { + +enum ImageLayout {H, W, C}; +enum BatchImageLayout {N, kH, kW, kC}; + +struct SizeParam { + int height; + int width; + SizeParam() { + height = 0; + width = 0; + } + SizeParam(int height_, int width_) { + height = height_; + width = width_; + } +}; // struct SizeParam + +} // namespace image +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_IMAGE_IMAGE_UTILS_H_ diff --git a/src/operator/image/resize-inl.h b/src/operator/image/resize-inl.h new file mode 100644 index 000000000000..3e1310068073 --- /dev/null +++ b/src/operator/image/resize-inl.h @@ -0,0 +1,218 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +*/ +/*! +* \file resize-inl.h +* \brief image resize operator using opencv and only support bilinear resize +* \author Jake Lee +*/ +#ifndef MXNET_OPERATOR_IMAGE_RESIZE_INL_H_ +#define MXNET_OPERATOR_IMAGE_RESIZE_INL_H_ + +#include +#include + +#include "../mxnet_op.h" +#include "../operator_common.h" +#include "image_utils.h" + +#if MXNET_USE_OPENCV + #include +#endif // MXNET_USE_OPENCV + +namespace mxnet { +namespace op { +namespace image { + +using namespace mshadow; + +#if MXNET_USE_CUDA +template +void ResizeImplCUDA(Stream *s, + const T input, + const T output); +#endif // MXNET_USE_CUDA + +struct ResizeParam : public dmlc::Parameter { + nnvm::Tuple size; + bool keep_ratio; + int interp; + DMLC_DECLARE_PARAMETER(ResizeParam) { + DMLC_DECLARE_FIELD(size) + .set_default(nnvm::Tuple()) + .describe("Size of new image. Could be (width, height) or (size)"); + DMLC_DECLARE_FIELD(keep_ratio) + .describe("Whether to resize the short edge or both edges to `size`, " + "if size is give as an integer.") + .set_default(false); + DMLC_DECLARE_FIELD(interp) + .set_default(1) + .describe("Interpolation method for resizing. By default uses bilinear interpolation" + "Options are INTER_NEAREST - a nearest-neighbor interpolation" + "INTER_LINEAR - a bilinear interpolation" + "INTER_AREA - resampling using pixel area relation" + "INTER_CUBIC - a bicubic interpolation over 4x4 pixel neighborhood" + "INTER_LANCZOS4 - a Lanczos interpolation over 8x8 pixel neighborhood" + "Note that the GPU version only support bilinear interpolation(1)" + " and the result on cpu would be slightly different from gpu." + "It uses opencv resize function which tend to align center on cpu" + "while using contrib.bilinearResize2D which aligns corner on gpu"); + } +}; +// handle the keep ratio param +inline SizeParam GetHeightAndWidth(int data_h, + int data_w, + const ResizeParam& param) { + CHECK((param.size.ndim() == 1) || (param.size.ndim() == 2)) + << "Input size dimension must be 1 or 2, but got " + << param.size.ndim(); + int resized_h; + int resized_w; + if (param.size.ndim() == 1) { + CHECK_GT(param.size[0], 0) + << "Input size should be greater than 0, but got " + << param.size[0]; + if (!param.keep_ratio) { + resized_h = param.size[0]; + resized_w = param.size[0]; + } else { + if (data_h > data_w) { + resized_w = param.size[0]; + resized_h = static_cast(data_h * resized_w / data_w); + } else { + resized_h = param.size[0]; + resized_w = static_cast(data_w * resized_h / data_h); + } + } + } else { + CHECK_GT(param.size[0], 0) + << "Input width should be greater than 0, but got " + << param.size[0]; + CHECK_GT(param.size[1], 0) + << "Input height should be greater than 0, but got " + << param.size[1]; + resized_h = param.size[1]; + resized_w = param.size[0]; + } + return SizeParam(resized_h, resized_w); +} + +inline bool ResizeShape(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + // input attrs should only be (h, w, c) or (n, h, w, c) + CHECK((in_attrs->at(0).ndim() == 3U) || (in_attrs->at(0).ndim() == 4U)) + << "Input image dimension should be 3 or 4 but got " + << in_attrs->at(0).ndim(); + const auto& ishape = (*in_attrs)[0]; + const ResizeParam& param = nnvm::get(attrs.parsed); + SizeParam size; + if (ishape.ndim() == 3) { + size = GetHeightAndWidth(ishape[H], ishape[W], param); + SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape({size.height, size.width, ishape[C]})); + } else { + size = GetHeightAndWidth(ishape[kH], ishape[kW], param); + SHAPE_ASSIGN_CHECK(*out_attrs, 0, + TShape({ishape[N], size.height, size.width, ishape[kC]})); + } + return true; +} + +inline void ResizeImpl(const std::vector &inputs, + const std::vector &outputs, + const int height, + const int width, + const int interp, + const int input_index = 0, + const int output_index = 0) { +#if MXNET_USE_OPENCV + CHECK_NE(inputs[0].type_flag_, mshadow::kFloat16) << "opencv image mat doesn't support fp16"; + CHECK((inputs[0].type_flag_ != mshadow::kInt32) || (inputs[0].type_flag_ != mshadow::kInt64)) + << "opencv resize doesn't support int32, int64"; + // mapping to opencv matrix element type according to channel + const int DTYPE[] = {CV_32F, CV_64F, -1, CV_8U, CV_32S}; + if (inputs[0].ndim() == 3) { + const int cv_type = CV_MAKETYPE(DTYPE[inputs[0].type_flag_], inputs[0].shape_[C]); + cv::Mat buf(inputs[0].shape_[H], inputs[0].shape_[W], cv_type, inputs[0].dptr_); + cv::Mat dst(outputs[0].shape_[H], outputs[0].shape_[W], cv_type, outputs[0].dptr_); + cv::resize(buf, dst, cv::Size(width, height), 0, 0, interp); + CHECK(!dst.empty()); + CHECK_EQ(static_cast(dst.ptr()), outputs[0].dptr_); + } else { + const int cv_type = CV_MAKETYPE(DTYPE[inputs[0].type_flag_], inputs[0].shape_[kC]); + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { + cv::Mat buf(inputs[0].shape_[kH], inputs[0].shape_[kW], cv_type, + inputs[0].dptr() + input_index); + cv::Mat dst(outputs[0].shape_[kH], outputs[0].shape_[kW], cv_type, + outputs[0].dptr() + output_index); + cv::resize(buf, dst, cv::Size(width, height), 0, 0, interp); + CHECK(!dst.empty()); + CHECK_EQ(static_cast(dst.ptr()), outputs[0].dptr() + output_index); + }); + } +#else + LOG(FATAL) << "Build with USE_OPENCV=1 for image resize operator."; +#endif // MXNET_USE_OPENCV +} + +template +inline void Resize(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + CHECK_EQ(outputs.size(), 1U); + const ResizeParam& param = nnvm::get(attrs.parsed); + SizeParam size; + if (std::is_same::value) { +#if MXNET_USE_CUDA + CHECK(param.interp == 1) << "interp should be 1 for using Resize on GPU."; + mshadow::Stream *s = ctx.get_stream(); + MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, { + if (inputs[0].ndim() == 3) { + Tensor input = inputs[0].get(s); + Tensor output = outputs[0].get(s); + ResizeImplCUDA, float>(s, input, output); + } else { + Tensor input = inputs[0].get(s); + Tensor output = outputs[0].get(s); + ResizeImplCUDA, float>(s, input, output); + } + }); +#endif // MXNET_USE_CUDA + } else if (inputs[0].ndim() == 3) { + size = GetHeightAndWidth(inputs[0].shape_[H], inputs[0].shape_[W], param); + ResizeImpl(inputs, outputs, size.height, size.width, param.interp); + } else { + size = GetHeightAndWidth(inputs[0].shape_[kH], inputs[0].shape_[kW], param); + const auto batch_size = inputs[0].shape_[N]; + const auto input_step = inputs[0].shape_[kH] * inputs[0].shape_[kW] * inputs[0].shape_[kC]; + const auto output_step = size.height * size.width * inputs[0].shape_[kC]; + #pragma omp parallel for + for (auto i = 0; i < batch_size; ++i) { + ResizeImpl(inputs, outputs, size.height, size.width, + param.interp, i * input_step, i * output_step); + } + } +} + +} // namespace image +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_IMAGE_RESIZE_INL_H_ diff --git a/src/operator/image/resize.cc b/src/operator/image/resize.cc new file mode 100644 index 000000000000..d3b28f08008f --- /dev/null +++ b/src/operator/image/resize.cc @@ -0,0 +1,83 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +*/ +/*! + * Copyright (c) 2019 by Contributors + * \file resize.cc + * \brief resize operator cpu + * \author Jake Lee +*/ +#include +#include "./resize-inl.h" +#include "../operator_common.h" +#include "../elemwise_op_common.h" + +namespace mxnet { +namespace op { +namespace image { + +DMLC_REGISTER_PARAMETER(ResizeParam); + +NNVM_REGISTER_OP(_image_resize) +.describe(R"code(Resize an image NDArray of shape (H x W x C) or (N x H x W x C) +to the given size +Example: + .. code-block:: python + image = mx.nd.random.uniform(0, 255, (4, 2, 3)).astype(dtype=np.uint8) + mx.nd.image.resize(image, (3, 3)) + [[[124 111 197] + [158 80 155] + [193 50 112]] + + [[110 100 113] + [134 165 148] + [157 231 182]] + + [[202 176 134] + [174 191 149] + [147 207 164]]] + + image = mx.nd.random.uniform(0, 255, (2, 4, 2, 3)).astype(dtype=np.uint8) + mx.nd.image.resize(image, (2, 2)) + [[[[ 59 133 80] + [187 114 153]] + + [[ 38 142 39] + [207 131 124]]] + + + [[[117 125 136] + [191 166 150]] + + [[129 63 113] + [182 109 48]]]] + +)code" ADD_FILELINE) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", ResizeShape) +.set_attr("FInferType", ElemwiseType<1, 1>) +.set_attr("FCompute", Resize) +.set_attr("FGradient", ElemwiseGradUseNone{ "_copy" }) +.add_argument("data", "NDArray-or-Symbol", "The input.") +.add_arguments(ResizeParam::__FIELDS__()); + +} // namespace image +} // namespace op +} // namespace mxnet diff --git a/src/operator/image/resize.cu b/src/operator/image/resize.cu new file mode 100644 index 000000000000..f045f3b238ea --- /dev/null +++ b/src/operator/image/resize.cu @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * Copyright (c) 2019 by Contributors + * \file bilinear_resize.cu + * \brief bilinear resize operator + * \author Hang Zhang, Jake Lee +*/ +#include +#include "./resize-inl.h" +#include "../contrib/bilinear_resize-inl.cuh" + +namespace mxnet { +namespace op { +namespace image { + +using namespace mshadow; + +template +void ResizeImplCUDA(mshadow::Stream *s, + const T input, + const T output) { + int outputHeight; + int outputWidth; + int inputHeight; + int inputWidth; + mxnet::op::ImageLayout layout; + if (std::is_same>::value) { + layout = HWC; + outputHeight = output.size(0); + outputWidth = output.size(1); + inputHeight = input.size(0); + inputWidth = input.size(1); + } else { + layout = NHWC; + outputHeight = output.size(1); + outputWidth = output.size(2); + inputHeight = input.size(1); + inputWidth = input.size(2); + } + const AccReal rheight = (outputHeight > 1) ? (AccReal)(inputHeight - 1)/ + (outputHeight - 1) : AccReal(0); + const AccReal rwidth = (outputWidth > 1) ? (AccReal)(inputWidth - 1)/ + (outputWidth - 1) : AccReal(0); + const int num_kernels = outputHeight * outputWidth; + const int num_threads = getNumThreads(inputHeight * inputWidth, false); + dim3 blocks(static_cast(num_kernels / num_threads) + 1); + dim3 threads(num_threads); + cudaStream_t stream = mshadow::Stream::GetStream(s); + caffe_gpu_interp2_kernel + <<>>( + num_kernels, rheight, rwidth, input, output, layout); + MSHADOW_CUDA_POST_KERNEL_CHECK(caffe_gpu_interp2_kernel); +} + +NNVM_REGISTER_OP(_image_resize) +.set_attr("FCompute", Resize); + +} // namespace image +} // namespace op +} // namespace mxnet diff --git a/tests/python/gpu/test_gluon_transforms.py b/tests/python/gpu/test_gluon_transforms.py index c7afc762bd80..4a1017b538ac 100644 --- a/tests/python/gpu/test_gluon_transforms.py +++ b/tests/python/gpu/test_gluon_transforms.py @@ -69,4 +69,63 @@ def test_normalize(): # Invalid Input - Channel neither 1 or 3 invalid_data_in = nd.random.uniform(0, 1, (5, 4, 300, 300)) normalize_transformer = transforms.Normalize(mean=(0, 1, 2), std=(3, 2, 1)) - assertRaises(MXNetError, normalize_transformer, invalid_data_in) \ No newline at end of file + assertRaises(MXNetError, normalize_transformer, invalid_data_in) + + +@with_seed() +def test_resize(): + # Test with normal case 3D input float type + data_in_3d = nd.random.uniform(0, 255, (300, 300, 3)) + out_nd_3d = transforms.Resize((100, 100))(data_in_3d) + data_in_4d_nchw = nd.moveaxis(nd.expand_dims(data_in_3d, axis=0), 3, 1) + data_expected_3d = (nd.moveaxis(nd.contrib.BilinearResize2D(data_in_4d_nchw, 100, 100), 1, 3))[0] + assert_almost_equal(out_nd_3d.asnumpy(), data_expected_3d.asnumpy()) + + # Test with normal case 4D input float type + data_in_4d = nd.random.uniform(0, 255, (2, 300, 300, 3)) + out_nd_4d = transforms.Resize((100, 100))(data_in_4d) + data_in_4d_nchw = nd.moveaxis(data_in_4d, 3, 1) + data_expected_4d = nd.moveaxis(nd.contrib.BilinearResize2D(data_in_4d_nchw, 100, 100), 1, 3) + assert_almost_equal(out_nd_4d.asnumpy(), data_expected_4d.asnumpy()) + + # Test invalid interp + data_in_3d = nd.random.uniform(0, 255, (300, 300, 3)) + invalid_transform = transforms.Resize(-150, keep_ratio=False, interpolation=2) + assertRaises(MXNetError, invalid_transform, data_in_3d) + + # Credited to Hang Zhang + def py_bilinear_resize_nhwc(x, outputHeight, outputWidth): + batch, inputHeight, inputWidth, channel = x.shape + if outputHeight == inputHeight and outputWidth == inputWidth: + return x + y = np.empty([batch, outputHeight, outputWidth, channel]).astype('uint8') + rheight = 1.0 * (inputHeight - 1) / (outputHeight - 1) if outputHeight > 1 else 0.0 + rwidth = 1.0 * (inputWidth - 1) / (outputWidth - 1) if outputWidth > 1 else 0.0 + for h2 in range(outputHeight): + h1r = 1.0 * h2 * rheight + h1 = int(np.floor(h1r)) + h1lambda = h1r - h1 + h1p = 1 if h1 < (inputHeight - 1) else 0 + for w2 in range(outputWidth): + w1r = 1.0 * w2 * rwidth + w1 = int(np.floor(w1r)) + w1lambda = w1r - w1 + w1p = 1 if w1 < (inputHeight - 1) else 0 + for b in range(batch): + for c in range(channel): + y[b][h2][w2][c] = (1-h1lambda)*((1-w1lambda)*x[b][h1][w1][c] + \ + w1lambda*x[b][h1][w1+w1p][c]) + \ + h1lambda*((1-w1lambda)*x[b][h1+h1p][w1][c] + \ + w1lambda*x[b][h1+h1p][w1+w1p][c]) + return y + + # Test with normal case 3D input int8 type + data_in_4d = nd.random.uniform(0, 255, (1, 300, 300, 3)).astype('uint8') + out_nd_3d = transforms.Resize((100, 100))(data_in_4d[0]) + assert_almost_equal(out_nd_3d.asnumpy(), py_bilinear_resize_nhwc(data_in_4d.asnumpy(), 100, 100)[0], atol=1.0) + + # Test with normal case 4D input int8 type + data_in_4d = nd.random.uniform(0, 255, (2, 300, 300, 3)).astype('uint8') + out_nd_4d = transforms.Resize((100, 100))(data_in_4d) + assert_almost_equal(out_nd_4d.asnumpy(), py_bilinear_resize_nhwc(data_in_4d.asnumpy(), 100, 100), atol=1.0) + diff --git a/tests/python/unittest/test_gluon_data_vision.py b/tests/python/unittest/test_gluon_data_vision.py index c83778fefc65..f10f0ae4fe19 100644 --- a/tests/python/unittest/test_gluon_data_vision.py +++ b/tests/python/unittest/test_gluon_data_vision.py @@ -17,7 +17,7 @@ from __future__ import print_function import mxnet as mx import mxnet.ndarray as nd -import numpy as np +from mxnet.base import MXNetError from mxnet import gluon from mxnet.base import MXNetError from mxnet.gluon.data.vision import transforms @@ -25,6 +25,7 @@ from mxnet.test_utils import almost_equal from common import assertRaises, setup_module, with_seed, teardown +import numpy as np @with_seed() def test_to_tensor(): @@ -68,6 +69,43 @@ def test_normalize(): assertRaises(MXNetError, normalize_transformer, invalid_data_in) +@with_seed() +def test_resize(): + def _test_resize_with_diff_type(dtype): + # test normal case + data_in = nd.random.uniform(0, 255, (300, 200, 3)).astype(dtype) + out_nd = transforms.Resize(200)(data_in) + data_expected = mx.image.imresize(data_in, 200, 200, 1) + assert_almost_equal(out_nd.asnumpy(), data_expected.asnumpy()) + # test 4D input + data_bath_in = nd.random.uniform(0, 255, (3, 300, 200, 3)).astype(dtype) + out_batch_nd = transforms.Resize(200)(data_bath_in) + for i in range(len(out_batch_nd)): + assert_almost_equal(mx.image.imresize(data_bath_in[i], 200, 200, 1).asnumpy(), + out_batch_nd[i].asnumpy()) + # test interp = 2 + out_nd = transforms.Resize(200, interpolation=2)(data_in) + data_expected = mx.image.imresize(data_in, 200, 200, 2) + assert_almost_equal(out_nd.asnumpy(), data_expected.asnumpy()) + # test height not equals to width + out_nd = transforms.Resize((200, 100))(data_in) + data_expected = mx.image.imresize(data_in, 200, 100, 1) + assert_almost_equal(out_nd.asnumpy(), data_expected.asnumpy()) + # test keep_ratio + out_nd = transforms.Resize(150, keep_ratio=True)(data_in) + data_expected = mx.image.imresize(data_in, 150, 225, 1) + assert_almost_equal(out_nd.asnumpy(), data_expected.asnumpy()) + # test size below zero + invalid_transform = transforms.Resize(-150, keep_ratio=True) + assertRaises(MXNetError, invalid_transform, data_in) + # test size more than 2: + invalid_transform = transforms.Resize((100, 100, 100), keep_ratio=True) + assertRaises(MXNetError, invalid_transform, data_in) + + for dtype in ['uint8', 'float32', 'float64']: + _test_resize_with_diff_type(dtype) + + @with_seed() def test_flip_left_right(): data_in = np.random.uniform(0, 255, (300, 300, 3)).astype(dtype=np.uint8)