From dad33b51b4a3a05f9bb94f781628494a409e80e4 Mon Sep 17 00:00:00 2001 From: Sandeep Krishnamurthy Date: Thu, 14 Feb 2019 11:55:43 -0800 Subject: [PATCH] Performance improvement in Normalize GPU Kernel (#14139) * New CPU kernel for normalize * New GPU kernel for Normalize * Add launch bounds and increase threads to 32*32 * do not hardcode number of threads * Try fix windows build failure * make channels as int to fix windows build issues with omp * Simplify cuda kernels with 1 D thread block * Minor refactoring * Revert thread dim for ToTensor operator --- src/operator/image/image_random-inl.h | 307 ++++++++++++++++---------- src/operator/image/image_random.cu | 170 ++++++++++++-- 2 files changed, 336 insertions(+), 141 deletions(-) diff --git a/src/operator/image/image_random-inl.h b/src/operator/image/image_random-inl.h index 392fff4dbf81..0f4d173be79a 100644 --- a/src/operator/image/image_random-inl.h +++ b/src/operator/image/image_random-inl.h @@ -54,6 +54,35 @@ void ToTensorImplCUDA(mshadow::Stream *s, const T2 output, const int req, const float normalize_factor); + +template +void NormalizeImplCUDA(mshadow::Stream *s, + const DType *input, + DType *output, + const int req, + const int N, + const int C, + const int H, + const int W, + const float mean_d0, + const float mean_d1, + const float mean_d2, + const float std_d0, + const float std_d1, + const float std_d2); + +template +void NormalizeBackwardImplCUDA(mshadow::Stream *s, + const DType *out_grad, + DType *in_grad, + const int req, + const int N, + const int C, + const int H, + const int W, + const float std_d0, + const float std_d1, + const float std_d2); #endif // MXNET_USE_CUDA // Shape and Type inference for image to tensor operator @@ -254,156 +283,165 @@ inline bool NormalizeOpType(const nnvm::NodeAttrs& attrs, return out_attrs->at(0) != -1; } -template -struct normalize_forward { - template - MSHADOW_XINLINE static void Map(uint32_t c, DType* out_data, const DType* in_data, - const float mean_d0, const float mean_d1, const float mean_d2, - const float std_d0, const float std_d1, const float std_d2, - const int length, const int step) { - float mean, std; - switch (c) { - case 0 : mean = mean_d0; - std = std_d0; - break; - case 1 : mean = mean_d1; - std = std_d1; - break; - case 2 : mean = mean_d2; - std = std_d2; - break; - } - #pragma omp parallel for - for (int i = 0; i < length; ++i) { - KERNEL_ASSIGN(out_data[step + c*length + i], req, - (in_data[step + c*length + i] - mean) / std); - } +template +inline void Normalize(DType* out_data, + const DType* in_data, + const int length, + const int channels, + const int step, + const std::vector mean, + const std::vector std) { + // Microsoft Visual C++ compiler does not support omp collapse + #ifdef _MSC_VER + #pragma omp parallel for + #else + #pragma omp parallel for collapse(2) + #endif // _MSC_VER + for (int c = 0; c < channels; ++c) { + for (int i = 0; i < length; ++i) { + KERNEL_ASSIGN(out_data[step + c*length + i], req, + (in_data[step + c*length + i] - mean[c]) / std[c]); } -}; - -template -void NormalizeImpl(const OpContext &ctx, - const std::vector &inputs, - const std::vector &outputs, - const std::vector &req, - const float mean_d0, const float mean_d1, - const float mean_d2, const float std_d0, - const float std_d1, const float std_d2, - const int length, - const uint32_t channel, - const int step = 0) { - mshadow::Stream *s = ctx.get_stream(); + } +} - MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { - MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { - DType* input = inputs[0].dptr(); - DType* output = outputs[0].dptr(); - mxnet_op::Kernel, xpu>::Launch( - s, channel, output, input, mean_d0, mean_d1, mean_d2, - std_d0, std_d1, std_d2, length, step); - }); +inline void NormalizeImpl(const std::vector &inputs, + const std::vector &outputs, + const std::vector &req, + const int length, + const int channels, + const int step, + const std::vector mean, + const std::vector std) { + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { + MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { + DType* input = inputs[0].dptr(); + DType* output = outputs[0].dptr(); + Normalize(output, input, length, channels, step, + mean, std); }); + }); } template void NormalizeOpForward(const nnvm::NodeAttrs &attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { CHECK_EQ(inputs.size(), 1U); CHECK_EQ(outputs.size(), 1U); CHECK_EQ(req.size(), 1U); const NormalizeParam ¶m = nnvm::get(attrs.parsed); - // Note: We need mean and std_dev in the kernel. - // It is costly (device copy) to pass it as vector, for gpu kernel. - // Hence, passing it as below for performance. - float mean_d0, mean_d1, mean_d2; - float std_d0, std_d1, std_d2; - - // Mean and Std can be 1 or 3 D only. + // Mean and Std can be 1 or 3D only. + std::vector mean(3); + std::vector std(3); if (param.mean.ndim() == 1) { - mean_d0 = mean_d1 = mean_d2 = param.mean[0]; + mean[0] = mean[1] = mean[3] = param.mean[0]; } else { - mean_d0 = param.mean[0]; - mean_d1 = param.mean[1]; - mean_d2 = param.mean[2]; + mean[0] = param.mean[0]; + mean[1] = param.mean[1]; + mean[2] = param.mean[2]; } if (param.std.ndim() == 1) { - std_d0 = std_d1 = std_d2 = param.std[0]; + std[0] = std[1] = std[2] = param.std[0]; } else { - std_d0 = param.std[0]; - std_d1 = param.std[1]; - std_d2 = param.std[2]; + std[0] = param.std[0]; + std[1] = param.std[1]; + std[2] = param.std[2]; } - // 3D input (c, h, w) - if (inputs[0].ndim() == 3) { + if (std::is_same::value) { + #if MXNET_USE_CUDA + mshadow::Stream *s = ctx.get_stream(); + MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, { + MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { + int N, C, H, W; + DType *input = nullptr; + DType *output = nullptr; + if (inputs[0].ndim() == 3) { + N = 1; + C = static_cast(inputs[0].shape_[0]); + H = static_cast(inputs[0].shape_[1]); + W = static_cast(inputs[0].shape_[2]); + input = (inputs[0].get(s)).dptr_; + output = (outputs[0].get(s)).dptr_; + } else { + N = static_cast(inputs[0].shape_[0]); + C = static_cast(inputs[0].shape_[1]); + H = static_cast(inputs[0].shape_[2]); + W = static_cast(inputs[0].shape_[3]); + input = (inputs[0].get(s)).dptr_; + output = (outputs[0].get(s)).dptr_; + } + NormalizeImplCUDA(s, input, output, req_type, + N, C, H, W, + mean[0], mean[1], mean[2], + std[0], std[1], std[2]); + }); + }); + #else + LOG(FATAL) << "Compile with USE_CUDA=1 to use Normalize operator on GPU."; + #endif // MXNET_USE_CUDA + } else if (inputs[0].ndim() == 3) { + // 3D input (c, h, w) const int length = inputs[0].shape_[1] * inputs[0].shape_[2]; - const uint32_t channel = inputs[0].shape_[0]; - NormalizeImpl(ctx, inputs, outputs, req, mean_d0, mean_d1, mean_d2, - std_d0, std_d1, std_d2, length, channel); + const int channel = static_cast(inputs[0].shape_[0]); + const int step = 0; + NormalizeImpl(inputs, outputs, req, length, channel, step, mean, std); } else if (inputs[0].ndim() == 4) { // 4D input (n, c, h, w) const int batch_size = inputs[0].shape_[0]; const int length = inputs[0].shape_[2] * inputs[0].shape_[3]; - const uint32_t channel = inputs[0].shape_[1]; + const int channel = static_cast(inputs[0].shape_[1]); const int step = channel * length; #pragma omp parallel for for (auto n = 0; n < batch_size; ++n) { - NormalizeImpl(ctx, inputs, outputs, req, mean_d0, mean_d1, mean_d2, - std_d0, std_d1, std_d2, length, channel, n*step); + NormalizeImpl(inputs, outputs, req, length, channel, n*step, mean, std); } } } // Backward function -template -struct normalize_backward { - template - MSHADOW_XINLINE static void Map(uint32_t c, DType* in_grad, const DType* out_grad, - const float std_d0, const float std_d1, const float std_d2, - const int length, const int step) { - // d/dx{(x - mean) / std_dev} => (1 / std_dev) - float std_dev; - switch (c) { - case 0 : std_dev = std_d0; - break; - case 1 : std_dev = std_d1; - break; - case 2 : std_dev = std_d2; - break; - } - +template +inline void NormalizeBackward(const DType* out_grad, + DType* in_grad, + const int length, + const int channels, + const int step, + const std::vector std) { + // Microsoft Visual C++ compiler does not support omp collapse + #ifdef _MSC_VER #pragma omp parallel for + #else + #pragma omp parallel for collapse(2) + #endif // _MSC_VER + for (int c = 0; c < channels; ++c) { for (int i = 0; i < length; ++i) { KERNEL_ASSIGN(in_grad[step + c*length + i], req, - out_grad[step + c*length + i] * (1.0 / std_dev)); + out_grad[step + c*length + i] * (1.0 / std[c])); } } -}; - -template -void NormalizeBackwardImpl(const OpContext &ctx, - const std::vector &inputs, - const std::vector &outputs, - const std::vector &req, - const float std_d0, const float std_d1, const float std_d2, - const int length, - const uint32_t channel, - const int step = 0) { - mshadow::Stream *s = ctx.get_stream(); +} +inline void NormalizeBackwardImpl(const std::vector &inputs, + const std::vector &outputs, + const std::vector &req, + const int length, + const int channels, + const int step, + const std::vector std + ) { MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { DType* out_grad = inputs[0].dptr(); DType* in_grad = outputs[0].dptr(); - mxnet_op::Kernel, xpu>::Launch( - s, channel, in_grad, out_grad, std_d0, std_d1, std_d2, length, step); + NormalizeBackward(out_grad, in_grad, length, + channels, step, std); }); }); } @@ -419,37 +457,66 @@ void NormalizeOpBackward(const nnvm::NodeAttrs &attrs, CHECK_EQ(req.size(), 1U); const NormalizeParam ¶m = nnvm::get(attrs.parsed); - float std_d0, std_d1, std_d2; - - // Std can be 1 or 3 D only + // Std can be 1 or 3D only. + std::vector std(3); if (param.std.ndim() == 1) { - std_d0 = std_d1 = std_d2 = param.std[0]; + std[0] = std[1] = std[2] = param.std[0]; } else { - std_d0 = param.std[0]; - std_d1 = param.std[1]; - std_d2 = param.std[2]; + std[0] = param.std[0]; + std[1] = param.std[1]; + std[2] = param.std[2]; } // Note: inputs[0] is out_grad const TBlob& in_data = inputs[1]; - // 3D input (c, h, w) - if (in_data.ndim() == 3) { + if (std::is_same::value) { + #if MXNET_USE_CUDA + mshadow::Stream *s = ctx.get_stream(); + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { + MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { + int N, C, H, W; + DType *in_grad = nullptr; + DType *out_grad = nullptr; + if (in_data.ndim() == 3) { + N = 1; + C = static_cast(in_data.shape_[0]); + H = static_cast(in_data.shape_[1]); + W = static_cast(in_data.shape_[2]); + out_grad = (inputs[0].get(s)).dptr_; + in_grad = (outputs[0].get(s)).dptr_; + } else { + N = static_cast(in_data.shape_[0]); + C = static_cast(in_data.shape_[1]); + H = static_cast(in_data.shape_[2]); + W = static_cast(in_data.shape_[3]); + out_grad = (inputs[0].get(s)).dptr_; + in_grad = (outputs[0].get(s)).dptr_; + } + NormalizeBackwardImplCUDA(s, out_grad, in_grad, req_type, + N, C, H, W, + std[0], std[1], std[2]); + }); + }); + #else + LOG(FATAL) << "Compile with USE_CUDA=1 to use Normalize backward operator on GPU."; + #endif // MXNET_USE_CUDA + } else if (in_data.ndim() == 3) { + // 3D input (c, h, w) const int length = in_data.shape_[1] * in_data.shape_[2]; - const uint32_t channel = in_data.shape_[0]; - NormalizeBackwardImpl(ctx, inputs, outputs, req, std_d0, std_d1, std_d2, length, channel); + const int channel = static_cast(in_data.shape_[0]); + const int step = 0; + NormalizeBackwardImpl(inputs, outputs, req, length, channel, step, std); } else if (in_data.ndim() == 4) { // 4D input (n, c, h, w) const int batch_size = in_data.shape_[0]; const int length = in_data.shape_[2] * in_data.shape_[3]; - const uint32_t channel = in_data.shape_[1]; + const int channel = static_cast(in_data.shape_[1]); const int step = channel * length; #pragma omp parallel for for (auto n = 0; n < batch_size; ++n) { - NormalizeBackwardImpl(ctx, inputs, outputs, req, - std_d0, std_d1, std_d2, length, - channel, n*step); + NormalizeBackwardImpl(inputs, outputs, req, length, channel, n*step, std); } } } diff --git a/src/operator/image/image_random.cu b/src/operator/image/image_random.cu index 6fe53832a89e..5dfbbd3b6b00 100644 --- a/src/operator/image/image_random.cu +++ b/src/operator/image/image_random.cu @@ -32,15 +32,25 @@ namespace image { using namespace mshadow; // ToTensor Kernel for 3D input +/* + * In order to not generate the code that uses too many + * registers (resulting in too many resources requested + * error) we need to tell the compiler that we will be + * launching this kernel with cuda::kMaxThreadsPerBlock + * threads per block. Setting __launch_bounds__ ensures + * that such configuration can always be launched. + */ template -__global__ void ToTensorCudaKernel(const Tensor input, - const Tensor output, - const int req, - const int N, - const int H, - const int W, - const int C, - const float normalize_factor) { +__global__ void +__launch_bounds__(cuda::kMaxThreadsPerBlock, 1) +ToTensorCudaKernel(const Tensor input, + const Tensor output, + const int req, + const int N, + const int H, + const int W, + const int C, + const float normalize_factor) { // We process one image per thread block. // In 3D case, we have only 1 block i.e., blockIdx.x // We do not use it. @@ -56,14 +66,16 @@ __global__ void ToTensorCudaKernel(const Tensor input, // ToTensor Kernel for 4D input template -__global__ void ToTensorCudaKernel(const Tensor input, - const Tensor output, - const int req, - const int N, - const int H, - const int W, - const int C, - const float normalize_factor) { +__global__ void +__launch_bounds__(cuda::kMaxThreadsPerBlock, 1) +ToTensorCudaKernel(const Tensor input, + const Tensor output, + const int req, + const int N, + const int H, + const int W, + const int C, + const float normalize_factor) { // We process one image per thread block. const int n = blockIdx.x; @@ -99,18 +111,134 @@ void ToTensorImplCUDA(mshadow::Stream *s, W = input.size(2); C = input.size(3); blocks = N > 0 ? N : 1; - blocks = N; } - // One block per image. - // Number of threads = (32, 32) is optimal, because, - // computation is minimal and overhead of CUDA preparing - // all threads is minimal. + ToTensorCudaKernel <<>>(input, output, req, N, H, W, C, normalize_factor); MSHADOW_CUDA_POST_KERNEL_CHECK(ToTensorCudaKernel); } +// Normalize Forward CUDA Kernel +template +__global__ void +__launch_bounds__(cuda::kMaxThreadsPerBlock, 1) +NormalizeCudaKernel(const DType* input, + DType* output, + const int req, + const int N, + const int C, + const int H, + const int W, + const float mean_d0, + const float mean_d1, + const float mean_d2, + const float std_d0, + const float std_d1, + const float std_d2) { + // We process one image per thread block. + const int n = blockIdx.x; + const int length = H * W; + const int step = C * length * n; + + float mean = mean_d0; + float std = std_d0; + for (int c = 0; c < C; ++c) { + switch (c) { + case 0 : break; + case 1 : mean = mean_d1; + std = std_d1; + break; + case 2 : mean = mean_d2; + std = std_d2; + break; + } + for (int i = threadIdx.x; i < length; i += blockDim.x) { + KERNEL_ASSIGN(*(output + step + i + (c * length)), req, + (*(input + step + i + (c * length)) - mean) / std); + } + } +} + +template +void NormalizeImplCUDA(mshadow::Stream *s, + const DType* input, + DType* output, + const int req, + const int N, + const int C, + const int H, + const int W, + const float mean_d0, + const float mean_d1, + const float mean_d2, + const float std_d0, + const float std_d1, + const float std_d2) { + cudaStream_t stream = mshadow::Stream::GetStream(s); + NormalizeCudaKernel + // 1 image per block. N is batch size. + <<>>(input, output, + req, N, C, H, W, mean_d0, mean_d1, mean_d2, + std_d0, std_d1, std_d2); + MSHADOW_CUDA_POST_KERNEL_CHECK(NormalizeCudaKernel); +} + +// Normalize Backward Kernel +template +__global__ void +__launch_bounds__(cuda::kMaxThreadsPerBlock, 1) +NormalizeBackwardCudaKernel(const DType *out_grad, + DType *in_grad, + const int req, + const int N, + const int C, + const int H, + const int W, + const float std_d0, + const float std_d1, + const float std_d2) { + // We process one image per thread block. + const int n = blockIdx.x; + const int length = H * W; + const int step = C * length * n; + + float std = std_d0; + for (int c = 0; c < C; ++c) { + switch (c) { + case 0 : break; + case 1 : std = std_d1; + break; + case 2 : std = std_d2; + break; + } + for (int i = threadIdx.x; i < length; i += blockDim.x) { + KERNEL_ASSIGN(*(in_grad + step + i + (c * length)), req, + *(out_grad + step + i + (c * length)) * (1.0 / std)); + } + } +} + +template +void NormalizeBackwardImplCUDA(mshadow::Stream *s, + const DType *out_grad, + DType *in_grad, + const int req, + const int N, + const int C, + const int H, + const int W, + const float std_d0, + const float std_d1, + const float std_d2) { + cudaStream_t stream = mshadow::Stream::GetStream(s); + NormalizeBackwardCudaKernel + // 1 image per block. N is batch size. + <<>>(out_grad, in_grad, + req, N, C, H, W, std_d0, std_d1, std_d2); + MSHADOW_CUDA_POST_KERNEL_CHECK(NormalizeBackwardCudaKernel); +} + NNVM_REGISTER_OP(_image_to_tensor) .set_attr("FCompute", ToTensorOpForward);