Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Fix launch bounds in spatial transformer #13188

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 32 additions & 11 deletions src/operator/spatial_transformer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,23 @@ template<typename DType>
__device__ bool between(DType value, int lowerBound, int upperBound) {
return (value >= lowerBound && value <= upperBound);
}

template<typename DType>
__global__ void BilinearSamplingForwardKernel(const int i_c, const int i_h,
const int i_w, const DType* data,
const DType* grid, const int o_n,
const int o_c, const int o_h,
const int o_w, DType* out) {
__global__ void
/*
* In order to not generate the code that uses too many
* registers (resulting in too many resources requested
* error) we need to tell the compiler that we will be
* launching this kernel with cuda::kMaxThreadsPerBlock
* threads per block. Setting __launch_bounds__ ensures
* that such configuration can always be launched.
*/
__launch_bounds__(cuda::kMaxThreadsPerBlock, 1)
BilinearSamplingForwardKernel(const int i_c, const int i_h,
const int i_w, const DType* data,
const DType* grid, const int o_n,
const int o_c, const int o_h,
const int o_w, DType* out) {
for (int index = (blockIdx.x + blockIdx.y * gridDim.x) * blockDim.x + threadIdx.x;
index < o_n * o_c * o_h * o_w;
index += blockDim.x * gridDim.x * gridDim.y) {
Expand Down Expand Up @@ -77,13 +88,23 @@ __global__ void BilinearSamplingForwardKernel(const int i_c, const int i_h,
}
}

/*
* In order to not generate the code that uses too many
* registers (resulting in too many resources requested
* error) we need to tell the compiler that we will be
* launching this kernel with cuda::kMaxThreadsPerBlock
* threads per block. Setting __launch_bounds__ ensures
* that such configuration can always be launched.
*/
template<typename DType>
__global__ void BilinearSamplingBackwardKernel(const int i_c, const int i_h,
const int i_w, const DType* grad,
const DType* data, const int o_n,
const int o_c, const int o_h,
const int o_w, DType* g_input,
DType* grid_src) {
__global__ void
__launch_bounds__(cuda::kMaxThreadsPerBlock, 1)
BilinearSamplingBackwardKernel(const int i_c, const int i_h,
const int i_w, const DType* grad,
const DType* data, const int o_n,
const int o_c, const int o_h,
const int o_w, DType* g_input,
DType* grid_src) {
for (int index = (blockIdx.x + blockIdx.y * gridDim.x) * blockDim.x + threadIdx.x;
index < o_n * o_h * o_w;
index += blockDim.x * gridDim.x * gridDim.y) {
Expand Down