Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 96 additions & 89 deletions paddle/phi/kernels/gpu/grid_sample_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,132 +23,123 @@

namespace phi {

template <typename T>
template <typename T, typename IndexT>
static __forceinline__ __device__ T Unnormalize(T coord,
int size,
IndexT size,
bool align_corners) {
return align_corners ? ((coord + 1.f) / 2) * (size - 1)
: ((coord + 1.f) * size - 1) / 2;
}

template <typename T>
static __forceinline__ __device__ T ClipIndexes(T in, int max_value) {
template <typename T, typename IndexT>
static __forceinline__ __device__ T ClipIndexes(T in, IndexT max_value) {
return min(static_cast<T>(max_value - 1), max(in, static_cast<T>(0)));
}

template <typename T>
template <typename T, typename IndexT>
static __forceinline__ __device__ T ReflectIndexes(T in,
int twice_low,
int twice_high) {
IndexT twice_low,
IndexT twice_high) {
if (twice_low == twice_high) {
return static_cast<T>(0);
}
T min = static_cast<T>(twice_low) / 2;
T span = static_cast<T>(twice_high - twice_low) / 2;
in = fabs(in - min);
T extra = fmod(in, span);
int flips = static_cast<int>(floor(in / span));
IndexT flips = floor(in / span);
return (flips & 1) ? span - extra + min : extra + min; // cond ? odd : even
}

template <typename T>
template <typename T, typename IndexT>
static __forceinline__ __device__ T ComputePositions(T coord,
int size,
IndexT size,
PaddingMode padding_mode,
bool align_corners) {
coord = Unnormalize<T>(coord, size, align_corners);
coord = Unnormalize(coord, size, align_corners);
if (padding_mode == PaddingMode::border) {
coord = ClipIndexes(coord, size);
} else if (padding_mode == PaddingMode::reflect) {
coord = align_corners ? ReflectIndexes(coord, 0, 2 * (size - 1))
: ReflectIndexes(coord, -1, 2 * size - 1);
coord = align_corners ? ReflectIndexes<T, IndexT>(coord, 0, 2 * (size - 1))
: ReflectIndexes<T, IndexT>(coord, -1, 2 * size - 1);
coord = ClipIndexes(coord, size);
}
return SafeDownGradeToIntRange(coord);
}

template <typename T>
__global__ void GridSampleCudaKernel(const int nthreads,
int n,
int out_c,
int out_h,
int out_w,
int in_h,
int in_w,
const T* input,
const T* grid,
T* output,
template <typename T, typename IndexT>
__global__ void GridSampleCudaKernel(IndexT n,
IndexT out_c,
IndexT out_hw,
IndexT in_h,
IndexT in_w,
const T* __restrict__ input,
const T* __restrict__ grid,
T* __restrict__ output,
const Mode mode,
const PaddingMode padding_mode,
bool align_corners) {
int inp_sN = out_c * in_h * in_w;

int inp_sC = in_h * in_w;
int inp_sH = in_w;
int inp_sW = 1;
int grid_sN = out_h * out_w * 2;
int grid_sH = out_w * 2;
int grid_sW = 2;
int grid_sCoor = 1;
int out_sN = out_c * out_h * out_w;
int out_sC = out_h * out_w;
int out_sH = out_w;
int out_sW = 1;
CUDA_KERNEL_LOOP(index, nthreads) {
const int w = index % out_w;
const int h = (index / out_w) % out_h;
const int n = index / (out_h * out_w);
const int grid_offset = n * grid_sN + h * grid_sH + w * grid_sW;
IndexT nthreads = n * out_hw;
IndexT inp_sN = out_c * (in_h * in_w);
IndexT inp_sC = in_h * in_w;
IndexT inp_sH = in_w;
IndexT inp_sW = 1;
IndexT grid_sNHW = 2;
IndexT grid_sCoor = 1;
IndexT out_sN = out_c * out_hw;
IndexT out_sC = out_hw;
IndexT out_sHW = 1;
CUDA_KERNEL_LOOP_TYPE(index, nthreads, IndexT) {
const IndexT hw = index % out_hw;
const IndexT n = index / out_hw;
const IndexT grid_offset = index * grid_sNHW;

T ix = grid[grid_offset];
T iy = grid[grid_offset + grid_sCoor];

ix = ComputePositions(ix, in_w, padding_mode, align_corners);
iy = ComputePositions(iy, in_h, padding_mode, align_corners);
if (mode == Mode::bilinear) {
int ix_nw = static_cast<int>(floor(ix));
int iy_nw = static_cast<int>(floor(iy));
int ix_ne = ix_nw + 1;
int iy_ne = iy_nw;
int ix_sw = ix_nw;
int iy_sw = iy_nw + 1;
int ix_se = ix_nw + 1;
int iy_se = iy_nw + 1;
IndexT ix_nw = floor(ix);
IndexT iy_nw = floor(iy);
IndexT ix_ne = ix_nw + 1;
IndexT iy_ne = iy_nw;
IndexT ix_sw = ix_nw;
IndexT iy_sw = iy_nw + 1;
IndexT ix_se = ix_nw + 1;
IndexT iy_se = iy_nw + 1;

T nw = (ix_se - ix) * (iy_se - iy);
T ne = (ix - ix_sw) * (iy_sw - iy);
T sw = (ix_ne - ix) * (iy - iy_ne);
T se = (ix - ix_nw) * (iy - iy_nw);

auto inp_offset_NC = n * inp_sN;
IndexT inp_offset_NC = n * inp_sN;
T* out_ptr_NCHW = output + (n * out_sN + hw * out_sHW);

auto out_ptr_NCHW = output + n * out_sN + h * out_sH + w * out_sW;
for (int c = 0; c < out_c;
for (IndexT c = 0; c < out_c;
++c, inp_offset_NC += inp_sC, out_ptr_NCHW += out_sC) {
*out_ptr_NCHW = static_cast<T>(0);
T value{0};
if (InBounds(iy_nw, ix_nw, in_h, in_w)) {
*out_ptr_NCHW +=
input[inp_offset_NC + iy_nw * inp_sH + ix_nw * inp_sW] * nw;
value += input[inp_offset_NC + iy_nw * inp_sH + ix_nw * inp_sW] * nw;
}
if (InBounds(iy_ne, ix_ne, in_h, in_w)) {
*out_ptr_NCHW +=
input[inp_offset_NC + iy_ne * inp_sH + ix_ne * inp_sW] * ne;
value += input[inp_offset_NC + iy_ne * inp_sH + ix_ne * inp_sW] * ne;
}
if (InBounds(iy_sw, ix_sw, in_h, in_w)) {
*out_ptr_NCHW +=
input[inp_offset_NC + iy_sw * inp_sH + ix_sw * inp_sW] * sw;
value += input[inp_offset_NC + iy_sw * inp_sH + ix_sw * inp_sW] * sw;
}
if (InBounds(iy_se, ix_se, in_h, in_w)) {
*out_ptr_NCHW +=
input[inp_offset_NC + iy_se * inp_sH + ix_se * inp_sW] * se;
value += input[inp_offset_NC + iy_se * inp_sH + ix_se * inp_sW] * se;
}
*out_ptr_NCHW = value;
}
} else if (mode == Mode::nearest) {
int ix_nearest = static_cast<int>(std::nearbyint(ix));
int iy_nearest = static_cast<int>(std::nearbyint(iy));
auto inp_offset_NC = n * inp_sN;
auto out_ptr_NCHW = output + n * out_sN + h * out_sH + w * out_sW;
for (int c = 0; c < out_c;
IndexT ix_nearest = std::nearbyint(ix);
IndexT iy_nearest = std::nearbyint(iy);
IndexT inp_offset_NC = n * inp_sN;
T* out_ptr_NCHW = output + (n * out_sN + hw * out_sHW);
for (IndexT c = 0; c < out_c;
++c, inp_offset_NC += inp_sC, out_ptr_NCHW += out_sC) {
if (InBounds(iy_nearest, ix_nearest, in_h, in_w)) {
*out_ptr_NCHW =
Expand Down Expand Up @@ -349,38 +340,54 @@ void GridSampleKernel(const Context& dev_ctx,
}

if (x.dims().size() == 4) {
const int n = grid.dims()[0];
const int out_h = grid.dims()[1];
const int out_w = grid.dims()[2];
const int c = x.dims()[1];
const int in_h = x.dims()[2];
const int in_w = x.dims()[3];
const int64_t n = grid.dims()[0];
const int64_t out_h = grid.dims()[1];
const int64_t out_w = grid.dims()[2];
const int64_t c = x.dims()[1];
const int64_t in_h = x.dims()[2];
const int64_t in_w = x.dims()[3];
VLOG(3) << "n: " << n << "; c: " << c << "; out_h: " << out_h
<< "; out_w: " << out_w;

auto* output_data = dev_ctx.template Alloc<T>(out);
VLOG(3) << "out dims: " << out->dims()[0] << "; " << out->dims()[1] << "; "
<< out->dims()[2] << "; " << out->dims()[3];

int count = static_cast<int>(n * out_h * out_w);
int64_t count = n * out_h * out_w;
auto cu_stream = dev_ctx.stream();
backends::gpu::GpuLaunchConfig config =
backends::gpu::GetGpuLaunchConfig1D(dev_ctx, count);
GridSampleCudaKernel<T>
<<<config.block_per_grid, config.thread_per_block, 0, cu_stream>>>(
count,
n,
c,
out_h,
out_w,
in_h,
in_w,
x.data<T>(),
grid.data<T>(),
output_data,
enum_mode,
enum_padding_mode,
align_corners);
if (x.numel() <= std::numeric_limits<int>::max() &&
grid.numel() <= std::numeric_limits<int>::max() &&
out->numel() <= std::numeric_limits<int>::max()) {
GridSampleCudaKernel<T, int>
<<<config.block_per_grid, config.thread_per_block, 0, cu_stream>>>(
n,
c,
out_h * out_w,
in_h,
in_w,
x.data<T>(),
grid.data<T>(),
output_data,
enum_mode,
enum_padding_mode,
align_corners);
} else {
GridSampleCudaKernel<T, int64_t>
<<<config.block_per_grid, config.thread_per_block, 0, cu_stream>>>(
n,
c,
out_h * out_w,
in_h,
in_w,
x.data<T>(),
grid.data<T>(),
output_data,
enum_mode,
enum_padding_mode,
align_corners);
}
} else {
const int n = grid.dims()[0];
const int out_d = grid.dims()[1];
Expand Down
Loading