Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Export resize and support batch size #14014

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
998c7ef
add image resize operator and unit test
stu1130 Dec 11, 2018
4e63cad
refactor the resize operator and address lint issues
stu1130 Dec 11, 2018
feb661f
address comment and add doc
stu1130 Dec 11, 2018
f8ccd2f
assert size is more than 2
stu1130 Dec 11, 2018
20b9c0d
add test case of 4D input
stu1130 Dec 11, 2018
529121d
use ndarray datatype
stu1130 Dec 11, 2018
043584f
add inline to Shape
stu1130 Dec 12, 2018
ce0a447
add 4D input example
stu1130 Dec 12, 2018
26287b4
refactor the duplicate code and separate the resize from image_random
stu1130 Dec 12, 2018
df8bb81
clean up the code
stu1130 Dec 12, 2018
d807326
add resize implementation
stu1130 Dec 12, 2018
20ed3a3
delete the variable not used
stu1130 Dec 19, 2018
1db74be
refactor the code with structure and enum to make code more understan…
stu1130 Dec 19, 2018
20f288f
fix the lint
stu1130 Dec 19, 2018
e365f9f
address comments
stu1130 Dec 21, 2018
546a81f
address comment 1. add description 2. refactor unit test and add dtype
stu1130 Dec 21, 2018
c3ec485
update data type check
stu1130 Dec 21, 2018
22583ae
lint
stu1130 Dec 21, 2018
46b2e4d
move the common utitlity to image_utils
stu1130 Dec 21, 2018
e5464f1
add default value for keep_ratio
stu1130 Dec 21, 2018
d805b2e
change the operator doc
stu1130 Jan 11, 2019
c2f8d61
update the image utility function
stu1130 Jan 12, 2019
ecb3be6
fix lint
stu1130 Jan 12, 2019
1c3295d
use Hang implementation to achieve image resize operator GPU
stu1130 Jan 24, 2019
c9b2e80
update the check and doc
stu1130 Jan 29, 2019
d019090
refactor the caffe_gpu_interp2_kernel
stu1130 Jan 29, 2019
ddf8d4e
update doc and fix the cpu compile error
stu1130 Jan 29, 2019
daa2b4b
update the comment
stu1130 Jan 29, 2019
9672091
fix lint
stu1130 Jan 29, 2019
b31d19f
add unit test for gpu
stu1130 Jan 29, 2019
a5d55b3
address comments
stu1130 Jan 29, 2019
7a8fd4e
remove the crop and centercop utility function to make the PR clear
stu1130 Jan 29, 2019
acc88e2
fix the syntax error
stu1130 Jan 29, 2019
d07b272
delete the warning
stu1130 Jan 29, 2019
1d6a201
add unit test with 4D
stu1130 Jan 29, 2019
cd6d481
fix typo
stu1130 Jan 29, 2019
fa4674e
add more unit test
stu1130 Jan 30, 2019
2bbedba
fix unit test
stu1130 Jan 30, 2019
1b604c2
set atol = 1
stu1130 Jan 30, 2019
fd018b5
fix missing numpy import
stu1130 Jan 30, 2019
e5b1754
fix the unit test
stu1130 Jan 30, 2019
53c5ea1
delete test case
stu1130 Jan 30, 2019
009490b
fix unit test missing dependency
stu1130 Jan 30, 2019
c558359
fix error data type
stu1130 Jan 30, 2019
c9935f8
unify the style and add invalid interp
stu1130 Jan 30, 2019
f4b9d23
update the doc
stu1130 Jan 30, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 13 additions & 21 deletions python/mxnet/gluon/data/vision/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,8 +262,8 @@ def forward(self, x):
return image.center_crop(x, *self._args)[0]


class Resize(Block):
"""Resize an image to the given size.
class Resize(HybridBlock):
"""Resize an image or a batch of image NDArray to the given size.
Should be applied before `mxnet.gluon.data.vision.transforms.ToTensor`.

Parameters
Expand All @@ -276,44 +276,36 @@ class Resize(Block):
interpolation : int
Interpolation method for resizing. By default uses bilinear
interpolation. See OpenCV's resize function for available choices.
Note that the Resize on gpu use contrib.bilinearResize2D operator
which only support bilinear interpolation(1). The result would be slightly
different on gpu compared to cpu. OpenCV tend to align center while bilinearResize2D
use algorithm which aligns corner.


Inputs:
- **data**: input tensor with (Hi x Wi x C) shape.
- **data**: input tensor with (H x W x C) or (N x H x W x C) shape.

Outputs:
- **out**: output tensor with (H x W x C) shape.
- **out**: output tensor with (H x W x C) or (N x H x W x C) shape.

Examples
--------
>>> transformer = vision.transforms.Resize(size=(1000, 500))
>>> image = mx.nd.random.uniform(0, 255, (224, 224, 3)).astype(dtype=np.uint8)
>>> transformer(image)
<NDArray 500x1000x3 @cpu(0)>
>>> image = mx.nd.random.uniform(0, 255, (3, 224, 224, 3)).astype(dtype=np.uint8)
>>> transformer(image)
<NDArray 3x500x1000x3 @cpu(0)>
"""
def __init__(self, size, keep_ratio=False, interpolation=1):
super(Resize, self).__init__()
self._keep = keep_ratio
self._size = size
self._interpolation = interpolation

def forward(self, x):
if isinstance(self._size, numeric_types):
if not self._keep:
wsize = self._size
hsize = self._size
else:
h, w, _ = x.shape
if h > w:
wsize = self._size
hsize = int(h * wsize / w)
else:
hsize = self._size
wsize = int(w * hsize / h)
else:
wsize, hsize = self._size
return image.imresize(x, wsize, hsize, self._interpolation)

def hybrid_forward(self, F, x):
return F.image.resize(x, self._size, self._keep, self._interpolation)

class RandomFlipLeftRight(HybridBlock):
"""Randomly flip the input image left to right with a probability
Expand Down
14 changes: 2 additions & 12 deletions src/io/image_io.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
#include <cstring>

#include "../operator/elemwise_op_common.h"
#include "../operator/image/resize-inl.h"

#if MXNET_USE_OPENCV
#include <opencv2/opencv.hpp>
Expand Down Expand Up @@ -285,19 +286,8 @@ inline void Imresize(const nnvm::NodeAttrs& attrs,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
#if MXNET_USE_OPENCV
CHECK_NE(inputs[0].type_flag_, mshadow::kFloat16) << "imresize doesn't support fp16";
const int DTYPE[] = {CV_32F, CV_64F, -1, CV_8U, CV_32S};
int cv_type = CV_MAKETYPE(DTYPE[inputs[0].type_flag_], inputs[0].shape_[2]);
const auto& param = nnvm::get<ResizeParam>(attrs.parsed);
cv::Mat buf(inputs[0].shape_[0], inputs[0].shape_[1], cv_type, inputs[0].dptr_);
cv::Mat dst(outputs[0].shape_[0], outputs[0].shape_[1], cv_type, outputs[0].dptr_);
cv::resize(buf, dst, cv::Size(param.w, param.h), 0, 0, param.interp);
CHECK(!dst.empty());
CHECK_EQ(static_cast<void*>(dst.ptr()), outputs[0].dptr_);
#else
LOG(FATAL) << "Build with USE_OPENCV=1 for image io.";
#endif // MXNET_USE_OPENCV
op::image::ResizeImpl(inputs, outputs, param.h, param.w, param.interp);
}


Expand Down
184 changes: 184 additions & 0 deletions src/operator/contrib/bilinear_resize-inl.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2019 by Contributors
* \file bilinear_resize-inl.cuh
* \brief bilinear resize operator cuda implementation
* \author Hang Zhang, Jake Lee
*/

#ifndef MXNET_OPERATOR_CONTRIB_BILINEAR_RESIZE_CUH_
#define MXNET_OPERATOR_CONTRIB_BILINEAR_RESIZE_CUH_

#include <cuda_runtime_api.h>
#include <algorithm>

namespace mxnet {
namespace op {

using namespace mshadow;

enum ImageLayout {
HWC,
NHWC,
NCHW
};

template<typename In, typename Out>
struct ScalarConvert {
static __host__ __device__ __forceinline__ Out to(const In v) { return (Out) v; }
};

// The maximum number of threads in a block
static const unsigned MAX_BLOCK_SIZE = 512U;

// Number of threads in a block given an input size up to MAX_BLOCK_SIZE
static unsigned getNumThreads(int nElem, const bool smaller) {
unsigned threadSizes[5] = {32, 64, 128, 256, MAX_BLOCK_SIZE};
const int maxi = smaller ? 4 : 5;
for (int i = 0; i != maxi; ++i) {
if (static_cast<unsigned>(nElem) <= threadSizes[i]) {
return threadSizes[i];
}
}
return smaller ? (MAX_BLOCK_SIZE >> 1) : MAX_BLOCK_SIZE;
}

// caffe_gpu_interp2_kernel overloading with Tensor<xpu, 3, DType>
template<typename xpu, typename Dtype, typename Acctype>
__global__ void caffe_gpu_interp2_kernel(const int n,
const Acctype rheight, const Acctype rwidth,
const Tensor<xpu, 3, Dtype> data1,
Tensor<xpu, 3, Dtype> data2,
ImageLayout layout) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
const int channels = data1.size(2);
const int height1 = data1.size(0);
const int width1 = data1.size(1);
const int height2 = data2.size(0);
const int width2 = data2.size(1);

if (index < n) {
const int w2 = index % width2; // 0:width2-1
const int h2 = index / width2; // 0:height2-1
// special case: just copy
if (height1 == height2 && width1 == width2) {
const int h1 = h2;
const int w1 = w2;
for (int c = 0; c < channels; ++c) {
const Dtype val = data1[h1][w1][c];
data2[h2][w2][c] = val;
}
return;
}
//
const Acctype h1r = rheight * h2;
const int h1 = h1r;
const int h1p = (h1 < height1 - 1) ? 1 : 0;
const Acctype h1lambda = h1r - h1;
const Acctype h0lambda = Acctype(1) - h1lambda;
//
const Acctype w1r = rwidth * w2;
const int w1 = w1r;
const int w1p = (w1 < width1 - 1) ? 1 : 0;
const Acctype w1lambda = w1r - w1;
const Acctype w0lambda = Acctype(1) - w1lambda;
for (int c = 0; c < channels; ++c) {
const Acctype val = h0lambda * (w0lambda * data1[h1][w1][c]
+ w1lambda * data1[h1][w1+w1p][c])
+ h1lambda * (w0lambda * data1[h1+h1p][w1][c]
+ w1lambda * data1[h1+h1p][w1+w1p][c]);
data2[h2][w2][c] = ScalarConvert<Acctype, Dtype>::to(val);
}
}
}

// caffe_gpu_interp2_kernel overloading with Tensor<xpu, 4, DType>
template<typename xpu, typename Dtype, typename Acctype>
__global__ void caffe_gpu_interp2_kernel(const int n,
const Acctype rheight, const Acctype rwidth,
const Tensor<xpu, 4, Dtype> data1,
Tensor<xpu, 4, Dtype> data2,
ImageLayout layout) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
int batch_size = (layout == NHWC) ? data1.size(0) : data1.size(0);
int channels = (layout == NHWC) ? data1.size(3) : data1.size(1);
int height1 = (layout == NHWC) ? data1.size(1) : data1.size(2);
int width1 = (layout == NHWC) ? data1.size(2) : data1.size(3);
int height2 = (layout == NHWC) ? data2.size(1) : data2.size(2);
int width2 = (layout == NHWC) ? data2.size(2): data2.size(3);

if (index < n) {
const int w2 = index % width2; // 0:width2-1
const int h2 = index / width2; // 0:height2-1
// special case: just copy
if (height1 == height2 && width1 == width2) {
const int h1 = h2;
const int w1 = w2;

for (int n = 0; n < batch_size; ++n) {
for (int c = 0; c < channels; ++c) {
if (layout == NHWC) {
const Dtype val = data1[n][h1][w1][c];
data2[n][h2][w2][c] = val;
} else {
const Dtype val = data1[n][c][h1][w1];
data2[n][c][h2][w2] = val;
}
}
}
return;
}
//
const Acctype h1r = rheight * h2;
const int h1 = h1r;
const int h1p = (h1 < height1 - 1) ? 1 : 0;
const Acctype h1lambda = h1r - h1;
const Acctype h0lambda = Acctype(1) - h1lambda;
//
const Acctype w1r = rwidth * w2;
const int w1 = w1r;
const int w1p = (w1 < width1 - 1) ? 1 : 0;
const Acctype w1lambda = w1r - w1;
const Acctype w0lambda = Acctype(1) - w1lambda;

for (auto n = 0; n < batch_size; ++n) {
for (int c = 0; c < channels; ++c) {
if (layout == NHWC) {
const Acctype val = h0lambda * (w0lambda * data1[n][h1][w1][c]
+ w1lambda * data1[n][h1][w1+w1p][c])
+ h1lambda * (w0lambda * data1[n][h1+h1p][w1][c]
+ w1lambda * data1[n][h1+h1p][w1+w1p][c]);
data2[n][h2][w2][c] = ScalarConvert<Acctype, Dtype>::to(val);
} else {
const Acctype val = h0lambda * (w0lambda * data1[n][c][h1][w1]
+ w1lambda * data1[n][c][h1][w1+w1p])
+ h1lambda * (w0lambda * data1[n][c][h1+h1p][w1]
+ w1lambda * data1[n][c][h1+h1p][w1+w1p]);
data2[n][c][h2][w2] = ScalarConvert<Acctype, Dtype>::to(val);
}
}
}
}
}

} // namespace op
} // namespace mxnet

#endif // MXNET_OPERATOR_CONTRIB_BILINEAR_RESIZE_CUH_
79 changes: 3 additions & 76 deletions src/operator/contrib/bilinear_resize.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,86 +25,13 @@
#include <cuda_runtime_api.h>
#include <algorithm>
#include "bilinear_resize-inl.h"
#include "bilinear_resize-inl.cuh"

namespace mxnet {
namespace op {

using namespace mshadow;

template<typename In, typename Out>
struct ScalarConvert {
static __host__ __device__ __forceinline__ Out to(const In v) { return (Out) v; }
};


// The maximum number of threads in a block
static const unsigned MAX_BLOCK_SIZE = 512U;

// Number of threads in a block given an input size up to MAX_BLOCK_SIZE
static unsigned getNumThreads(int nElem, const bool smaller) {
unsigned threadSizes[5] = {32, 64, 128, 256, MAX_BLOCK_SIZE};
const int maxi = smaller ? 4 : 5;
for (int i = 0; i != maxi; ++i) {
if (static_cast<unsigned>(nElem) <= threadSizes[i]) {
return threadSizes[i];
}
}
return smaller ? (MAX_BLOCK_SIZE >> 1) : MAX_BLOCK_SIZE;
}

template<typename xpu, typename Dtype, typename Acctype>
__global__ void caffe_gpu_interp2_kernel(const int n,
const Acctype rheight, const Acctype rwidth,
const Tensor<xpu, 4, Dtype> data1,
Tensor<xpu, 4, Dtype> data2) {
int index = threadIdx.x + blockIdx.x * blockDim.x;
const int batchsize = data1.size(0);
const int channels = data1.size(1);
const int height1 = data1.size(2);
const int width1 = data1.size(3);
const int height2 = data2.size(2);
const int width2 = data2.size(3);

if (index < n) {
const int w2 = index % width2; // 0:width2-1
const int h2 = index / width2; // 0:height2-1
// special case: just copy
if (height1 == height2 && width1 == width2) {
const int h1 = h2;
const int w1 = w2;
for (int n = 0; n < batchsize ; n++) {
for (int c = 0; c < channels; ++c) {
const Dtype val = data1[n][c][h1][w1];
data2[n][c][h2][w2] = val;
}
}
return;
}
//
const Acctype h1r = rheight * h2;
const int h1 = h1r;
const int h1p = (h1 < height1 - 1) ? 1 : 0;
const Acctype h1lambda = h1r - h1;
const Acctype h0lambda = Acctype(1) - h1lambda;
//
const Acctype w1r = rwidth * w2;
const int w1 = w1r;
const int w1p = (w1 < width1 - 1) ? 1 : 0;
const Acctype w1lambda = w1r - w1;
const Acctype w0lambda = Acctype(1) - w1lambda;
//
for (int n = 0; n < batchsize ; n++) {
for (int c = 0; c < channels; ++c) {
const Acctype val = h0lambda * (w0lambda * data1[n][c][h1][w1]
+ w1lambda * data1[n][c][h1][w1+w1p])
+ h1lambda * (w0lambda * data1[n][c][h1+h1p][w1]
+ w1lambda * data1[n][c][h1+h1p][w1+w1p]);
data2[n][c][h2][w2] = ScalarConvert<Acctype, Dtype>::to(val);
}
}
}
}

// Backward (adjoint) operation 1 <- 2 (accumulates)
template<typename xpu, typename Dtype, typename Acctype>
__global__ void caffe_gpu_interp2_kernel_backward(const int n,
Expand Down Expand Up @@ -181,9 +108,10 @@ void SpatialUpSamplingBilinearUpdateOutput(mshadow::Stream<gpu> *s,
dim3 blocks(static_cast<int>(num_kernels / num_threads) + 1);
dim3 threads(num_threads);
cudaStream_t stream = mshadow::Stream<gpu>::GetStream(s);
ImageLayout layout = NCHW;
stu1130 marked this conversation as resolved.
Show resolved Hide resolved
caffe_gpu_interp2_kernel<xpu, DType, AccReal>
<<<blocks, threads , 0, stream>>>(
num_kernels, rheight, rwidth, idata, odata);
num_kernels, rheight, rwidth, idata, odata, layout);
MSHADOW_CUDA_POST_KERNEL_CHECK(SpatialUpSamplingBilinearUpdateOutput);
}

Expand Down Expand Up @@ -215,6 +143,5 @@ NNVM_REGISTER_OP(_contrib_BilinearResize2D)

NNVM_REGISTER_OP(_backward_contrib_BilinearResize2D)
.set_attr<FCompute>("FCompute<gpu>", BilinearSampleOpBackward<gpu>);

} // namespace op
} // namespace mxnet
Loading