Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix launch bounds in spatial transformer
Browse files Browse the repository at this point in the history
  • Loading branch information
ptrendx committed Nov 8, 2018
1 parent 99534c9 commit f322d69
Showing 1 changed file with 15 additions and 11 deletions.
26 changes: 15 additions & 11 deletions src/operator/spatial_transformer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,13 @@ __device__ bool between(DType value, int lowerBound, int upperBound) {
return (value >= lowerBound && value <= upperBound);
}
template<typename DType>
__global__ void BilinearSamplingForwardKernel(const int i_c, const int i_h,
const int i_w, const DType* data,
const DType* grid, const int o_n,
const int o_c, const int o_h,
const int o_w, DType* out) {
__global__ void
__launch_bounds__(cuda::kMaxThreadsPerBlock, 1)
BilinearSamplingForwardKernel(const int i_c, const int i_h,
const int i_w, const DType* data,
const DType* grid, const int o_n,
const int o_c, const int o_h,
const int o_w, DType* out) {
for (int index = (blockIdx.x + blockIdx.y * gridDim.x) * blockDim.x + threadIdx.x;
index < o_n * o_c * o_h * o_w;
index += blockDim.x * gridDim.x * gridDim.y) {
Expand Down Expand Up @@ -78,12 +80,14 @@ __global__ void BilinearSamplingForwardKernel(const int i_c, const int i_h,
}

template<typename DType>
__global__ void BilinearSamplingBackwardKernel(const int i_c, const int i_h,
const int i_w, const DType* grad,
const DType* data, const int o_n,
const int o_c, const int o_h,
const int o_w, DType* g_input,
DType* grid_src) {
__global__ void
__launch_bounds__(cuda::kMaxThreadsPerBlock, 1)
BilinearSamplingBackwardKernel(const int i_c, const int i_h,
const int i_w, const DType* grad,
const DType* data, const int o_n,
const int o_c, const int o_h,
const int o_w, DType* g_input,
DType* grid_src) {
for (int index = (blockIdx.x + blockIdx.y * gridDim.x) * blockDim.x + threadIdx.x;
index < o_n * o_h * o_w;
index += blockDim.x * gridDim.x * gridDim.y) {
Expand Down

0 comments on commit f322d69

Please sign in to comment.