Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Adding explanation in comment.
Browse files Browse the repository at this point in the history
  • Loading branch information
ptrendx authored and KellenSunderland committed Jan 16, 2019
1 parent 156127d commit cb0d731
Showing 1 changed file with 17 additions and 0 deletions.
17 changes: 17 additions & 0 deletions src/operator/spatial_transformer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,17 @@ template<typename DType>
__device__ bool between(DType value, int lowerBound, int upperBound) {
return (value >= lowerBound && value <= upperBound);
}

template<typename DType>
__global__ void
/*
* In order to not generate the code that uses too many
* registers (resulting in too many resources requested
* error) we need to tell the compiler that we will be
* launching this kernel with cuda::kMaxThreadsPerBlock
* threads per block. Setting __launch_bounds__ ensures
* that such configuration can always be launched.
*/
__launch_bounds__(cuda::kMaxThreadsPerBlock, 1)
BilinearSamplingForwardKernel(const int i_c, const int i_h,
const int i_w, const DType* data,
Expand Down Expand Up @@ -79,6 +88,14 @@ BilinearSamplingForwardKernel(const int i_c, const int i_h,
}
}

/*
* In order to not generate the code that uses too many
* registers (resulting in too many resources requested
* error) we need to tell the compiler that we will be
* launching this kernel with cuda::kMaxThreadsPerBlock
* threads per block. Setting __launch_bounds__ ensures
* that such configuration can always be launched.
*/
template<typename DType>
__global__ void
__launch_bounds__(cuda::kMaxThreadsPerBlock, 1)
Expand Down

0 comments on commit cb0d731

Please sign in to comment.