From ab5a0cf6cf87f046d98397edbced251fe6173d6c Mon Sep 17 00:00:00 2001 From: Sandeep Krishnamurthy Date: Mon, 11 Feb 2019 15:44:17 -0800 Subject: [PATCH] Performance improvement in ToTensor GPU Kernel (#14099) * CPU implementation without Kernel launch/map * Optimal CUDA support for 3D ToTensor operator * Add CUDA kernel for 4D inputs * Fix failing CPU tests for totensor * disable warning on windows * try fix in instance norm windows build failure * Guard omp parallel collapse for windows * Remove warning supression to check if it is ok * fix lint issues * Address code review comments --- src/operator/image/image_random-inl.h | 100 ++++++++++++++++++-------- src/operator/image/image_random.cu | 83 +++++++++++++++++++++ 2 files changed, 153 insertions(+), 30 deletions(-) diff --git a/src/operator/image/image_random-inl.h b/src/operator/image/image_random-inl.h index 448016341f21..392fff4dbf81 100644 --- a/src/operator/image/image_random-inl.h +++ b/src/operator/image/image_random-inl.h @@ -43,8 +43,18 @@ namespace mxnet { namespace op { namespace image { -// There are no parameters for this operator. -// Hence, no arameter registration. +using namespace mshadow; + +#if MXNET_USE_CUDA +// NOTE: Kernel launch/map was extremely costly. +// Hence, we use separate CUDA kernels for these operators. +template +void ToTensorImplCUDA(mshadow::Stream *s, + const T1 input, + const T2 output, + const int req, + const float normalize_factor); +#endif // MXNET_USE_CUDA // Shape and Type inference for image to tensor operator inline bool ToTensorShape(const nnvm::NodeAttrs& attrs, @@ -78,37 +88,39 @@ inline bool ToTensorType(const nnvm::NodeAttrs& attrs, } // Operator Implementation - -template -struct totensor_forward { - template - MSHADOW_XINLINE static void Map(uint32_t c, float* out_data, const DType* in_data, - const int length, const int channel, const int step, - const float normalize_factor = 255.0f) { - #pragma omp parallel for +template +inline void ToTensor(float* out_data, const DType* in_data, + const int length, + const int channels, + const float normalize_factor, + const int step) { + // Microsoft Visual C++ compiler does not support omp collapse + #ifdef _MSC_VER + #pragma omp parallel for + #else + #pragma omp parallel for collapse(2) + #endif // _MSC_VER + for (int c = 0; c < channels; ++c) { for (int i = 0; i < length; ++i) { KERNEL_ASSIGN(out_data[step + c*length + i], req, - (in_data[step + i*channel + c]) / normalize_factor); + (in_data[step + i*channels + c]) / normalize_factor); } } -}; - -template -void ToTensorImpl(const OpContext &ctx, - const std::vector &inputs, - const std::vector &outputs, - const std::vector &req, - const int length, - const uint32_t channel, - const int step = 0) { - mshadow::Stream *s = ctx.get_stream(); +} +inline void ToTensorImpl(const std::vector &inputs, + const std::vector &outputs, + const std::vector &req, + const int length, + const int channel, + const float normalize_factor, + const int step) { MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, { MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { float* output = outputs[0].dptr(); DType* input = inputs[0].dptr(); - mxnet_op::Kernel, xpu>::Launch( - s, channel, output, input, length, channel, step); + ToTensor(output, input, length, channel, + normalize_factor, step); }); }); } @@ -123,24 +135,52 @@ void ToTensorOpForward(const nnvm::NodeAttrs &attrs, CHECK_EQ(outputs.size(), 1U); CHECK_EQ(req.size(), 1U); + // We do not use temp buffer when performance the operation. + // Hence, this check is necessary. CHECK_EQ(req[0], kWriteTo) << "`to_tensor` does not support inplace updates"; - // 3D Input - (h, w, c) - if (inputs[0].ndim() == 3) { + const float normalize_factor = 255.0f; + + if (std::is_same::value) { + #if MXNET_USE_CUDA + mshadow::Stream *s = ctx.get_stream(); + MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, { + MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { + if (inputs[0].ndim() == 3) { + Tensor input = inputs[0].get(s); + Tensor output = outputs[0].get(s); + ToTensorImplCUDA, Tensor> + (s, input, output, req_type, normalize_factor); + } else { + Tensor input = inputs[0].get(s); + Tensor output = outputs[0].get(s); + ToTensorImplCUDA, Tensor> + (s, input, output, req_type, normalize_factor); + } + }); + }); + #else + LOG(FATAL) << "Compile with USE_CUDA=1 to use ToTensor operator on GPU."; + #endif // MXNET_USE_CUDA + } else if (inputs[0].ndim() == 3) { + // 3D Input - (h, w, c) const int length = inputs[0].shape_[0] * inputs[0].shape_[1]; - const uint32_t channel = inputs[0].shape_[2]; - ToTensorImpl(ctx, inputs, outputs, req, length, channel); + const int channel = static_cast(inputs[0].shape_[2]); + const int step = 0; + ToTensorImpl(inputs, outputs, req, length, + channel, normalize_factor, step); } else if (inputs[0].ndim() == 4) { // 4D input (n, h, w, c) const int batch_size = inputs[0].shape_[0]; const int length = inputs[0].shape_[1] * inputs[0].shape_[2]; - const uint32_t channel = inputs[0].shape_[3]; + const int channel = static_cast(inputs[0].shape_[3]); const int step = channel * length; #pragma omp parallel for for (auto n = 0; n < batch_size; ++n) { - ToTensorImpl(ctx, inputs, outputs, req, length, channel, n*step); + ToTensorImpl(inputs, outputs, req, length, channel, + normalize_factor, n*step); } } } diff --git a/src/operator/image/image_random.cu b/src/operator/image/image_random.cu index 5f9aff27e85b..6fe53832a89e 100644 --- a/src/operator/image/image_random.cu +++ b/src/operator/image/image_random.cu @@ -21,6 +21,7 @@ * \file image_random.cu * \brief GPU Implementation of image transformation operators */ +#include #include "./image_random-inl.h" #include "../elemwise_op_common.h" @@ -28,6 +29,88 @@ namespace mxnet { namespace op { namespace image { +using namespace mshadow; + +// ToTensor Kernel for 3D input +template +__global__ void ToTensorCudaKernel(const Tensor input, + const Tensor output, + const int req, + const int N, + const int H, + const int W, + const int C, + const float normalize_factor) { + // We process one image per thread block. + // In 3D case, we have only 1 block i.e., blockIdx.x + // We do not use it. + for (int c = 0; c < C; ++c) { + for (int h = threadIdx.y; h < H; h += blockDim.y) { + for (int w = threadIdx.x; w < W; w += blockDim.x) { + KERNEL_ASSIGN(output[c][h][w], req, + input[h][w][c] / normalize_factor); + } + } + } +} + +// ToTensor Kernel for 4D input +template +__global__ void ToTensorCudaKernel(const Tensor input, + const Tensor output, + const int req, + const int N, + const int H, + const int W, + const int C, + const float normalize_factor) { + // We process one image per thread block. + const int n = blockIdx.x; + + for (int c = 0; c < C; ++c) { + for (int h = threadIdx.y; h < H; h += blockDim.y) { + for (int w = threadIdx.x; w < W; w += blockDim.x) { + KERNEL_ASSIGN(output[n][c][h][w], req, + input[n][h][w][c] / normalize_factor); + } + } + } +} + +template +void ToTensorImplCUDA(mshadow::Stream *s, + const T1 input, + const T2 output, + const int req, + const float normalize_factor) { + int blocks, H, W, C, N; + cudaStream_t stream = mshadow::Stream::GetStream(s); + if (std::is_same>::value) { + // 3D Input - (H, W, C) + N = 0; + H = input.size(0); + W = input.size(1); + C = input.size(2); + blocks = 1; + } else { + // 4D Input - (N, H, W, C) + N = input.size(0); + H = input.size(1); + W = input.size(2); + C = input.size(3); + blocks = N > 0 ? N : 1; + blocks = N; + } + // One block per image. + // Number of threads = (32, 32) is optimal, because, + // computation is minimal and overhead of CUDA preparing + // all threads is minimal. + ToTensorCudaKernel + <<>>(input, output, + req, N, H, W, C, normalize_factor); + MSHADOW_CUDA_POST_KERNEL_CHECK(ToTensorCudaKernel); +} + NNVM_REGISTER_OP(_image_to_tensor) .set_attr("FCompute", ToTensorOpForward);