Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Export resize and support batch size (#14014)
Browse files Browse the repository at this point in the history
* add image resize operator and unit test

* refactor the resize operator and address lint issues

* address comment and add doc

* assert size is more than 2

* add test case of 4D input

* use ndarray datatype

* add inline to Shape

* add 4D input example

* refactor the duplicate code and separate the resize from image_random

* clean up the code

* add resize implementation

* delete the variable not used

* refactor the code with structure and enum to make code more understandable

* fix the lint

* address comments

* address comment 1. add description 2. refactor unit test and add dtype

* update data type check

* lint

* move the common utitlity to image_utils

* add default value for keep_ratio

* change the operator doc

* update the image utility function

* fix lint

* use Hang implementation to achieve image resize operator GPU

* update the check and doc

* refactor the caffe_gpu_interp2_kernel

* update doc and fix the cpu compile error

* update the comment

* fix lint

* add unit test for gpu

* address comments

* remove the crop and centercop utility function to make the PR clear

* fix the syntax error

* delete the warning

* add unit test with 4D

* fix typo

* add more unit test

* fix unit test

* set atol = 1

* fix missing numpy import

* fix the unit test

* delete test case

* fix unit test missing dependency

* fix error data type

* unify the style and add invalid interp

* update the doc
  • Loading branch information
stu1130 authored and sandeep-krishnamurthy committed Feb 1, 2019
1 parent 9a3e4a0 commit 2a4634b
Show file tree
Hide file tree
Showing 11 changed files with 744 additions and 113 deletions.
34 changes: 13 additions & 21 deletions python/mxnet/gluon/data/vision/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,8 +262,8 @@ def forward(self, x):
return image.center_crop(x, *self._args)[0]


class Resize(Block):
"""Resize an image to the given size.
class Resize(HybridBlock):
"""Resize an image or a batch of image NDArray to the given size.
Should be applied before `mxnet.gluon.data.vision.transforms.ToTensor`.
Parameters
Expand All @@ -276,44 +276,36 @@ class Resize(Block):
interpolation : int
Interpolation method for resizing. By default uses bilinear
interpolation. See OpenCV's resize function for available choices.
Note that the Resize on gpu use contrib.bilinearResize2D operator
which only support bilinear interpolation(1). The result would be slightly
different on gpu compared to cpu. OpenCV tend to align center while bilinearResize2D
use algorithm which aligns corner.
Inputs:
- **data**: input tensor with (Hi x Wi x C) shape.
- **data**: input tensor with (H x W x C) or (N x H x W x C) shape.
Outputs:
- **out**: output tensor with (H x W x C) shape.
- **out**: output tensor with (H x W x C) or (N x H x W x C) shape.
Examples
--------
>>> transformer = vision.transforms.Resize(size=(1000, 500))
>>> image = mx.nd.random.uniform(0, 255, (224, 224, 3)).astype(dtype=np.uint8)
>>> transformer(image)
<NDArray 500x1000x3 @cpu(0)>
>>> image = mx.nd.random.uniform(0, 255, (3, 224, 224, 3)).astype(dtype=np.uint8)
>>> transformer(image)
<NDArray 3x500x1000x3 @cpu(0)>
"""
def __init__(self, size, keep_ratio=False, interpolation=1):
super(Resize, self).__init__()
self._keep = keep_ratio
self._size = size
self._interpolation = interpolation

def forward(self, x):
if isinstance(self._size, numeric_types):
if not self._keep:
wsize = self._size
hsize = self._size
else:
h, w, _ = x.shape
if h > w:
wsize = self._size
hsize = int(h * wsize / w)
else:
hsize = self._size
wsize = int(w * hsize / h)
else:
wsize, hsize = self._size
return image.imresize(x, wsize, hsize, self._interpolation)

def hybrid_forward(self, F, x):
return F.image.resize(x, self._size, self._keep, self._interpolation)

class RandomFlipLeftRight(HybridBlock):
"""Randomly flip the input image left to right with a probability
Expand Down
14 changes: 2 additions & 12 deletions src/io/image_io.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
#include <cstring>

#include "../operator/elemwise_op_common.h"
#include "../operator/image/resize-inl.h"

#if MXNET_USE_OPENCV
#include <opencv2/opencv.hpp>
Expand Down Expand Up @@ -285,19 +286,8 @@ inline void Imresize(const nnvm::NodeAttrs& attrs,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
#if MXNET_USE_OPENCV
CHECK_NE(inputs[0].type_flag_, mshadow::kFloat16) << "imresize doesn't support fp16";
const int DTYPE[] = {CV_32F, CV_64F, -1, CV_8U, CV_32S};
int cv_type = CV_MAKETYPE(DTYPE[inputs[0].type_flag_], inputs[0].shape_[2]);
const auto& param = nnvm::get<ResizeParam>(attrs.parsed);
cv::Mat buf(inputs[0].shape_[0], inputs[0].shape_[1], cv_type, inputs[0].dptr_);
cv::Mat dst(outputs[0].shape_[0], outputs[0].shape_[1], cv_type, outputs[0].dptr_);
cv::resize(buf, dst, cv::Size(param.w, param.h), 0, 0, param.interp);
CHECK(!dst.empty());
CHECK_EQ(static_cast<void*>(dst.ptr()), outputs[0].dptr_);
#else
LOG(FATAL) << "Build with USE_OPENCV=1 for image io.";
#endif // MXNET_USE_OPENCV
op::image::ResizeImpl(inputs, outputs, param.h, param.w, param.interp);
}


Expand Down
184 changes: 184 additions & 0 deletions src/operator/contrib/bilinear_resize-inl.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2019 by Contributors
* \file bilinear_resize-inl.cuh
* \brief bilinear resize operator cuda implementation
* \author Hang Zhang, Jake Lee
*/

#ifndef MXNET_OPERATOR_CONTRIB_BILINEAR_RESIZE_CUH_
#define MXNET_OPERATOR_CONTRIB_BILINEAR_RESIZE_CUH_

#include <cuda_runtime_api.h>
#include <algorithm>

namespace mxnet {
namespace op {

using namespace mshadow;

enum ImageLayout {
HWC,
NHWC,
NCHW
};

template<typename In, typename Out>
struct ScalarConvert {
static __host__ __device__ __forceinline__ Out to(const In v) { return (Out) v; }
};

// The maximum number of threads in a block
static const unsigned MAX_BLOCK_SIZE = 512U;

// Number of threads in a block given an input size up to MAX_BLOCK_SIZE
static unsigned getNumThreads(int nElem, const bool smaller) {
unsigned threadSizes[5] = {32, 64, 128, 256, MAX_BLOCK_SIZE};
const int maxi = smaller ? 4 : 5;
for (int i = 0; i != maxi; ++i) {
if (static_cast<unsigned>(nElem) <= threadSizes[i]) {
return threadSizes[i];
}
}
return smaller ? (MAX_BLOCK_SIZE >> 1) : MAX_BLOCK_SIZE;
}

// caffe_gpu_interp2_kernel overloading with Tensor<xpu, 3, DType>
template<typename xpu, typename Dtype, typename Acctype>
__global__ void caffe_gpu_interp2_kernel(const int n,
const Acctype rheight, const Acctype rwidth,
const Tensor<xpu, 3, Dtype> data1,
Tensor<xpu, 3, Dtype> data2,
ImageLayout layout) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
const int channels = data1.size(2);
const int height1 = data1.size(0);
const int width1 = data1.size(1);
const int height2 = data2.size(0);
const int width2 = data2.size(1);

if (index < n) {
const int w2 = index % width2; // 0:width2-1
const int h2 = index / width2; // 0:height2-1
// special case: just copy
if (height1 == height2 && width1 == width2) {
const int h1 = h2;
const int w1 = w2;
for (int c = 0; c < channels; ++c) {
const Dtype val = data1[h1][w1][c];
data2[h2][w2][c] = val;
}
return;
}
//
const Acctype h1r = rheight * h2;
const int h1 = h1r;
const int h1p = (h1 < height1 - 1) ? 1 : 0;
const Acctype h1lambda = h1r - h1;
const Acctype h0lambda = Acctype(1) - h1lambda;
//
const Acctype w1r = rwidth * w2;
const int w1 = w1r;
const int w1p = (w1 < width1 - 1) ? 1 : 0;
const Acctype w1lambda = w1r - w1;
const Acctype w0lambda = Acctype(1) - w1lambda;
for (int c = 0; c < channels; ++c) {
const Acctype val = h0lambda * (w0lambda * data1[h1][w1][c]
+ w1lambda * data1[h1][w1+w1p][c])
+ h1lambda * (w0lambda * data1[h1+h1p][w1][c]
+ w1lambda * data1[h1+h1p][w1+w1p][c]);
data2[h2][w2][c] = ScalarConvert<Acctype, Dtype>::to(val);
}
}
}

// caffe_gpu_interp2_kernel overloading with Tensor<xpu, 4, DType>
template<typename xpu, typename Dtype, typename Acctype>
__global__ void caffe_gpu_interp2_kernel(const int n,
const Acctype rheight, const Acctype rwidth,
const Tensor<xpu, 4, Dtype> data1,
Tensor<xpu, 4, Dtype> data2,
ImageLayout layout) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
int batch_size = (layout == NHWC) ? data1.size(0) : data1.size(0);
int channels = (layout == NHWC) ? data1.size(3) : data1.size(1);
int height1 = (layout == NHWC) ? data1.size(1) : data1.size(2);
int width1 = (layout == NHWC) ? data1.size(2) : data1.size(3);
int height2 = (layout == NHWC) ? data2.size(1) : data2.size(2);
int width2 = (layout == NHWC) ? data2.size(2): data2.size(3);

if (index < n) {
const int w2 = index % width2; // 0:width2-1
const int h2 = index / width2; // 0:height2-1
// special case: just copy
if (height1 == height2 && width1 == width2) {
const int h1 = h2;
const int w1 = w2;

for (int n = 0; n < batch_size; ++n) {
for (int c = 0; c < channels; ++c) {
if (layout == NHWC) {
const Dtype val = data1[n][h1][w1][c];
data2[n][h2][w2][c] = val;
} else {
const Dtype val = data1[n][c][h1][w1];
data2[n][c][h2][w2] = val;
}
}
}
return;
}
//
const Acctype h1r = rheight * h2;
const int h1 = h1r;
const int h1p = (h1 < height1 - 1) ? 1 : 0;
const Acctype h1lambda = h1r - h1;
const Acctype h0lambda = Acctype(1) - h1lambda;
//
const Acctype w1r = rwidth * w2;
const int w1 = w1r;
const int w1p = (w1 < width1 - 1) ? 1 : 0;
const Acctype w1lambda = w1r - w1;
const Acctype w0lambda = Acctype(1) - w1lambda;

for (auto n = 0; n < batch_size; ++n) {
for (int c = 0; c < channels; ++c) {
if (layout == NHWC) {
const Acctype val = h0lambda * (w0lambda * data1[n][h1][w1][c]
+ w1lambda * data1[n][h1][w1+w1p][c])
+ h1lambda * (w0lambda * data1[n][h1+h1p][w1][c]
+ w1lambda * data1[n][h1+h1p][w1+w1p][c]);
data2[n][h2][w2][c] = ScalarConvert<Acctype, Dtype>::to(val);
} else {
const Acctype val = h0lambda * (w0lambda * data1[n][c][h1][w1]
+ w1lambda * data1[n][c][h1][w1+w1p])
+ h1lambda * (w0lambda * data1[n][c][h1+h1p][w1]
+ w1lambda * data1[n][c][h1+h1p][w1+w1p]);
data2[n][c][h2][w2] = ScalarConvert<Acctype, Dtype>::to(val);
}
}
}
}
}

} // namespace op
} // namespace mxnet

#endif // MXNET_OPERATOR_CONTRIB_BILINEAR_RESIZE_CUH_
79 changes: 3 additions & 76 deletions src/operator/contrib/bilinear_resize.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,86 +25,13 @@
#include <cuda_runtime_api.h>
#include <algorithm>
#include "bilinear_resize-inl.h"
#include "bilinear_resize-inl.cuh"

namespace mxnet {
namespace op {

using namespace mshadow;

template<typename In, typename Out>
struct ScalarConvert {
static __host__ __device__ __forceinline__ Out to(const In v) { return (Out) v; }
};


// The maximum number of threads in a block
static const unsigned MAX_BLOCK_SIZE = 512U;

// Number of threads in a block given an input size up to MAX_BLOCK_SIZE
static unsigned getNumThreads(int nElem, const bool smaller) {
unsigned threadSizes[5] = {32, 64, 128, 256, MAX_BLOCK_SIZE};
const int maxi = smaller ? 4 : 5;
for (int i = 0; i != maxi; ++i) {
if (static_cast<unsigned>(nElem) <= threadSizes[i]) {
return threadSizes[i];
}
}
return smaller ? (MAX_BLOCK_SIZE >> 1) : MAX_BLOCK_SIZE;
}

template<typename xpu, typename Dtype, typename Acctype>
__global__ void caffe_gpu_interp2_kernel(const int n,
const Acctype rheight, const Acctype rwidth,
const Tensor<xpu, 4, Dtype> data1,
Tensor<xpu, 4, Dtype> data2) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
const int batchsize = data1.size(0);
const int channels = data1.size(1);
const int height1 = data1.size(2);
const int width1 = data1.size(3);
const int height2 = data2.size(2);
const int width2 = data2.size(3);

if (index < n) {
const int w2 = index % width2; // 0:width2-1
const int h2 = index / width2; // 0:height2-1
// special case: just copy
if (height1 == height2 && width1 == width2) {
const int h1 = h2;
const int w1 = w2;
for (int n = 0; n < batchsize ; n++) {
for (int c = 0; c < channels; ++c) {
const Dtype val = data1[n][c][h1][w1];
data2[n][c][h2][w2] = val;
}
}
return;
}
//
const Acctype h1r = rheight * h2;
const int h1 = h1r;
const int h1p = (h1 < height1 - 1) ? 1 : 0;
const Acctype h1lambda = h1r - h1;
const Acctype h0lambda = Acctype(1) - h1lambda;
//
const Acctype w1r = rwidth * w2;
const int w1 = w1r;
const int w1p = (w1 < width1 - 1) ? 1 : 0;
const Acctype w1lambda = w1r - w1;
const Acctype w0lambda = Acctype(1) - w1lambda;
//
for (int n = 0; n < batchsize ; n++) {
for (int c = 0; c < channels; ++c) {
const Acctype val = h0lambda * (w0lambda * data1[n][c][h1][w1]
+ w1lambda * data1[n][c][h1][w1+w1p])
+ h1lambda * (w0lambda * data1[n][c][h1+h1p][w1]
+ w1lambda * data1[n][c][h1+h1p][w1+w1p]);
data2[n][c][h2][w2] = ScalarConvert<Acctype, Dtype>::to(val);
}
}
}
}

// Backward (adjoint) operation 1 <- 2 (accumulates)
template<typename xpu, typename Dtype, typename Acctype>
__global__ void caffe_gpu_interp2_kernel_backward(const int n,
Expand Down Expand Up @@ -181,9 +108,10 @@ void SpatialUpSamplingBilinearUpdateOutput(mshadow::Stream<gpu> *s,
dim3 blocks(static_cast<int>(num_kernels / num_threads) + 1);
dim3 threads(num_threads);
cudaStream_t stream = mshadow::Stream<gpu>::GetStream(s);
ImageLayout layout = NCHW;
caffe_gpu_interp2_kernel<xpu, DType, AccReal>
<<<blocks, threads , 0, stream>>>(
num_kernels, rheight, rwidth, idata, odata);
num_kernels, rheight, rwidth, idata, odata, layout);
MSHADOW_CUDA_POST_KERNEL_CHECK(SpatialUpSamplingBilinearUpdateOutput);
}

Expand Down Expand Up @@ -215,6 +143,5 @@ NNVM_REGISTER_OP(_contrib_BilinearResize2D)

NNVM_REGISTER_OP(_backward_contrib_BilinearResize2D)
.set_attr<FCompute>("FCompute<gpu>", BilinearSampleOpBackward<gpu>);

} // namespace op
} // namespace mxnet
Loading

0 comments on commit 2a4634b

Please sign in to comment.