Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 47 additions & 37 deletions paddle/phi/kernels/cpu/grid_sample_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -248,13 +248,15 @@ static void GatherOutputGradToInputGrad(const DenseTensor& output_grad,
for (int i = 0; i < n; i++) {
for (int k = 0; k < out_h; k++) {
for (int l = 0; l < out_w; l++) {
if (IsInBound(
x_t(i, k, l), y_t(i, k, l), (T)(in_w - 1), (T)(in_h - 1))) {
if (IsInBound<int>(static_cast<int>(x_t(i, k, l)),
static_cast<int>(y_t(i, k, l)),
(in_w - 1),
(in_h - 1))) {
for (int j = 0; j < c; j++) {
input_grad_t(i,
j,
static_cast<int>(round(y_t(i, k, l))),
static_cast<int>(round(x_t(i, k, l)))) +=
static_cast<int>(y_t(i, k, l)),
static_cast<int>(x_t(i, k, l))) +=
output_grad_t(i, j, k, l) * d1_t(i, k, l) * d2_t(i, k, l);
}
}
Expand Down Expand Up @@ -293,18 +295,18 @@ static void Gather3DOutputGradToInputGrad(const DenseTensor& output_grad,
for (int m = 0; m < out_d; m++) {
for (int k = 0; k < out_h; k++) {
for (int l = 0; l < out_w; l++) {
if (IsInBound3D(x_t(i, m, k, l),
y_t(i, m, k, l),
z_t(i, m, k, l),
(T)(in_w - 1),
(T)(in_h - 1),
(T)(in_d - 1))) {
if (IsInBound3D<int>(static_cast<int>(x_t(i, m, k, l)),
static_cast<int>(y_t(i, m, k, l)),
static_cast<int>(z_t(i, m, k, l)),
(in_w - 1),
(in_h - 1),
(in_d - 1))) {
for (int j = 0; j < c; j++) {
input_grad_t(i,
j,
static_cast<int>(round(z_t(i, m, k, l))),
static_cast<int>(round(y_t(i, m, k, l))),
static_cast<int>(round(x_t(i, m, k, l)))) +=
static_cast<int>(z_t(i, m, k, l)),
static_cast<int>(y_t(i, m, k, l)),
static_cast<int>(x_t(i, m, k, l))) +=
output_grad_t(i, j, m, k, l) * d1_t(i, m, k, l) *
d2_t(i, m, k, l) * d3_t(i, m, k, l);
}
Expand Down Expand Up @@ -590,13 +592,15 @@ static void GatherOutputGradToInputGrad(const DenseTensor& output_grad,
for (int i = 0; i < n; i++) {
for (int k = 0; k < out_h; k++) {
for (int l = 0; l < out_w; l++) {
if (IsInBound(
x_t(i, k, l), y_t(i, k, l), (T)(in_w - 1), (T)(in_h - 1))) {
if (IsInBound<int>(static_cast<int>(std::nearbyint(x_t(i, k, l))),
static_cast<int>(std::nearbyint(y_t(i, k, l))),
(in_w - 1),
(in_h - 1))) {
for (int j = 0; j < c; j++) {
input_grad_t(i,
j,
static_cast<int>(round(y_t(i, k, l))),
static_cast<int>(round(x_t(i, k, l)))) +=
static_cast<int>(std::nearbyint(y_t(i, k, l))),
static_cast<int>(std::nearbyint(x_t(i, k, l)))) +=
output_grad_t(i, j, k, l);
}
}
Expand Down Expand Up @@ -628,18 +632,19 @@ static void Gather3DOutputGradToInputGrad(const DenseTensor& output_grad,
for (int m = 0; m < out_d; m++) {
for (int k = 0; k < out_h; k++) {
for (int l = 0; l < out_w; l++) {
if (IsInBound3D(x_t(i, m, k, l),
y_t(i, m, k, l),
z_t(i, m, k, l),
(T)(in_w - 1),
(T)(in_h - 1),
(T)(in_d - 1))) {
if (IsInBound3D<int>(
static_cast<int>(std::nearbyint(x_t(i, m, k, l))),
static_cast<int>(std::nearbyint(y_t(i, m, k, l))),
static_cast<int>(std::nearbyint(z_t(i, m, k, l))),
(in_w - 1),
(in_h - 1),
(in_d - 1))) {
for (int j = 0; j < c; j++) {
input_grad_t(i,
j,
static_cast<int>(round(z_t(i, m, k, l))),
static_cast<int>(round(y_t(i, m, k, l))),
static_cast<int>(round(x_t(i, m, k, l)))) +=
static_cast<int>(std::nearbyint(z_t(i, m, k, l))),
static_cast<int>(std::nearbyint(y_t(i, m, k, l))),
static_cast<int>(std::nearbyint(x_t(i, m, k, l)))) +=
output_grad_t(i, j, m, k, l);
}
}
Expand Down Expand Up @@ -673,6 +678,13 @@ void GridSampleGradKernel(const Context& dev_ctx,
return;
}

std::string enum_mode;
if (mode == "nearest") {
enum_mode = "nearest";
} else {
enum_mode = "bilinear";
}

if (x.dims().size() == 4) {
const int n = static_cast<int>(grid.dims()[0]);
const int out_h = static_cast<int>(grid.dims()[1]);
Expand Down Expand Up @@ -704,7 +716,10 @@ void GridSampleGradKernel(const Context& dev_ctx,
&grid_y,
&grid_x_scale,
&grid_y_scale);
if (mode == "bilinear") {
if (enum_mode == "nearest") {
GatherOutputGradToInputGrad<T>(out_grad, x_grad, grid_x, grid_y);

} else if (enum_mode == "bilinear") {
GatherBilinearGrad<T>(dev_ctx,
x,
out_grad,
Expand All @@ -714,12 +729,6 @@ void GridSampleGradKernel(const Context& dev_ctx,
&grid_y_scale,
x_grad,
grid_grad);
} else {
auto grid_x_t = EigenTensor<T, 3>::From(grid_x);
auto grid_y_t = EigenTensor<T, 3>::From(grid_y);
grid_x_t = grid_x_t.round();
grid_y_t = grid_y_t.round();
GatherOutputGradToInputGrad<T>(out_grad, x_grad, grid_x, grid_y);
}
} else {
const int n = static_cast<int>(grid.dims()[0]);
Expand Down Expand Up @@ -757,7 +766,11 @@ void GridSampleGradKernel(const Context& dev_ctx,
&grid_x_scale,
&grid_y_scale,
&grid_z_scale);
if (mode == "bilinear") {
if (enum_mode == "nearest") {
Gather3DOutputGradToInputGrad<T>(
out_grad, x_grad, grid_x, grid_y, grid_z);

} else if (enum_mode == "bilinear") {
Gather3DBilinearGrad<T>(dev_ctx,
x,
out_grad,
Expand All @@ -769,9 +782,6 @@ void GridSampleGradKernel(const Context& dev_ctx,
&grid_z_scale,
x_grad,
grid_grad);
} else {
Gather3DOutputGradToInputGrad<T>(
out_grad, x_grad, grid_x, grid_y, grid_z);
}
}
}
Expand Down
24 changes: 14 additions & 10 deletions paddle/phi/kernels/cpu/grid_sample_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,14 @@ void GridSampleKernel(const Context& dev_ctx,
dev_ctx.template Alloc<T>(out);
return;
}

std::string enum_mode;
if (mode == "nearest") {
enum_mode = "nearest";
} else {
enum_mode = "bilinear";
}

if (x.dims().size() == 4) {
const int n = static_cast<int>(grid.dims()[0]);
const int out_h = static_cast<int>(grid.dims()[1]);
Expand All @@ -338,14 +346,10 @@ void GridSampleKernel(const Context& dev_ctx,
&grid_x,
&grid_y);

if (mode == "bilinear") {
if (enum_mode == "bilinear") {
BilinearInter<T>(dev_ctx, x, &grid_x, &grid_y, out);
} else if (mode == "nearest") {
auto grid_x_t = EigenTensor<T, 3>::From(grid_x);
auto grid_y_t = EigenTensor<T, 3>::From(grid_y);
grid_x_t = grid_x_t.round();
grid_y_t = grid_y_t.round();
GetGridPointValue<T>(x, out, grid_x, grid_y);
} else if (enum_mode == "nearest") {
GetGridPointValue_nearest<T>(x, out, grid_x, grid_y);
}
} else {
const int n = static_cast<int>(grid.dims()[0]);
Expand All @@ -372,10 +376,10 @@ void GridSampleKernel(const Context& dev_ctx,
&grid_x,
&grid_y,
&grid_z);
if (mode == "bilinear") {
if (enum_mode == "bilinear") {
Bilinear3DInter<T>(dev_ctx, x, &grid_x, &grid_y, &grid_z, out);
} else if (mode == "nearest") {
Get3DGridPointValue<T>(x, out, grid_x, grid_y, grid_z);
} else if (enum_mode == "nearest") {
Get3DGridPointValue_nearest<T>(x, out, grid_x, grid_y, grid_z);
}
}
}
Expand Down
118 changes: 101 additions & 17 deletions paddle/phi/kernels/cpu/grid_sample_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@ void Unnormalize(const CPUContext& dev_ctx,
auto& place = *dev_ctx.eigen_device();
auto grid_slice_t = EigenTensor<T, 3>::From(*grid_slice);

if (!align_corners) {
if (align_corners) {
auto factor = static_cast<T>(max_val * 0.5);
grid_slice_t.device(place) = (grid_slice_t + static_cast<T>(1)) * factor;
} else {
auto factor = static_cast<T>((max_val + 1) * 0.5);
grid_slice_t.device(place) =
(grid_slice_t + static_cast<T>(1)) * factor - static_cast<T>(0.5);
} else {
auto factor = static_cast<T>(max_val * 0.5);
grid_slice_t.device(place) = (grid_slice_t + static_cast<T>(1)) * factor;
}
}

Expand Down Expand Up @@ -89,14 +89,51 @@ void GetGridPointValue(const DenseTensor& input,
for (int i = 0; i < n; i++) {
for (int k = 0; k < out_h; k++) {
for (int l = 0; l < out_w; l++) {
if (IsInBound(
x_t(i, k, l), y_t(i, k, l), (T)(in_w - 1), (T)(in_h - 1))) {
if (IsInBound<int>(static_cast<int>(x_t(i, k, l)),
static_cast<int>(y_t(i, k, l)),
(in_w - 1),
(in_h - 1))) {
for (int j = 0; j < c; j++) {
output_t(i, j, k, l) = input_t(i,
j,
static_cast<int>(y_t(i, k, l)),
static_cast<int>(x_t(i, k, l)));
}
}
}
}
}
}

template <typename T>
void GetGridPointValue_nearest(const DenseTensor& input,
DenseTensor* output,
const DenseTensor& x,
const DenseTensor& y) {
const int n = input.dims()[0];
const int c = input.dims()[1];
const int in_h = input.dims()[2];
const int in_w = input.dims()[3];
const int out_h = x.dims()[1];
const int out_w = x.dims()[2];
auto x_t = EigenTensor<T, 3>::From(x);
auto y_t = EigenTensor<T, 3>::From(y);
auto output_t = EigenTensor<T, 4>::From(*output).setConstant((T)0);
auto input_t = EigenTensor<T, 4>::From(input);

for (int i = 0; i < n; i++) {
for (int k = 0; k < out_h; k++) {
for (int l = 0; l < out_w; l++) {
if (IsInBound<int>(static_cast<int>(std::nearbyint(x_t(i, k, l))),
static_cast<int>(std::nearbyint(y_t(i, k, l))),
(in_w - 1),
(in_h - 1))) {
for (int j = 0; j < c; j++) {
output_t(i, j, k, l) =
input_t(i,
j,
static_cast<int>(round(y_t(i, k, l))),
static_cast<int>(round(x_t(i, k, l))));
static_cast<int>(std::nearbyint(y_t(i, k, l))),
static_cast<int>(std::nearbyint(x_t(i, k, l))));
}
}
}
Expand Down Expand Up @@ -207,19 +244,66 @@ void Get3DGridPointValue(const DenseTensor& input,
for (int m = 0; m < out_d; m++) {
for (int k = 0; k < out_h; k++) {
for (int l = 0; l < out_w; l++) {
if (IsInBound3D(x_t(i, m, k, l),
y_t(i, m, k, l),
z_t(i, m, k, l),
(T)(in_w - 1),
(T)(in_h - 1),
(T)(in_d - 1))) {
if (IsInBound3D<int>(static_cast<int>(x_t(i, m, k, l)),
static_cast<int>(y_t(i, m, k, l)),
static_cast<int>(z_t(i, m, k, l)),
(in_w - 1),
(in_h - 1),
(in_d - 1))) {
for (int j = 0; j < c; j++) {
output_t(i, j, m, k, l) =
input_t(i,
j,
static_cast<int>(z_t(i, m, k, l)),
static_cast<int>(y_t(i, m, k, l)),
static_cast<int>(x_t(i, m, k, l)));
}
}
}
}
}
}
}

template <typename T>
void Get3DGridPointValue_nearest(const DenseTensor& input,
DenseTensor* output,
const DenseTensor& x,
const DenseTensor& y,
const DenseTensor& z) {
const int n = input.dims()[0];
const int c = input.dims()[1];
const int in_d = input.dims()[2];
const int in_h = input.dims()[3];
const int in_w = input.dims()[4];
const int out_d = x.dims()[1];
const int out_h = x.dims()[2];
const int out_w = x.dims()[3];
auto x_t = EigenTensor<T, 4>::From(x);
auto y_t = EigenTensor<T, 4>::From(y);
auto z_t = EigenTensor<T, 4>::From(z);
auto output_t =
EigenTensor<T, 5>::From(*output).setConstant(static_cast<T>(0.0));
auto input_t = EigenTensor<T, 5>::From(input);

for (int i = 0; i < n; i++) {
for (int m = 0; m < out_d; m++) {
for (int k = 0; k < out_h; k++) {
for (int l = 0; l < out_w; l++) {
if (IsInBound3D<int>(
static_cast<int>(std::nearbyint(x_t(i, m, k, l))),
static_cast<int>(std::nearbyint(y_t(i, m, k, l))),
static_cast<int>(std::nearbyint(z_t(i, m, k, l))),
(in_w - 1),
(in_h - 1),
(in_d - 1))) {
for (int j = 0; j < c; j++) {
output_t(i, j, m, k, l) =
input_t(i,
j,
static_cast<int>(round(z_t(i, m, k, l))),
static_cast<int>(round(y_t(i, m, k, l))),
static_cast<int>(round(x_t(i, m, k, l))));
static_cast<int>(std::nearbyint(z_t(i, m, k, l))),
static_cast<int>(std::nearbyint(y_t(i, m, k, l))),
static_cast<int>(std::nearbyint(x_t(i, m, k, l))));
}
}
}
Expand Down
6 changes: 3 additions & 3 deletions paddle/phi/kernels/gpu/grid_sample_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -293,9 +293,9 @@ __global__ void GridSample3DCudaKernel(const IndexT nthreads,
}
}
} else if (interpolation_mode == Mode::nearest) {
IndexT ix_nearest = static_cast<IndexT>(std::round(ix));
IndexT iy_nearest = static_cast<IndexT>(std::round(iy));
IndexT iz_nearest = static_cast<IndexT>(std::round(iz));
IndexT ix_nearest = static_cast<IndexT>(std::nearbyint(ix));
IndexT iy_nearest = static_cast<IndexT>(std::nearbyint(iy));
IndexT iz_nearest = static_cast<IndexT>(std::nearbyint(iz));

// assign nearest neighbor pixel value to output pixel
const T* inp_ptr_NC = input + n * inp_sN;
Expand Down
15 changes: 14 additions & 1 deletion test/legacy_test/test_grid_sampler_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,16 +379,29 @@ def setUp(self):
}

def test_check_output(self):
self.check_output_with_place(core.CPUPlace(), check_pir=True)
if core.is_compiled_with_cuda():
self.check_output_with_place(core.CUDAPlace(0), check_pir=True)
self.check_output(check_pir=True)

def test_check_grad_normal(self):
self.check_grad(
self.check_grad_with_place(
core.CPUPlace(),
['X', 'Grid'],
'Output',
max_relative_error=0.01,
numeric_grad_delta=self.numeric_grad_delta,
check_pir=True,
)
if core.is_compiled_with_cuda():
self.check_grad_with_place(
core.CUDAPlace(0),
['X', 'Grid'],
'Output',
max_relative_error=0.01,
numeric_grad_delta=self.numeric_grad_delta,
check_pir=True,
)

def initTestCase(self):
self.x_shape = (2, 3, 8, 8)
Expand Down