From c184b07128efacda031e5ceb52aca603c30d3063 Mon Sep 17 00:00:00 2001 From: ptredak Date: Thu, 8 Nov 2018 15:00:38 -0800 Subject: [PATCH 1/2] Fix launch bounds in spatial transformer --- src/operator/spatial_transformer.cu | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/src/operator/spatial_transformer.cu b/src/operator/spatial_transformer.cu index 33dbe3e7c069..1a16de3bab7c 100644 --- a/src/operator/spatial_transformer.cu +++ b/src/operator/spatial_transformer.cu @@ -36,11 +36,13 @@ __device__ bool between(DType value, int lowerBound, int upperBound) { return (value >= lowerBound && value <= upperBound); } template -__global__ void BilinearSamplingForwardKernel(const int i_c, const int i_h, - const int i_w, const DType* data, - const DType* grid, const int o_n, - const int o_c, const int o_h, - const int o_w, DType* out) { +__global__ void +__launch_bounds__(cuda::kMaxThreadsPerBlock, 1) +BilinearSamplingForwardKernel(const int i_c, const int i_h, + const int i_w, const DType* data, + const DType* grid, const int o_n, + const int o_c, const int o_h, + const int o_w, DType* out) { for (int index = (blockIdx.x + blockIdx.y * gridDim.x) * blockDim.x + threadIdx.x; index < o_n * o_c * o_h * o_w; index += blockDim.x * gridDim.x * gridDim.y) { @@ -78,12 +80,14 @@ __global__ void BilinearSamplingForwardKernel(const int i_c, const int i_h, } template -__global__ void BilinearSamplingBackwardKernel(const int i_c, const int i_h, - const int i_w, const DType* grad, - const DType* data, const int o_n, - const int o_c, const int o_h, - const int o_w, DType* g_input, - DType* grid_src) { +__global__ void +__launch_bounds__(cuda::kMaxThreadsPerBlock, 1) +BilinearSamplingBackwardKernel(const int i_c, const int i_h, + const int i_w, const DType* grad, + const DType* data, const int o_n, + const int o_c, const int o_h, + const int o_w, DType* g_input, + DType* grid_src) { for (int index = (blockIdx.x + blockIdx.y * gridDim.x) * blockDim.x + threadIdx.x; index < o_n * o_h * o_w; index += blockDim.x * gridDim.x * gridDim.y) { From cb20d537c1471883f711cc4d382d6ec75b8caa23 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Wed, 2 Jan 2019 15:20:31 -0800 Subject: [PATCH 2/2] Adding explanation in comment. --- src/operator/spatial_transformer.cu | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/src/operator/spatial_transformer.cu b/src/operator/spatial_transformer.cu index 1a16de3bab7c..fd330bd4ca87 100644 --- a/src/operator/spatial_transformer.cu +++ b/src/operator/spatial_transformer.cu @@ -35,8 +35,17 @@ template __device__ bool between(DType value, int lowerBound, int upperBound) { return (value >= lowerBound && value <= upperBound); } + template __global__ void +/* + * In order to not generate the code that uses too many + * registers (resulting in too many resources requested + * error) we need to tell the compiler that we will be + * launching this kernel with cuda::kMaxThreadsPerBlock + * threads per block. Setting __launch_bounds__ ensures + * that such configuration can always be launched. + */ __launch_bounds__(cuda::kMaxThreadsPerBlock, 1) BilinearSamplingForwardKernel(const int i_c, const int i_h, const int i_w, const DType* data, @@ -79,6 +88,14 @@ BilinearSamplingForwardKernel(const int i_c, const int i_h, } } +/* + * In order to not generate the code that uses too many + * registers (resulting in too many resources requested + * error) we need to tell the compiler that we will be + * launching this kernel with cuda::kMaxThreadsPerBlock + * threads per block. Setting __launch_bounds__ ensures + * that such configuration can always be launched. + */ template __global__ void __launch_bounds__(cuda::kMaxThreadsPerBlock, 1)