Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Performance improvement in ToTensor GPU Kernel #14099

Merged
100 changes: 70 additions & 30 deletions src/operator/image/image_random-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,18 @@ namespace mxnet {
namespace op {
namespace image {

// There are no parameters for this operator.
// Hence, no arameter registration.
using namespace mshadow;

#if MXNET_USE_CUDA
// NOTE: Kernel launch/map was extremely costly.
// Hence, we use separate CUDA kernels for these operators.
template<typename DType, typename T1, typename T2>
void ToTensorImplCUDA(mshadow::Stream<gpu> *s,
const T1 input,
const T2 output,
const int req,
const float normalize_factor);
#endif // MXNET_USE_CUDA

// Shape and Type inference for image to tensor operator
inline bool ToTensorShape(const nnvm::NodeAttrs& attrs,
Expand Down Expand Up @@ -78,37 +88,39 @@ inline bool ToTensorType(const nnvm::NodeAttrs& attrs,
}

// Operator Implementation

template<int req>
struct totensor_forward {
template<typename DType>
MSHADOW_XINLINE static void Map(uint32_t c, float* out_data, const DType* in_data,
const int length, const int channel, const int step,
const float normalize_factor = 255.0f) {
#pragma omp parallel for
template<typename DType, int req>
inline void ToTensor(float* out_data, const DType* in_data,
const int length,
const int channels,
const float normalize_factor,
const int step) {
// Microsoft Visual C++ compiler does not support omp collapse
#ifdef _MSC_VER
#pragma omp parallel for
#else
#pragma omp parallel for collapse(2)
#endif // _MSC_VER
for (int c = 0; c < channels; ++c) {
for (int i = 0; i < length; ++i) {
KERNEL_ASSIGN(out_data[step + c*length + i], req,
(in_data[step + i*channel + c]) / normalize_factor);
(in_data[step + i*channels + c]) / normalize_factor);
}
}
};

template<typename xpu>
void ToTensorImpl(const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<TBlob> &outputs,
const std::vector<OpReqType> &req,
const int length,
const uint32_t channel,
const int step = 0) {
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
}

inline void ToTensorImpl(const std::vector<TBlob> &inputs,
const std::vector<TBlob> &outputs,
const std::vector<OpReqType> &req,
const int length,
const int channel,
const float normalize_factor,
const int step) {
MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
float* output = outputs[0].dptr<float>();
DType* input = inputs[0].dptr<DType>();
mxnet_op::Kernel<totensor_forward<req_type>, xpu>::Launch(
s, channel, output, input, length, channel, step);
ToTensor<DType, req_type>(output, input, length, channel,
normalize_factor, step);
});
});
}
Expand All @@ -123,24 +135,52 @@ void ToTensorOpForward(const nnvm::NodeAttrs &attrs,
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req.size(), 1U);

// We do not use temp buffer when performance the operation.
// Hence, this check is necessary.
CHECK_EQ(req[0], kWriteTo)
<< "`to_tensor` does not support inplace updates";

// 3D Input - (h, w, c)
if (inputs[0].ndim() == 3) {
const float normalize_factor = 255.0f;

if (std::is_same<xpu, gpu>::value) {
#if MXNET_USE_CUDA
mshadow::Stream<gpu> *s = ctx.get_stream<gpu>();
MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
if (inputs[0].ndim() == 3) {
Tensor<gpu, 3, DType> input = inputs[0].get<gpu, 3, DType>(s);
Tensor<gpu, 3, float> output = outputs[0].get<gpu, 3, float>(s);
ToTensorImplCUDA<DType, Tensor<gpu, 3, DType>, Tensor<gpu, 3, float>>
(s, input, output, req_type, normalize_factor);
} else {
Tensor<gpu, 4, DType> input = inputs[0].get<gpu, 4, DType>(s);
Tensor<gpu, 4, float> output = outputs[0].get<gpu, 4, float>(s);
ToTensorImplCUDA<DType, Tensor<gpu, 4, DType>, Tensor<gpu, 4, float>>
(s, input, output, req_type, normalize_factor);
}
});
});
#else
LOG(FATAL) << "Compile with USE_CUDA=1 to use ToTensor operator on GPU.";
#endif // MXNET_USE_CUDA
sandeep-krishnamurthy marked this conversation as resolved.
Show resolved Hide resolved
} else if (inputs[0].ndim() == 3) {
// 3D Input - (h, w, c)
const int length = inputs[0].shape_[0] * inputs[0].shape_[1];
const uint32_t channel = inputs[0].shape_[2];
ToTensorImpl<xpu>(ctx, inputs, outputs, req, length, channel);
const int channel = static_cast<int>(inputs[0].shape_[2]);
const int step = 0;
ToTensorImpl(inputs, outputs, req, length,
channel, normalize_factor, step);
} else if (inputs[0].ndim() == 4) {
// 4D input (n, h, w, c)
const int batch_size = inputs[0].shape_[0];
const int length = inputs[0].shape_[1] * inputs[0].shape_[2];
const uint32_t channel = inputs[0].shape_[3];
const int channel = static_cast<int>(inputs[0].shape_[3]);
const int step = channel * length;

#pragma omp parallel for
for (auto n = 0; n < batch_size; ++n) {
ToTensorImpl<xpu>(ctx, inputs, outputs, req, length, channel, n*step);
ToTensorImpl(inputs, outputs, req, length, channel,
normalize_factor, n*step);
}
}
}
Expand Down
83 changes: 83 additions & 0 deletions src/operator/image/image_random.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,96 @@
* \file image_random.cu
* \brief GPU Implementation of image transformation operators
*/
#include <cuda_runtime_api.h>
#include "./image_random-inl.h"
#include "../elemwise_op_common.h"

namespace mxnet {
namespace op {
namespace image {

using namespace mshadow;

// ToTensor Kernel for 3D input
template<typename xpu, typename Dtype>
__global__ void ToTensorCudaKernel(const Tensor<xpu, 3, Dtype> input,
const Tensor<xpu, 3, float> output,
const int req,
const int N,
const int H,
const int W,
const int C,
const float normalize_factor) {
// We process one image per thread block.
// In 3D case, we have only 1 block i.e., blockIdx.x
// We do not use it.
for (int c = 0; c < C; ++c) {
for (int h = threadIdx.y; h < H; h += blockDim.y) {
for (int w = threadIdx.x; w < W; w += blockDim.x) {
KERNEL_ASSIGN(output[c][h][w], req,
input[h][w][c] / normalize_factor);
}
}
}
}

// ToTensor Kernel for 4D input
template<typename xpu, typename Dtype>
__global__ void ToTensorCudaKernel(const Tensor<xpu, 4, Dtype> input,
const Tensor<xpu, 4, float> output,
const int req,
const int N,
const int H,
const int W,
const int C,
const float normalize_factor) {
// We process one image per thread block.
const int n = blockIdx.x;

for (int c = 0; c < C; ++c) {
for (int h = threadIdx.y; h < H; h += blockDim.y) {
for (int w = threadIdx.x; w < W; w += blockDim.x) {
KERNEL_ASSIGN(output[n][c][h][w], req,
input[n][h][w][c] / normalize_factor);
}
}
}
}

template<typename DType, typename T1, typename T2>
void ToTensorImplCUDA(mshadow::Stream<gpu> *s,
const T1 input,
const T2 output,
const int req,
const float normalize_factor) {
int blocks, H, W, C, N;
cudaStream_t stream = mshadow::Stream<gpu>::GetStream(s);
if (std::is_same<T1, Tensor<gpu, 3, DType>>::value) {
// 3D Input - (H, W, C)
N = 0;
H = input.size(0);
W = input.size(1);
C = input.size(2);
blocks = 1;
} else {
// 4D Input - (N, H, W, C)
N = input.size(0);
H = input.size(1);
W = input.size(2);
C = input.size(3);
blocks = N > 0 ? N : 1;
blocks = N;
}
// One block per image.
// Number of threads = (32, 32) is optimal, because,
// computation is minimal and overhead of CUDA preparing
// all threads is minimal.
ToTensorCudaKernel<gpu, DType>
<<<blocks, dim3(32, 32), 0, stream>>>(input, output,
req, N, H, W, C, normalize_factor);
MSHADOW_CUDA_POST_KERNEL_CHECK(ToTensorCudaKernel);
}

NNVM_REGISTER_OP(_image_to_tensor)
.set_attr<FCompute>("FCompute<gpu>", ToTensorOpForward<gpu>);

Expand Down