diff --git a/backends/metax_gpu/CMakeLists.txt b/backends/metax_gpu/CMakeLists.txt index b98f2bcc919..bca1ce7aad4 100755 --- a/backends/metax_gpu/CMakeLists.txt +++ b/backends/metax_gpu/CMakeLists.txt @@ -310,8 +310,6 @@ file( ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/hinge_loss_grad_kernel.cu ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/hinge_loss_kernel.cu ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/gru_grad_kernel.cu - ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/grid_sample_grad_kernel.cu - ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/grid_sample_kernel.cu ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/generate_proposals_kernel.cu ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/gaussian_inplace_grad_kernel.cu ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/gammaln_kernel.cu diff --git a/backends/metax_gpu/kernels/cuda_kernels/grid_sample_grad_kernel_register.cu b/backends/metax_gpu/kernels/cuda_kernels/grid_sample_grad_kernel_register.cu deleted file mode 100644 index 83c47dc86db..00000000000 --- a/backends/metax_gpu/kernels/cuda_kernels/grid_sample_grad_kernel_register.cu +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/grid_sample_grad_kernel.h" - -PD_CUSTOM_KERNEL_REGISTER(grid_sample_grad, - metax_gpu, - ALL_LAYOUT, - phi::GridSampleGradKernel, - float, - double) {} diff --git a/backends/metax_gpu/kernels/cuda_kernels/grid_sample_kernel_register.cu b/backends/metax_gpu/kernels/cuda_kernels/grid_sample_kernel_register.cu deleted file mode 100644 index a0447405971..00000000000 --- a/backends/metax_gpu/kernels/cuda_kernels/grid_sample_kernel_register.cu +++ /dev/null @@ -1,19 +0,0 @@ -// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/kernels/grid_sample_kernel.h" - -PD_CUSTOM_KERNEL_REGISTER( - grid_sample, metax_gpu, ALL_LAYOUT, phi::GridSampleKernel, float, double) {} diff --git a/backends/metax_gpu/kernels/metax_kernel/grid_sample_grad_kernel_register.cu b/backends/metax_gpu/kernels/metax_kernel/grid_sample_grad_kernel_register.cu new file mode 100644 index 00000000000..8aae95bdb22 --- /dev/null +++ b/backends/metax_gpu/kernels/metax_kernel/grid_sample_grad_kernel_register.cu @@ -0,0 +1,839 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "kernels/metax_kernel/metax_context.h" +#include "paddle/phi/backends/dynload/cudnn.h" +#include "paddle/phi/backends/gpu/gpu_device_function.h" +#include "paddle/phi/backends/gpu/gpu_info.h" +#include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/backends/gpu/gpu_primitives.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/full_kernel.h" +#include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/gpu/grid_sample_utils.h" +#include "paddle/phi/kernels/grid_sample_grad_kernel.h" + +namespace phi { + +template +static __forceinline__ __device__ void AtomicAdd(T* data, + IndexT h, + IndexT w, + IndexT sH, + IndexT sW, + IndexT H, + IndexT W, + T delta) { + if (InBounds(h, w, H, W)) { + phi::CudaAtomicAdd(data + h * sH + w * sW, delta); + } +} + +template +static __forceinline__ __device__ void AtomicAdd3D(T* data, + IndexT d, + IndexT h, + IndexT w, + IndexT sD, + IndexT sH, + IndexT sW, + IndexT D, + IndexT H, + IndexT W, + T delta) { + if (InBounds3D(d, h, w, D, H, W)) { + phi::CudaAtomicAdd(data + d * sD + h * sH + w * sW, delta); + } +} + +template +static __forceinline__ __device__ T +UnnormalizeWithMask(T coord, IndexT size, bool align_corners, T* grad_in) { + if (align_corners) { + *grad_in = static_cast(size - 1) / 2; + return ((coord + 1.f) / 2) * (size - 1); + } else { + *grad_in = static_cast(size) / 2; + return ((coord + 1.f) * size - 1) / 2; + } +} + +template +static __forceinline__ __device__ T ClipIndexesWithMask(T in, + IndexT clip_limit, + T* grad_in) { + if (in <= static_cast(0)) { + *grad_in = static_cast(0); + return static_cast(0); + } else { + T max = static_cast(clip_limit - 1); + if (in >= max) { + *grad_in = static_cast(0); + return max; + } else { + *grad_in = static_cast(1); + return in; + } + } +} + +template +static __forceinline__ __device__ T +ReflectIndexesWithMask(T in, IndexT twice_low, IndexT twice_high, T* grad_in) { + if (twice_low == twice_high) { + *grad_in = static_cast(0); + return static_cast(0); + } + IndexT grad_in_mult_; + T min = static_cast(twice_low) / 2; + T span = static_cast(twice_high - twice_low) / 2; + in = in - min; + if (in < static_cast(0)) { + grad_in_mult_ = -1; + in = -in; + } else { + grad_in_mult_ = 1; + } + T extra = fmod(in, span); + IndexT flips = static_cast(floor(in / span)); + if (flips % 2 == 0) { + *grad_in = static_cast(grad_in_mult_); + return extra + min; + } else { + *grad_in = static_cast(-grad_in_mult_); + return span - extra + min; + } +} + +template +static __forceinline__ __device__ T +ComputePositionsWithMask(T coord, + IndexT size, + PaddingMode padding_mode, + bool align_corners, + T* grad_in) { + T grad_clip, grad_refl; + coord = UnnormalizeWithMask(coord, size, align_corners, grad_in); + if (padding_mode == PaddingMode::border) { + coord = ClipIndexesWithMask(coord, size, &grad_clip); + *grad_in = (*grad_in) * grad_clip; + } else if (padding_mode == PaddingMode::reflect) { + coord = align_corners ? ReflectIndexesWithMask( + coord, 0, 2 * (size - 1), &grad_refl) + : ReflectIndexesWithMask( + coord, -1, 2 * size - 1, &grad_refl); + coord = ClipIndexesWithMask(coord, size, &grad_clip); + *grad_in = (*grad_in) * grad_refl * grad_clip; + } + return SafeDownGradeToIntRange(coord); +} + +template +__global__ void GridSamplerCudaBackwardKernel(const IndexT nthreads, + const T* grad_output, + const T* input, + const T* grid, + IndexT n, + IndexT out_c, + IndexT out_h, + IndexT out_w, + IndexT in_h, + IndexT in_w, + T* grad_input, + T* grad_grid, + const Mode mode, + const PaddingMode padding_mode, + bool align_corners) { + IndexT inp_sN = out_c * in_h * in_w; + IndexT inp_sC = in_h * in_w; + IndexT inp_sH = in_w; + IndexT inp_sW = 1; + IndexT grid_sN = out_h * out_w * 2; + IndexT grid_sH = out_w * 2; + IndexT grid_sW = 2; + IndexT grid_sCoor = 1; + + IndexT gOut_sN = out_c * out_h * out_w; + IndexT gOut_sC = out_h * out_w; + IndexT gOut_sH = out_w; + IndexT gOut_sW = 1; + + CUDA_KERNEL_LOOP(index, nthreads) { + const IndexT w = index % out_w; + const IndexT h = (index / out_w) % out_h; + const IndexT n = index / (out_h * out_w); + const IndexT grid_offset = n * grid_sN + h * grid_sH + w * grid_sW; + + T ix = grid[grid_offset]; + T iy = grid[grid_offset + grid_sCoor]; + + T gix_mult, giy_mult; + ix = ComputePositionsWithMask( + ix, in_w, padding_mode, align_corners, &gix_mult); + iy = ComputePositionsWithMask( + iy, in_h, padding_mode, align_corners, &giy_mult); + + if (mode == Mode::bilinear) { + IndexT ix_nw = static_cast(floor(ix)); + IndexT iy_nw = static_cast(floor(iy)); + IndexT ix_ne = ix_nw + 1; + IndexT iy_ne = iy_nw; + IndexT ix_sw = ix_nw; + IndexT iy_sw = iy_nw + 1; + IndexT ix_se = ix_nw + 1; + IndexT iy_se = iy_nw + 1; + + T nw = (ix_se - ix) * (iy_se - iy); + T ne = (ix - ix_sw) * (iy_sw - iy); + T sw = (ix_ne - ix) * (iy - iy_ne); + T se = (ix - ix_nw) * (iy - iy_nw); + + T gix = static_cast(0), giy = static_cast(0); + IndexT gOut_offset = n * gOut_sN + h * gOut_sH + w * gOut_sW; + T* gInp_ptr_NC = grad_input + n * inp_sN; + IndexT inp_offset_NC = n * inp_sN; + for (IndexT c = 0; c < out_c; ++c, + inp_offset_NC += inp_sC, + gInp_ptr_NC += inp_sC, + gOut_offset += gOut_sC) { + T gOut = grad_output[gOut_offset]; + + AtomicAdd( + gInp_ptr_NC, iy_nw, ix_nw, inp_sH, inp_sW, in_h, in_w, nw * gOut); + AtomicAdd( + gInp_ptr_NC, iy_ne, ix_ne, inp_sH, inp_sW, in_h, in_w, ne * gOut); + AtomicAdd( + gInp_ptr_NC, iy_sw, ix_sw, inp_sH, inp_sW, in_h, in_w, sw * gOut); + AtomicAdd( + gInp_ptr_NC, iy_se, ix_se, inp_sH, inp_sW, in_h, in_w, se * gOut); + + if (InBounds(iy_nw, ix_nw, in_h, in_w)) { + T nw_val = input[inp_offset_NC + iy_nw * inp_sH + ix_nw * inp_sW]; + gix -= nw_val * (iy_se - iy) * gOut; + giy -= nw_val * (ix_se - ix) * gOut; + } + if (InBounds(iy_ne, ix_ne, in_h, in_w)) { + T ne_val = input[inp_offset_NC + iy_ne * inp_sH + ix_ne * inp_sW]; + gix += ne_val * (iy_sw - iy) * gOut; + giy -= ne_val * (ix - ix_sw) * gOut; + } + if (InBounds(iy_sw, ix_sw, in_h, in_w)) { + T sw_val = input[inp_offset_NC + iy_sw * inp_sH + ix_sw * inp_sW]; + gix -= sw_val * (iy - iy_ne) * gOut; + giy += sw_val * (ix_ne - ix) * gOut; + } + if (InBounds(iy_se, ix_se, in_h, in_w)) { + T se_val = input[inp_offset_NC + iy_se * inp_sH + ix_se * inp_sW]; + gix += se_val * (iy - iy_nw) * gOut; + giy += se_val * (ix - ix_nw) * gOut; + } + } + + if (grad_grid != nullptr) { + T* gGrid_ptr_NHW = grad_grid + index * grid_sW; + gGrid_ptr_NHW[0] = gix_mult * gix; + gGrid_ptr_NHW[1] = giy_mult * giy; + } + } else if (mode == Mode::nearest) { + IndexT ix_nearest = static_cast(std::nearbyint(ix)); + IndexT iy_nearest = static_cast(std::nearbyint(iy)); + + IndexT gOut_offset = n * gOut_sN + h * gOut_sH + w * gOut_sW; + T* gInp_ptr_NC = grad_input + n * inp_sN; + for (IndexT c = 0; c < out_c; + ++c, gInp_ptr_NC += inp_sC, gOut_offset += gOut_sC) { + AtomicAdd(gInp_ptr_NC, + iy_nearest, + ix_nearest, + inp_sH, + inp_sW, + in_h, + in_w, + grad_output[gOut_offset]); + } + + if (grad_grid != nullptr) { + T* gGrid_ptr_NHW = grad_grid + index * grid_sW; + gGrid_ptr_NHW[0] = static_cast(0); + gGrid_ptr_NHW[1] = static_cast(0); + } + } + } +} + +template +__global__ void GridSampler3DCudaBackwardKernel(const IndexT nthreads, + const T* grad_output, + const T* input, + const T* grid, + IndexT out_c, + IndexT out_d, + IndexT out_h, + IndexT out_w, + IndexT in_d, + IndexT in_h, + IndexT in_w, + T* grad_input, + T* grad_grid, + const Mode mode, + const PaddingMode padding_mode, + bool align_corners) { + IndexT inp_sW = 1; + IndexT inp_sH = in_w; + IndexT inp_sD = in_h * in_w; + IndexT inp_sC = in_d * inp_sD; + IndexT inp_sN = out_c * inp_sC; + + IndexT grid_sCoor = 1; + IndexT grid_sW = 3; + IndexT grid_sH = out_w * grid_sW; + IndexT grid_sD = out_h * grid_sH; + IndexT grid_sN = out_d * grid_sD; + + IndexT gOut_sW = 1; + IndexT gOut_sH = out_w; + IndexT gOut_sD = out_h * out_w; + IndexT gOut_sC = out_d * gOut_sD; + IndexT gOut_sN = out_c * gOut_sC; + + CUDA_KERNEL_LOOP_TYPE(index, nthreads, IndexT) { + const IndexT w = index % out_w; + const IndexT h = (index / out_w) % out_h; + const IndexT d = (index / (out_h * out_w)) % out_d; + const IndexT n = index / (out_d * out_h * out_w); + const auto grid_offset = + n * grid_sN + d * grid_sD + h * grid_sH + w * grid_sW; + + // get the corresponding input x, y, z coordinates from grid + T ix = grid[grid_offset]; + T iy = grid[grid_offset + grid_sCoor]; + T iz = grid[grid_offset + 2 * grid_sCoor]; + + // multipliers for gradients on ix, iy, and iz + T gix_mult, giy_mult, giz_mult; + ix = ComputePositionsWithMask( + ix, in_w, padding_mode, align_corners, &gix_mult); + iy = ComputePositionsWithMask( + iy, in_h, padding_mode, align_corners, &giy_mult); + iz = ComputePositionsWithMask( + iz, in_d, padding_mode, align_corners, &giz_mult); + + if (mode == Mode::bilinear) { + // get corner pixel values from (x, y, z) + // for 4d, we used north-east-south-west + // for 5d, we add top-bottom + IndexT ix_tnw = static_cast(std::floor(ix)); + IndexT iy_tnw = static_cast(std::floor(iy)); + IndexT iz_tnw = static_cast(std::floor(iz)); + + IndexT ix_tne = ix_tnw + 1; + IndexT iy_tne = iy_tnw; + IndexT iz_tne = iz_tnw; + + IndexT ix_tsw = ix_tnw; + IndexT iy_tsw = iy_tnw + 1; + IndexT iz_tsw = iz_tnw; + + IndexT ix_tse = ix_tnw + 1; + IndexT iy_tse = iy_tnw + 1; + IndexT iz_tse = iz_tnw; + + IndexT ix_bnw = ix_tnw; + IndexT iy_bnw = iy_tnw; + IndexT iz_bnw = iz_tnw + 1; + + IndexT ix_bne = ix_tnw + 1; + IndexT iy_bne = iy_tnw; + IndexT iz_bne = iz_tnw + 1; + + IndexT ix_bsw = ix_tnw; + IndexT iy_bsw = iy_tnw + 1; + IndexT iz_bsw = iz_tnw + 1; + + IndexT ix_bse = ix_tnw + 1; + IndexT iy_bse = iy_tnw + 1; + IndexT iz_bse = iz_tnw + 1; + + // get surfaces to each neighbor: + T tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz); + T tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz); + T tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz); + T tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz); + T bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse); + T bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw); + T bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne); + T bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw); + + T gix = static_cast(0), giy = static_cast(0), + giz = static_cast(0); + IndexT gOut_offset = + n * gOut_sN + d * gOut_sD + h * gOut_sH + w * gOut_sW; + IndexT inp_offset_NC = n * inp_sN; + T* gInp_ptr_NC = grad_input + n * inp_sN; + for (IndexT c = 0; c < out_c; ++c, + gOut_offset += gOut_sC, + gInp_ptr_NC += inp_sC, + inp_offset_NC += inp_sC) { + T gOut = grad_output[gOut_offset]; + + AtomicAdd3D(gInp_ptr_NC, + iz_tnw, + iy_tnw, + ix_tnw, + inp_sD, + inp_sH, + inp_sW, + in_d, + in_h, + in_w, + tnw * gOut); + AtomicAdd3D(gInp_ptr_NC, + iz_tne, + iy_tne, + ix_tne, + inp_sD, + inp_sH, + inp_sW, + in_d, + in_h, + in_w, + tne * gOut); + AtomicAdd3D(gInp_ptr_NC, + iz_tsw, + iy_tsw, + ix_tsw, + inp_sD, + inp_sH, + inp_sW, + in_d, + in_h, + in_w, + tsw * gOut); + AtomicAdd3D(gInp_ptr_NC, + iz_tse, + iy_tse, + ix_tse, + inp_sD, + inp_sH, + inp_sW, + in_d, + in_h, + in_w, + tse * gOut); + AtomicAdd3D(gInp_ptr_NC, + iz_bnw, + iy_bnw, + ix_bnw, + inp_sD, + inp_sH, + inp_sW, + in_d, + in_h, + in_w, + bnw * gOut); + AtomicAdd3D(gInp_ptr_NC, + iz_bne, + iy_bne, + ix_bne, + inp_sD, + inp_sH, + inp_sW, + in_d, + in_h, + in_w, + bne * gOut); + AtomicAdd3D(gInp_ptr_NC, + iz_bsw, + iy_bsw, + ix_bsw, + inp_sD, + inp_sH, + inp_sW, + in_d, + in_h, + in_w, + bsw * gOut); + AtomicAdd3D(gInp_ptr_NC, + iz_bse, + iy_bse, + ix_bse, + inp_sD, + inp_sH, + inp_sW, + in_d, + in_h, + in_w, + bse * gOut); + + // calculate grad_grid + if (InBounds3D(iz_tnw, iy_tnw, ix_tnw, in_d, in_h, in_w)) { + T tnw_val = input[inp_offset_NC + iz_tnw * inp_sD + iy_tnw * inp_sH + + ix_tnw * inp_sW]; + gix -= tnw_val * (iy_bse - iy) * (iz_bse - iz) * gOut; + giy -= tnw_val * (ix_bse - ix) * (iz_bse - iz) * gOut; + giz -= tnw_val * (ix_bse - ix) * (iy_bse - iy) * gOut; + } + if (InBounds3D(iz_tne, iy_tne, ix_tne, in_d, in_h, in_w)) { + T tne_val = input[inp_offset_NC + iz_tne * inp_sD + iy_tne * inp_sH + + ix_tne * inp_sW]; + gix += tne_val * (iy_bsw - iy) * (iz_bsw - iz) * gOut; + giy -= tne_val * (ix - ix_bsw) * (iz_bsw - iz) * gOut; + giz -= tne_val * (ix - ix_bsw) * (iy_bsw - iy) * gOut; + } + if (InBounds3D(iz_tsw, iy_tsw, ix_tsw, in_d, in_h, in_w)) { + T tsw_val = input[inp_offset_NC + iz_tsw * inp_sD + iy_tsw * inp_sH + + ix_tsw * inp_sW]; + gix -= tsw_val * (iy - iy_bne) * (iz_bne - iz) * gOut; + giy += tsw_val * (ix_bne - ix) * (iz_bne - iz) * gOut; + giz -= tsw_val * (ix_bne - ix) * (iy - iy_bne) * gOut; + } + if (InBounds3D(iz_tse, iy_tse, ix_tse, in_d, in_h, in_w)) { + T tse_val = input[inp_offset_NC + iz_tse * inp_sD + iy_tse * inp_sH + + ix_tse * inp_sW]; + gix += tse_val * (iy - iy_bnw) * (iz_bnw - iz) * gOut; + giy += tse_val * (ix - ix_bnw) * (iz_bnw - iz) * gOut; + giz -= tse_val * (ix - ix_bnw) * (iy - iy_bnw) * gOut; + } + if (InBounds3D(iz_bnw, iy_bnw, ix_bnw, in_d, in_h, in_w)) { + T bnw_val = input[inp_offset_NC + iz_bnw * inp_sD + iy_bnw * inp_sH + + ix_bnw * inp_sW]; + gix -= bnw_val * (iy_tse - iy) * (iz - iz_tse) * gOut; + giy -= bnw_val * (ix_tse - ix) * (iz - iz_tse) * gOut; + giz += bnw_val * (ix_tse - ix) * (iy_tse - iy) * gOut; + } + if (InBounds3D(iz_bne, iy_bne, ix_bne, in_d, in_h, in_w)) { + T bne_val = input[inp_offset_NC + iz_bne * inp_sD + iy_bne * inp_sH + + ix_bne * inp_sW]; + gix += bne_val * (iy_tsw - iy) * (iz - iz_tsw) * gOut; + giy -= bne_val * (ix - ix_tsw) * (iz - iz_tsw) * gOut; + giz += bne_val * (ix - ix_tsw) * (iy_tsw - iy) * gOut; + } + if (InBounds3D(iz_bsw, iy_bsw, ix_bsw, in_d, in_h, in_w)) { + T bsw_val = input[inp_offset_NC + iz_bsw * inp_sD + iy_bsw * inp_sH + + ix_bsw * inp_sW]; + gix -= bsw_val * (iy - iy_tne) * (iz - iz_tne) * gOut; + giy += bsw_val * (ix_tne - ix) * (iz - iz_tne) * gOut; + giz += bsw_val * (ix_tne - ix) * (iy - iy_tne) * gOut; + } + if (InBounds3D(iz_bse, iy_bse, ix_bse, in_d, in_h, in_w)) { + T bse_val = input[inp_offset_NC + iz_bse * inp_sD + iy_bse * inp_sH + + ix_bse * inp_sW]; + gix += bse_val * (iy - iy_tnw) * (iz - iz_tnw) * gOut; + giy += bse_val * (ix - ix_tnw) * (iz - iz_tnw) * gOut; + giz += bse_val * (ix - ix_tnw) * (iy - iy_tnw) * gOut; + } + } + if (grad_grid != nullptr) { + T* gGrid_ptr_NDHW = grad_grid + index * grid_sW; + gGrid_ptr_NDHW[0] = gix_mult * gix; + gGrid_ptr_NDHW[1] = giy_mult * giy; + gGrid_ptr_NDHW[2] = giz_mult * giz; + } + } else if (mode == Mode::nearest) { + IndexT ix_nearest = static_cast(std::round(ix)); + IndexT iy_nearest = static_cast(std::round(iy)); + IndexT iz_nearest = static_cast(std::round(iz)); + + // assign nearest neighbor pixel value to output pixel + IndexT gOut_offset = + n * gOut_sN + d * gOut_sD + h * gOut_sH + w * gOut_sW; + T* gInp_ptr_NC = grad_input + n * inp_sN; + for (IndexT c = 0; c < out_c; + ++c, gOut_offset += gOut_sC, gInp_ptr_NC += inp_sC) { + AtomicAdd3D(gInp_ptr_NC, + iz_nearest, + iy_nearest, + ix_nearest, + inp_sD, + inp_sH, + inp_sW, + in_d, + in_h, + in_w, + grad_output[gOut_offset]); + } + if (grad_grid != nullptr) { + T* gGrid_ptr_NDHW = grad_grid + index * grid_sW; + gGrid_ptr_NDHW[0] = static_cast(0); + gGrid_ptr_NDHW[1] = static_cast(0); + gGrid_ptr_NDHW[2] = static_cast(0); + } + } + } +} + +template +void GridSampleGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& grid, + const DenseTensor& out_grad, + const std::string& mode, + const std::string& padding_mode, + bool align_corners, + DenseTensor* x_grad, + DenseTensor* grid_grad) { + if (out_grad.numel() == 0) { + if (x_grad) { + phi::Full( + dev_ctx, phi::IntArray(common::vectorize(x_grad->dims())), 0, x_grad); + } + if (grid_grad) { + phi::Full(dev_ctx, + phi::IntArray(common::vectorize(grid_grad->dims())), + 0, + grid_grad); + } + return; + } + + PaddingMode enum_padding_mode; + Mode enum_mode; + if (padding_mode == "border") { + enum_padding_mode = PaddingMode::border; + } else if (padding_mode == "reflection") { + enum_padding_mode = PaddingMode::reflect; + } else { + enum_padding_mode = PaddingMode::zeros; + } + + if (mode == "nearest") { + enum_mode = Mode::nearest; + } else { + enum_mode = Mode::bilinear; + } + +#ifndef PADDLE_WITH_HIP + if (condCudnnGridSampler(x, grid) && + enum_padding_mode == PaddingMode::zeros && enum_mode == Mode::bilinear && + align_corners) { + const int64_t N = x.dims()[0]; + const int64_t C = x.dims()[1]; + const int64_t H_in = x.dims()[2]; + const int64_t W_in = x.dims()[3]; + const int64_t H_out = grid.dims()[1]; + const int64_t W_out = grid.dims()[2]; + + // cuDNN handle + cudnnHandle_t handle = GetDnnHandle(dev_ctx.stream(), dev_ctx.GetPlace()); + + // Create and set Tensor descriptors (NCHW) for x/y + cudnnTensorDescriptor_t x_desc, dx_desc, y_desc; + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cudnnCreateTensorDescriptor(&x_desc)); + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cudnnCreateTensorDescriptor(&dx_desc)); + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cudnnCreateTensorDescriptor(&y_desc)); + + const cudnnDataType_t cudnn_dtype = + std::is_same::value ? CUDNN_DATA_FLOAT : CUDNN_DATA_DOUBLE; + + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cudnnSetTensor4dDescriptor(x_desc, + CUDNN_TENSOR_NCHW, + cudnn_dtype, + static_cast(N), + static_cast(C), + static_cast(H_in), + static_cast(W_in))); + + // The shape of dx is consistent with that of x + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cudnnSetTensor4dDescriptor(dx_desc, + CUDNN_TENSOR_NCHW, + cudnn_dtype, + static_cast(N), + static_cast(C), + static_cast(H_in), + static_cast(W_in))); + + // The shape of y is consistent with out_grad + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cudnnSetTensor4dDescriptor(y_desc, + CUDNN_TENSOR_NCHW, + cudnn_dtype, + static_cast(N), + static_cast(C), + static_cast(H_out), + static_cast(W_out))); + + // Spatial Transformer descriptor: specifies sampler type and output + // dimension (N, C, H_out, W_out) + cudnnSpatialTransformerDescriptor_t st_desc; + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cudnnCreateSpatialTransformerDescriptor(&st_desc)); + int st_dims[4] = {static_cast(N), + static_cast(C), + static_cast(H_out), + static_cast(W_out)}; + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cudnnSetSpatialTransformerNdDescriptor( + st_desc, CUDNN_SAMPLER_BILINEAR, cudnn_dtype, 4, st_dims)); + + // data pointer + const T* x_data = x.data(); + const T* grid_data = grid.data(); + const T* dy_data = out_grad.data(); + + T* dx_data = dev_ctx.template Alloc(x_grad); + phi::funcs::SetConstant()(dev_ctx, x_grad, static_cast(0)); + + T* dgrid_data = nullptr; + if (grid_grad) { + dgrid_data = dev_ctx.template Alloc(grid_grad); + } + + // alpha/beta + using AlphaBetaT = typename std:: + conditional::value, float, double>::type; + const AlphaBetaT one = static_cast(1.0); + const AlphaBetaT zero = static_cast(0.0); + + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnSpatialTfSamplerBackward( + handle, + st_desc, + static_cast(&one), // alpha (for dx) + x_desc, + static_cast(x_data), + static_cast(&zero), // beta (for dx) + dx_desc, + static_cast(dx_data), + static_cast(&one), // alpha (for dgrid) + y_desc, + static_cast(dy_data), + static_cast(grid_data), + static_cast(&zero), // beta (for dgrid) + static_cast(dgrid_data))); + + // resource release + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cudnnDestroySpatialTransformerDescriptor(st_desc)); + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cudnnDestroyTensorDescriptor(x_desc)); + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cudnnDestroyTensorDescriptor(dx_desc)); + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cudnnDestroyTensorDescriptor(y_desc)); + return; + } +#endif + + bool use_int32_index = x.numel() <= std::numeric_limits::max() && + grid.numel() <= std::numeric_limits::max() && + out_grad.numel() <= std::numeric_limits::max(); + + if (x.dims().size() == 4) { + const int64_t n = grid.dims()[0]; + const int64_t out_h = grid.dims()[1]; + const int64_t out_w = grid.dims()[2]; + const int64_t c = x.dims()[1]; + const int64_t in_h = x.dims()[2]; + const int64_t in_w = x.dims()[3]; + + dev_ctx.template Alloc(x_grad); + phi::funcs::SetConstant()(dev_ctx, x_grad, static_cast(0)); + + T* grid_grad_data = nullptr; + if (grid_grad != nullptr) { + grid_grad_data = dev_ctx.template Alloc(grid_grad); + } + + int64_t count = n * out_h * out_w; + auto cu_stream = dev_ctx.stream(); + backends::gpu::GpuLaunchConfig config = + backends::gpu::GetGpuLaunchConfig1D(dev_ctx, count); + +#define LAUNCH_KERNEL(INDEX_TYPE) \ + GridSamplerCudaBackwardKernel \ + <<>>( \ + count, \ + out_grad.data(), \ + x.data(), \ + grid.data(), \ + n, \ + c, \ + out_h, \ + out_w, \ + in_h, \ + in_w, \ + x_grad->data(), \ + grid_grad_data, \ + enum_mode, \ + enum_padding_mode, \ + align_corners); + if (use_int32_index) { + LAUNCH_KERNEL(int32_t) + } else { + LAUNCH_KERNEL(int64_t) + } +#undef LAUNCH_KERNEL + } else { + const int64_t out_d = grid.dims()[1]; + const int64_t out_h = grid.dims()[2]; + const int64_t out_w = grid.dims()[3]; + const int64_t n = x.dims()[0]; + const int64_t c = x.dims()[1]; + const int64_t in_d = x.dims()[2]; + const int64_t in_h = x.dims()[3]; + const int64_t in_w = x.dims()[4]; + + dev_ctx.template Alloc(x_grad); + phi::funcs::SetConstant()(dev_ctx, x_grad, static_cast(0)); + + T* grid_grad_data = nullptr; + if (grid_grad != nullptr) { + grid_grad_data = dev_ctx.template Alloc(grid_grad); + } + + int64_t count = static_cast(n * out_d * out_h * out_w); + auto cu_stream = dev_ctx.stream(); + backends::gpu::GpuLaunchConfig config = + backends::gpu::GetGpuLaunchConfig1D(dev_ctx, count); + +#define LAUNCH_KERNEL(INDEX_TYPE) \ + GridSampler3DCudaBackwardKernel \ + <<>>( \ + count, \ + out_grad.data(), \ + x.data(), \ + grid.data(), \ + c, \ + out_d, \ + out_h, \ + out_w, \ + in_d, \ + in_h, \ + in_w, \ + x_grad->data(), \ + grid_grad_data, \ + enum_mode, \ + enum_padding_mode, \ + align_corners); + if (use_int32_index) { + LAUNCH_KERNEL(int32_t) + } else { + LAUNCH_KERNEL(int64_t) + } +#undef LAUNCH_KERNEL + } +} + +} // namespace phi + +PD_REGISTER_PLUGIN_KERNEL(grid_sample_grad, + metax_gpus, + ALL_LAYOUT, + phi::GridSampleGradKernel, + float, + double) {} diff --git a/backends/metax_gpu/kernels/metax_kernel/grid_sample_kernel_register.cu b/backends/metax_gpu/kernels/metax_kernel/grid_sample_kernel_register.cu new file mode 100644 index 00000000000..71050c264c6 --- /dev/null +++ b/backends/metax_gpu/kernels/metax_kernel/grid_sample_kernel_register.cu @@ -0,0 +1,527 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "glog/logging.h" +#include "kernels/metax_kernel/metax_context.h" +#include "paddle/phi/backends/dynload/cudnn.h" +#include "paddle/phi/backends/gpu/gpu_info.h" +#include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/gpu/grid_sample_utils.h" +#include "paddle/phi/kernels/grid_sample_kernel.h" + +namespace phi { + +template +static __forceinline__ __device__ T Unnormalize(T coord, + IndexT size, + bool align_corners) { + return align_corners ? ((coord + 1.f) / 2) * (size - 1) + : ((coord + 1.f) * size - 1) / 2; +} + +template +static __forceinline__ __device__ T ClipIndexes(T in, IndexT max_value) { + return min(static_cast(max_value - 1), max(in, static_cast(0))); +} + +template +static __forceinline__ __device__ T ReflectIndexes(T in, + IndexT twice_low, + IndexT twice_high) { + if (twice_low == twice_high) { + return static_cast(0); + } + T min = static_cast(twice_low) / 2; + T span = static_cast(twice_high - twice_low) / 2; + in = fabs(in - min); + T extra = fmod(in, span); + IndexT flips = floor(in / span); + return (flips & 1) ? span - extra + min : extra + min; // cond ? odd : even +} + +template +static __forceinline__ __device__ T ComputePositions(T coord, + IndexT size, + PaddingMode padding_mode, + bool align_corners) { + coord = Unnormalize(coord, size, align_corners); + if (padding_mode == PaddingMode::border) { + coord = ClipIndexes(coord, size); + } else if (padding_mode == PaddingMode::reflect) { + coord = align_corners ? ReflectIndexes(coord, 0, 2 * (size - 1)) + : ReflectIndexes(coord, -1, 2 * size - 1); + coord = ClipIndexes(coord, size); + } + return SafeDownGradeToIntRange(coord); +} + +template +__global__ void GridSampleCudaKernel(IndexT n, + IndexT out_c, + IndexT out_hw, + IndexT in_h, + IndexT in_w, + const T* __restrict__ input, + const T* __restrict__ grid, + T* __restrict__ output, + const Mode mode, + const PaddingMode padding_mode, + bool align_corners) { + IndexT nthreads = n * out_hw; + IndexT inp_sN = out_c * (in_h * in_w); + IndexT inp_sC = in_h * in_w; + IndexT inp_sH = in_w; + IndexT inp_sW = 1; + IndexT grid_sNHW = 2; + IndexT grid_sCoor = 1; + IndexT out_sN = out_c * out_hw; + IndexT out_sC = out_hw; + IndexT out_sHW = 1; + CUDA_KERNEL_LOOP_TYPE(index, nthreads, IndexT) { + const IndexT hw = index % out_hw; + const IndexT n = index / out_hw; + const IndexT grid_offset = index * grid_sNHW; + + T ix = grid[grid_offset]; + T iy = grid[grid_offset + grid_sCoor]; + + ix = ComputePositions(ix, in_w, padding_mode, align_corners); + iy = ComputePositions(iy, in_h, padding_mode, align_corners); + if (mode == Mode::bilinear) { + IndexT ix_nw = floor(ix); + IndexT iy_nw = floor(iy); + IndexT ix_ne = ix_nw + 1; + IndexT iy_ne = iy_nw; + IndexT ix_sw = ix_nw; + IndexT iy_sw = iy_nw + 1; + IndexT ix_se = ix_nw + 1; + IndexT iy_se = iy_nw + 1; + + T nw = (ix_se - ix) * (iy_se - iy); + T ne = (ix - ix_sw) * (iy_sw - iy); + T sw = (ix_ne - ix) * (iy - iy_ne); + T se = (ix - ix_nw) * (iy - iy_nw); + + IndexT inp_offset_NC = n * inp_sN; + T* out_ptr_NCHW = output + (n * out_sN + hw * out_sHW); + + for (IndexT c = 0; c < out_c; + ++c, inp_offset_NC += inp_sC, out_ptr_NCHW += out_sC) { + T value{0}; + if (InBounds(iy_nw, ix_nw, in_h, in_w)) { + value += input[inp_offset_NC + iy_nw * inp_sH + ix_nw * inp_sW] * nw; + } + if (InBounds(iy_ne, ix_ne, in_h, in_w)) { + value += input[inp_offset_NC + iy_ne * inp_sH + ix_ne * inp_sW] * ne; + } + if (InBounds(iy_sw, ix_sw, in_h, in_w)) { + value += input[inp_offset_NC + iy_sw * inp_sH + ix_sw * inp_sW] * sw; + } + if (InBounds(iy_se, ix_se, in_h, in_w)) { + value += input[inp_offset_NC + iy_se * inp_sH + ix_se * inp_sW] * se; + } + *out_ptr_NCHW = value; + } + } else if (mode == Mode::nearest) { + IndexT ix_nearest = std::nearbyint(ix); + IndexT iy_nearest = std::nearbyint(iy); + IndexT inp_offset_NC = n * inp_sN; + T* out_ptr_NCHW = output + (n * out_sN + hw * out_sHW); + for (IndexT c = 0; c < out_c; + ++c, inp_offset_NC += inp_sC, out_ptr_NCHW += out_sC) { + if (InBounds(iy_nearest, ix_nearest, in_h, in_w)) { + *out_ptr_NCHW = + input[inp_offset_NC + iy_nearest * inp_sH + ix_nearest * inp_sW]; + } else { + *out_ptr_NCHW = static_cast(0); + } + } + } + } +} + +template +__global__ void GridSample3DCudaKernel(const IndexT nthreads, + IndexT out_c, + IndexT out_d, + IndexT out_h, + IndexT out_w, + IndexT in_d, + IndexT in_h, + IndexT in_w, + const T* input, + const T* grid, + T* output, + const Mode interpolation_mode, + const PaddingMode padding_mode, + bool align_corners) { + IndexT inp_sW = 1; + IndexT inp_sH = in_w; + IndexT inp_sD = in_h * in_w; + IndexT inp_sC = in_d * inp_sD; + IndexT inp_sN = out_c * inp_sC; + + IndexT grid_sCoor = 1; + IndexT grid_sW = 3; + IndexT grid_sH = out_w * grid_sW; + IndexT grid_sD = out_h * grid_sH; + IndexT grid_sN = out_d * grid_sD; + + IndexT out_sW = 1; + IndexT out_sH = out_w; + IndexT out_sD = out_h * out_w; + IndexT out_sC = out_d * out_sD; + IndexT out_sN = out_c * out_sC; + + CUDA_KERNEL_LOOP_TYPE(index, nthreads, IndexT) { + const IndexT w = index % out_w; + const IndexT h = (index / out_w) % out_h; + const IndexT d = (index / (out_h * out_w)) % out_d; + const IndexT n = index / (out_d * out_h * out_w); + const IndexT grid_offset = + n * grid_sN + d * grid_sD + h * grid_sH + w * grid_sW; + // get the corresponding input x, y, z coordinates from grid + T ix = grid[grid_offset]; + T iy = grid[grid_offset + grid_sCoor]; + T iz = grid[grid_offset + 2 * grid_sCoor]; + ix = ComputePositions(ix, in_w, padding_mode, align_corners); + iy = ComputePositions(iy, in_h, padding_mode, align_corners); + iz = ComputePositions(iz, in_d, padding_mode, align_corners); + if (interpolation_mode == Mode::bilinear) { + // get corner pixel values from (x, y, z) + // for 4d, we used north-east-south-west + // for 5d, we add top-bottom + IndexT ix_tnw = static_cast(std::floor(ix)); + IndexT iy_tnw = static_cast(std::floor(iy)); + IndexT iz_tnw = static_cast(std::floor(iz)); + + IndexT ix_tne = ix_tnw + 1; + IndexT iy_tne = iy_tnw; + IndexT iz_tne = iz_tnw; + + IndexT ix_tsw = ix_tnw; + IndexT iy_tsw = iy_tnw + 1; + IndexT iz_tsw = iz_tnw; + + IndexT ix_tse = ix_tnw + 1; + IndexT iy_tse = iy_tnw + 1; + IndexT iz_tse = iz_tnw; + + IndexT ix_bnw = ix_tnw; + IndexT iy_bnw = iy_tnw; + IndexT iz_bnw = iz_tnw + 1; + + IndexT ix_bne = ix_tnw + 1; + IndexT iy_bne = iy_tnw; + IndexT iz_bne = iz_tnw + 1; + + IndexT ix_bsw = ix_tnw; + IndexT iy_bsw = iy_tnw + 1; + IndexT iz_bsw = iz_tnw + 1; + + IndexT ix_bse = ix_tnw + 1; + IndexT iy_bse = iy_tnw + 1; + IndexT iz_bse = iz_tnw + 1; + + // get surfaces to each neighbor: + T tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz); + T tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz); + T tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz); + T tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz); + T bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse); + T bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw); + T bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne); + T bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw); + + const T* inp_ptr_NC = input + n * inp_sN; + T* out_ptr_NCDHW = + output + (n * out_sN + d * out_sD + h * out_sH + w * out_sW); + for (IndexT c = 0; c < out_c; + ++c, inp_ptr_NC += inp_sC, out_ptr_NCDHW += out_sC) { + *out_ptr_NCDHW = static_cast(0); + if (InBounds3D(iz_tnw, iy_tnw, ix_tnw, in_d, in_h, in_w)) { + *out_ptr_NCDHW += + inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW] * + tnw; + } + if (InBounds3D(iz_tne, iy_tne, ix_tne, in_d, in_h, in_w)) { + *out_ptr_NCDHW += + inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW] * + tne; + } + if (InBounds3D(iz_tsw, iy_tsw, ix_tsw, in_d, in_h, in_w)) { + *out_ptr_NCDHW += + inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW] * + tsw; + } + if (InBounds3D(iz_tse, iy_tse, ix_tse, in_d, in_h, in_w)) { + *out_ptr_NCDHW += + inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW] * + tse; + } + if (InBounds3D(iz_bnw, iy_bnw, ix_bnw, in_d, in_h, in_w)) { + *out_ptr_NCDHW += + inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW] * + bnw; + } + if (InBounds3D(iz_bne, iy_bne, ix_bne, in_d, in_h, in_w)) { + *out_ptr_NCDHW += + inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW] * + bne; + } + if (InBounds3D(iz_bsw, iy_bsw, ix_bsw, in_d, in_h, in_w)) { + *out_ptr_NCDHW += + inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW] * + bsw; + } + if (InBounds3D(iz_bse, iy_bse, ix_bse, in_d, in_h, in_w)) { + *out_ptr_NCDHW += + inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW] * + bse; + } + } + } else if (interpolation_mode == Mode::nearest) { + IndexT ix_nearest = static_cast(std::nearbyint(ix)); + IndexT iy_nearest = static_cast(std::nearbyint(iy)); + IndexT iz_nearest = static_cast(std::nearbyint(iz)); + + // assign nearest neighbor pixel value to output pixel + const T* inp_ptr_NC = input + n * inp_sN; + T* out_ptr_NCDHW = + output + (n * out_sN + d * out_sD + h * out_sH + w * out_sW); + for (IndexT c = 0; c < out_c; + ++c, inp_ptr_NC += inp_sC, out_ptr_NCDHW += out_sC) { + if (InBounds3D(iz_nearest, iy_nearest, ix_nearest, in_d, in_h, in_w)) { + *out_ptr_NCDHW = + inp_ptr_NC[iz_nearest * inp_sD + iy_nearest * inp_sH + + ix_nearest * inp_sW]; + } else { + *out_ptr_NCDHW = static_cast(0); + } + } + } + } +} + +template +void GridSampleKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& grid, + const std::string& mode, + const std::string& padding_mode, + bool align_corners, + DenseTensor* out) { + if (out && out->numel() == 0) { + dev_ctx.template Alloc(out); + return; + } + PaddingMode enum_padding_mode; + Mode enum_mode; + if (padding_mode == "border") { + enum_padding_mode = PaddingMode::border; + } else if (padding_mode == "reflection") { + enum_padding_mode = PaddingMode::reflect; + } else { + enum_padding_mode = PaddingMode::zeros; + } + + if (mode == "nearest") { + enum_mode = Mode::nearest; + } else { + enum_mode = Mode::bilinear; + } + +#ifndef PADDLE_WITH_HIP + if (condCudnnGridSampler(x, grid) && + enum_padding_mode == PaddingMode::zeros && enum_mode == Mode::bilinear && + align_corners) { + const int64_t N = x.dims()[0]; + const int64_t C = x.dims()[1]; + const int64_t H_in = x.dims()[2]; + const int64_t W_in = x.dims()[3]; + const int64_t H_out = grid.dims()[1]; + const int64_t W_out = grid.dims()[2]; + + out->Resize({N, C, H_out, W_out}); + auto* out_data = dev_ctx.template Alloc(out); + + cudnnHandle_t handle = GetDnnHandle(dev_ctx.stream(), dev_ctx.GetPlace()); + + // Create and set Tensor descriptors (NCHW) for x and out + cudnnTensorDescriptor_t x_desc, y_desc; + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cudnnCreateTensorDescriptor(&x_desc)); + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cudnnCreateTensorDescriptor(&y_desc)); + + const cudnnDataType_t cudnn_dtype = + std::is_same::value ? CUDNN_DATA_FLOAT : CUDNN_DATA_DOUBLE; + + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cudnnSetTensor4dDescriptor(x_desc, + CUDNN_TENSOR_NCHW, + cudnn_dtype, + static_cast(N), + static_cast(C), + static_cast(H_in), + static_cast(W_in))); + + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cudnnSetTensor4dDescriptor(y_desc, + CUDNN_TENSOR_NCHW, + cudnn_dtype, + static_cast(N), + static_cast(C), + static_cast(H_out), + static_cast(W_out))); + + // Spatial Transformer descriptor: specifies sampler type and output + // dimension (N, C, H_out, W_out) + cudnnSpatialTransformerDescriptor_t st_desc; + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cudnnCreateSpatialTransformerDescriptor(&st_desc)); + int st_dims[4] = {static_cast(N), + static_cast(C), + static_cast(H_out), + static_cast(W_out)}; + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cudnnSetSpatialTransformerNdDescriptor( + st_desc, CUDNN_SAMPLER_BILINEAR, cudnn_dtype, 4, st_dims)); + + const T* x_data = x.data(); + const T* grid_data = grid.data(); + using AlphaBetaT = typename std:: + conditional::value, float, double>::type; + const AlphaBetaT alpha = static_cast(1.0); + const AlphaBetaT beta = static_cast(0.0); + + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cudnnSpatialTfSamplerForward( + handle, + st_desc, + static_cast(&alpha), + x_desc, + static_cast(x_data), + static_cast(grid_data), + static_cast(&beta), + y_desc, + static_cast(out_data))); + + // resource release + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cudnnDestroySpatialTransformerDescriptor(st_desc)); + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cudnnDestroyTensorDescriptor(x_desc)); + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cudnnDestroyTensorDescriptor(y_desc)); + return; + } +#endif + + bool use_int32_index = x.numel() <= std::numeric_limits::max() && + grid.numel() <= std::numeric_limits::max() && + out->numel() <= std::numeric_limits::max(); + + if (x.dims().size() == 4) { + const int64_t n = grid.dims()[0]; + const int64_t out_h = grid.dims()[1]; + const int64_t out_w = grid.dims()[2]; + const int64_t c = x.dims()[1]; + const int64_t in_h = x.dims()[2]; + const int64_t in_w = x.dims()[3]; + VLOG(3) << "n: " << n << "; c: " << c << "; out_h: " << out_h + << "; out_w: " << out_w; + + auto* output_data = dev_ctx.template Alloc(out); + VLOG(3) << "out dims: " << out->dims()[0] << "; " << out->dims()[1] << "; " + << out->dims()[2] << "; " << out->dims()[3]; + + int64_t count = n * out_h * out_w; + auto cu_stream = dev_ctx.stream(); + backends::gpu::GpuLaunchConfig config = + backends::gpu::GetGpuLaunchConfig1D(dev_ctx, count); + +#define LAUNCH_KERNEL(INDEX_TYPE) \ + GridSampleCudaKernel \ + <<>>( \ + n, \ + c, \ + out_h * out_w, \ + in_h, \ + in_w, \ + x.data(), \ + grid.data(), \ + output_data, \ + enum_mode, \ + enum_padding_mode, \ + align_corners) + if (use_int32_index) { + LAUNCH_KERNEL(int); + } else { + LAUNCH_KERNEL(int64_t); + } +#undef LAUNCH_KERNEL + } else { + const int64_t n = grid.dims()[0]; + const int64_t out_d = grid.dims()[1]; + const int64_t out_h = grid.dims()[2]; + const int64_t out_w = grid.dims()[3]; + const int64_t c = x.dims()[1]; + const int64_t in_d = x.dims()[2]; + const int64_t in_h = x.dims()[3]; + const int64_t in_w = x.dims()[4]; + + VLOG(3) << "n: " << n << "; c: " << c << "; out_d: " << out_d + << "; out_h: " << out_h << "; out_w: " << out_w; + + auto* output_data = dev_ctx.template Alloc(out); + VLOG(3) << "out dims: " << out->dims()[0] << "; " << out->dims()[1] << "; " + << out->dims()[2] << "; " << out->dims()[3] << "; " + << out->dims()[4]; + + int64_t count = n * out_d * out_h * out_w; + auto cu_stream = dev_ctx.stream(); + backends::gpu::GpuLaunchConfig config = + backends::gpu::GetGpuLaunchConfig1D(dev_ctx, count); + +#define LAUNCH_KERNEL(INDEX_TYPE) \ + GridSample3DCudaKernel \ + <<>>( \ + count, \ + c, \ + out_d, \ + out_h, \ + out_w, \ + in_d, \ + in_h, \ + in_w, \ + x.data(), \ + grid.data(), \ + output_data, \ + enum_mode, \ + enum_padding_mode, \ + align_corners) + if (use_int32_index) { + LAUNCH_KERNEL(int); + } else { + LAUNCH_KERNEL(int64_t); + } +#undef LAUNCH_KERNEL + } +} + +} // namespace phi + +PD_REGISTER_PLUGIN_KERNEL( + grid_sample, metax_gpu, ALL_LAYOUT, phi::GridSampleKernel, float, double) {} diff --git a/backends/metax_gpu/kernels/metax_kernel/weight_only_linear_kernel.cu b/backends/metax_gpu/kernels/metax_kernel/weight_only_linear_kernel.cu index eae8c8c0301..d2f39ccf751 100644 --- a/backends/metax_gpu/kernels/metax_kernel/weight_only_linear_kernel.cu +++ b/backends/metax_gpu/kernels/metax_kernel/weight_only_linear_kernel.cu @@ -35,6 +35,7 @@ void WeightOnlyLinearKernel(const Context& dev_ctx, const int32_t group_size, DenseTensor* out) { dev_ctx.template Alloc(out); + auto stream = dev_ctx.stream(); const T* x_data = x.data(); const int8_t* weight_data = weight.data(); const T* bias_data = bias ? bias.get().data() : nullptr; @@ -128,7 +129,7 @@ void WeightOnlyLinearKernel(const Context& dev_ctx, k, n, n}; - mctlass_op(arguments); + mctlass_op(arguments, NULL, stream); } else { mctlassGemmScaleOp_w8a16_bias mctlass_op; typename mctlassGemmScaleOp_w8a16_bias::Arguments arguments{