diff --git a/paddle/phi/kernels/gpu/grid_sample_kernel.cu b/paddle/phi/kernels/gpu/grid_sample_kernel.cu index 8dc9ba0a8f3fe1..a1013c0d55c1f6 100644 --- a/paddle/phi/kernels/gpu/grid_sample_kernel.cu +++ b/paddle/phi/kernels/gpu/grid_sample_kernel.cu @@ -23,23 +23,23 @@ namespace phi { -template +template static __forceinline__ __device__ T Unnormalize(T coord, - int size, + IndexT size, bool align_corners) { return align_corners ? ((coord + 1.f) / 2) * (size - 1) : ((coord + 1.f) * size - 1) / 2; } -template -static __forceinline__ __device__ T ClipIndexes(T in, int max_value) { +template +static __forceinline__ __device__ T ClipIndexes(T in, IndexT max_value) { return min(static_cast(max_value - 1), max(in, static_cast(0))); } -template +template static __forceinline__ __device__ T ReflectIndexes(T in, - int twice_low, - int twice_high) { + IndexT twice_low, + IndexT twice_high) { if (twice_low == twice_high) { return static_cast(0); } @@ -47,58 +47,52 @@ static __forceinline__ __device__ T ReflectIndexes(T in, T span = static_cast(twice_high - twice_low) / 2; in = fabs(in - min); T extra = fmod(in, span); - int flips = static_cast(floor(in / span)); + IndexT flips = floor(in / span); return (flips & 1) ? span - extra + min : extra + min; // cond ? odd : even } -template +template static __forceinline__ __device__ T ComputePositions(T coord, - int size, + IndexT size, PaddingMode padding_mode, bool align_corners) { - coord = Unnormalize(coord, size, align_corners); + coord = Unnormalize(coord, size, align_corners); if (padding_mode == PaddingMode::border) { coord = ClipIndexes(coord, size); } else if (padding_mode == PaddingMode::reflect) { - coord = align_corners ? ReflectIndexes(coord, 0, 2 * (size - 1)) - : ReflectIndexes(coord, -1, 2 * size - 1); + coord = align_corners ? ReflectIndexes(coord, 0, 2 * (size - 1)) + : ReflectIndexes(coord, -1, 2 * size - 1); coord = ClipIndexes(coord, size); } return SafeDownGradeToIntRange(coord); } -template -__global__ void GridSampleCudaKernel(const int nthreads, - int n, - int out_c, - int out_h, - int out_w, - int in_h, - int in_w, - const T* input, - const T* grid, - T* output, +template +__global__ void GridSampleCudaKernel(IndexT n, + IndexT out_c, + IndexT out_hw, + IndexT in_h, + IndexT in_w, + const T* __restrict__ input, + const T* __restrict__ grid, + T* __restrict__ output, const Mode mode, const PaddingMode padding_mode, bool align_corners) { - int inp_sN = out_c * in_h * in_w; - - int inp_sC = in_h * in_w; - int inp_sH = in_w; - int inp_sW = 1; - int grid_sN = out_h * out_w * 2; - int grid_sH = out_w * 2; - int grid_sW = 2; - int grid_sCoor = 1; - int out_sN = out_c * out_h * out_w; - int out_sC = out_h * out_w; - int out_sH = out_w; - int out_sW = 1; - CUDA_KERNEL_LOOP(index, nthreads) { - const int w = index % out_w; - const int h = (index / out_w) % out_h; - const int n = index / (out_h * out_w); - const int grid_offset = n * grid_sN + h * grid_sH + w * grid_sW; + IndexT nthreads = n * out_hw; + IndexT inp_sN = out_c * (in_h * in_w); + IndexT inp_sC = in_h * in_w; + IndexT inp_sH = in_w; + IndexT inp_sW = 1; + IndexT grid_sNHW = 2; + IndexT grid_sCoor = 1; + IndexT out_sN = out_c * out_hw; + IndexT out_sC = out_hw; + IndexT out_sHW = 1; + CUDA_KERNEL_LOOP_TYPE(index, nthreads, IndexT) { + const IndexT hw = index % out_hw; + const IndexT n = index / out_hw; + const IndexT grid_offset = index * grid_sNHW; T ix = grid[grid_offset]; T iy = grid[grid_offset + grid_sCoor]; @@ -106,49 +100,46 @@ __global__ void GridSampleCudaKernel(const int nthreads, ix = ComputePositions(ix, in_w, padding_mode, align_corners); iy = ComputePositions(iy, in_h, padding_mode, align_corners); if (mode == Mode::bilinear) { - int ix_nw = static_cast(floor(ix)); - int iy_nw = static_cast(floor(iy)); - int ix_ne = ix_nw + 1; - int iy_ne = iy_nw; - int ix_sw = ix_nw; - int iy_sw = iy_nw + 1; - int ix_se = ix_nw + 1; - int iy_se = iy_nw + 1; + IndexT ix_nw = floor(ix); + IndexT iy_nw = floor(iy); + IndexT ix_ne = ix_nw + 1; + IndexT iy_ne = iy_nw; + IndexT ix_sw = ix_nw; + IndexT iy_sw = iy_nw + 1; + IndexT ix_se = ix_nw + 1; + IndexT iy_se = iy_nw + 1; T nw = (ix_se - ix) * (iy_se - iy); T ne = (ix - ix_sw) * (iy_sw - iy); T sw = (ix_ne - ix) * (iy - iy_ne); T se = (ix - ix_nw) * (iy - iy_nw); - auto inp_offset_NC = n * inp_sN; + IndexT inp_offset_NC = n * inp_sN; + T* out_ptr_NCHW = output + (n * out_sN + hw * out_sHW); - auto out_ptr_NCHW = output + n * out_sN + h * out_sH + w * out_sW; - for (int c = 0; c < out_c; + for (IndexT c = 0; c < out_c; ++c, inp_offset_NC += inp_sC, out_ptr_NCHW += out_sC) { - *out_ptr_NCHW = static_cast(0); + T value{0}; if (InBounds(iy_nw, ix_nw, in_h, in_w)) { - *out_ptr_NCHW += - input[inp_offset_NC + iy_nw * inp_sH + ix_nw * inp_sW] * nw; + value += input[inp_offset_NC + iy_nw * inp_sH + ix_nw * inp_sW] * nw; } if (InBounds(iy_ne, ix_ne, in_h, in_w)) { - *out_ptr_NCHW += - input[inp_offset_NC + iy_ne * inp_sH + ix_ne * inp_sW] * ne; + value += input[inp_offset_NC + iy_ne * inp_sH + ix_ne * inp_sW] * ne; } if (InBounds(iy_sw, ix_sw, in_h, in_w)) { - *out_ptr_NCHW += - input[inp_offset_NC + iy_sw * inp_sH + ix_sw * inp_sW] * sw; + value += input[inp_offset_NC + iy_sw * inp_sH + ix_sw * inp_sW] * sw; } if (InBounds(iy_se, ix_se, in_h, in_w)) { - *out_ptr_NCHW += - input[inp_offset_NC + iy_se * inp_sH + ix_se * inp_sW] * se; + value += input[inp_offset_NC + iy_se * inp_sH + ix_se * inp_sW] * se; } + *out_ptr_NCHW = value; } } else if (mode == Mode::nearest) { - int ix_nearest = static_cast(std::nearbyint(ix)); - int iy_nearest = static_cast(std::nearbyint(iy)); - auto inp_offset_NC = n * inp_sN; - auto out_ptr_NCHW = output + n * out_sN + h * out_sH + w * out_sW; - for (int c = 0; c < out_c; + IndexT ix_nearest = std::nearbyint(ix); + IndexT iy_nearest = std::nearbyint(iy); + IndexT inp_offset_NC = n * inp_sN; + T* out_ptr_NCHW = output + (n * out_sN + hw * out_sHW); + for (IndexT c = 0; c < out_c; ++c, inp_offset_NC += inp_sC, out_ptr_NCHW += out_sC) { if (InBounds(iy_nearest, ix_nearest, in_h, in_w)) { *out_ptr_NCHW = @@ -349,12 +340,12 @@ void GridSampleKernel(const Context& dev_ctx, } if (x.dims().size() == 4) { - const int n = grid.dims()[0]; - const int out_h = grid.dims()[1]; - const int out_w = grid.dims()[2]; - const int c = x.dims()[1]; - const int in_h = x.dims()[2]; - const int in_w = x.dims()[3]; + const int64_t n = grid.dims()[0]; + const int64_t out_h = grid.dims()[1]; + const int64_t out_w = grid.dims()[2]; + const int64_t c = x.dims()[1]; + const int64_t in_h = x.dims()[2]; + const int64_t in_w = x.dims()[3]; VLOG(3) << "n: " << n << "; c: " << c << "; out_h: " << out_h << "; out_w: " << out_w; @@ -362,25 +353,41 @@ void GridSampleKernel(const Context& dev_ctx, VLOG(3) << "out dims: " << out->dims()[0] << "; " << out->dims()[1] << "; " << out->dims()[2] << "; " << out->dims()[3]; - int count = static_cast(n * out_h * out_w); + int64_t count = n * out_h * out_w; auto cu_stream = dev_ctx.stream(); backends::gpu::GpuLaunchConfig config = backends::gpu::GetGpuLaunchConfig1D(dev_ctx, count); - GridSampleCudaKernel - <<>>( - count, - n, - c, - out_h, - out_w, - in_h, - in_w, - x.data(), - grid.data(), - output_data, - enum_mode, - enum_padding_mode, - align_corners); + if (x.numel() <= std::numeric_limits::max() && + grid.numel() <= std::numeric_limits::max() && + out->numel() <= std::numeric_limits::max()) { + GridSampleCudaKernel + <<>>( + n, + c, + out_h * out_w, + in_h, + in_w, + x.data(), + grid.data(), + output_data, + enum_mode, + enum_padding_mode, + align_corners); + } else { + GridSampleCudaKernel + <<>>( + n, + c, + out_h * out_w, + in_h, + in_w, + x.data(), + grid.data(), + output_data, + enum_mode, + enum_padding_mode, + align_corners); + } } else { const int n = grid.dims()[0]; const int out_d = grid.dims()[1];