diff --git a/paddle/common/flags.cc b/paddle/common/flags.cc index 8efd244f671439..46bb82b776a562 100644 --- a/paddle/common/flags.cc +++ b/paddle/common/flags.cc @@ -2248,3 +2248,15 @@ PHI_DEFINE_EXPORTED_bool( force_stride_compute_contig_out, false, "Whether force Stride_Compute_Kernel output contiguous."); + +/** + * Torch Compatible related FLAG + * Name: FLAGS_torch_compatible_kernel + * Since Version: 3.2.2 + * Value Range: bool, default=false + * Example: + * Note: Whether use torch compatible version kernel. + */ +PHI_DEFINE_EXPORTED_bool(torch_compatible_kernel, + false, + "Whether use torch compatible version kernel."); diff --git a/paddle/phi/kernels/funcs/pooling.cu b/paddle/phi/kernels/funcs/pooling.cu index f6fbe18490fa19..06bcee3be384c1 100644 --- a/paddle/phi/kernels/funcs/pooling.cu +++ b/paddle/phi/kernels/funcs/pooling.cu @@ -22,6 +22,7 @@ limitations under the License. */ #include #endif +#include "paddle/common/flags.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/backends/gpu/gpu_primitives.h" #include "paddle/phi/kernels/funcs/distribution_helper.h" @@ -30,6 +31,8 @@ limitations under the License. */ #include "paddle/phi/kernels/funcs/random.cuh" #include "paddle/phi/kernels/funcs/reduce_function.h" +COMMON_DECLARE_bool(torch_compatible_kernel); + namespace phi { namespace funcs { @@ -92,6 +95,20 @@ struct FastDivModForPoolingWithMoreStaff { stride_h(stride_height) {} }; +static __device__ inline int p_start(int size, + int pad, + int kernel, + int stride) { + return (size + pad < kernel) ? 0 : (size + pad - kernel) / stride + 1; +} + +static __device__ inline int p_end(int size, + int pad, + int pooled_size, + int stride) { + return std::min((size + pad) / stride + 1, pooled_size); +} + template __device__ void OffsetPreparationFor4Dimension(IndexT index, bool channel_last, @@ -474,6 +491,56 @@ __global__ void KernelMaxPool2DGrad(const IndexT nthreads, } } +template +__global__ void KernelMaxPool2DGradCompatible( + const T* input_data, + const T* output_data, + const T* output_grad, + const IndexT batch_size, + const IndexT channels, + const IndexT input_height, + const IndexT input_width, + const IndexT output_height, + const IndexT output_width, + const IndexT ksize_height, + const IndexT ksize_width, + const IndexT stride_height, + const IndexT stride_width, + const IndexT padding_height, + const IndexT padding_width, + T* input_grad, + FastDivModForPooling divmods, + bool channel_last = false) { + using MPType = typename phi::dtype::MPTypeTrait::Type; + + CUDA_KERNEL_LOOP(index, input_height * input_width) { + IndexT h = index / input_width; + IndexT w = index - h * input_width; + IndexT phstart = p_start(h, padding_height, ksize_height, stride_height); + IndexT phend = p_end(h, padding_height, output_height, stride_height); + IndexT pwstart = p_start(w, padding_width, ksize_width, stride_width); + IndexT pwend = p_end(w, padding_width, output_width, stride_width); + T input_data_value = input_data[h * input_width + w]; + for (IndexT n = blockIdx.y; n < batch_size; n += gridDim.y) { + for (IndexT c = blockIdx.z; c < channels; c += gridDim.z) { + MPType gradient = static_cast(0.0f); + IndexT offset = (n * channels + c) * output_height * output_width; + for (int ph = phstart; ph < phend; ++ph) { + for (int pw = pwstart; pw < pwend; ++pw) { + T output_data_value = output_data[ph * output_width + pw + offset]; + if (output_data_value == input_data_value) { + gradient += static_cast( + output_grad[ph * output_width + pw + offset]); + } + } + } + input_grad[(n * channels + c) * input_height * input_width + index] = + static_cast(gradient); + } + } + } +} + template void Pool2dDirectCUDAFunctor::operator()( const T* input, @@ -879,6 +946,8 @@ class MaxPool2dGradFunctor { const std::vector& paddings, const std::string data_format, DenseTensor* input_grad) { + static const int kBlockThreads = 1024; + bool channel_last = (data_format == "NHWC"); const int64_t batch_size = input.dims()[0]; @@ -913,55 +982,112 @@ class MaxPool2dGradFunctor { int64_t nthreads = batch_size * output_channels * output_height * output_width; - int64_t blocks = (nthreads + 1024 - 1) / 1024; - dim3 threads(1024, 1); - dim3 grid(blocks, 1); + dim3 threads(kBlockThreads, 1); if (input.numel() <= std::numeric_limits::max() && output.numel() <= std::numeric_limits::max()) { auto pool_divmods = FastDivModForPooling( input_channels, output_width, output_height); - KernelMaxPool2DGrad - <<>>(nthreads, - input_data, - output_data, - output_grad_data, - input_channels, - input_height, - input_width, - output_height, - output_width, - ksize_height, - ksize_width, - stride_height, - stride_width, - padding_height, - padding_width, - input_grad_data, - pool_divmods, - channel_last); + if (FLAGS_torch_compatible_kernel) { + int64_t blocks = + (input_width * input_height + kBlockThreads - 1) / kBlockThreads; + dim3 grid(blocks, batch_size, input_channels); + // NOTE: input.numel() <= std::numeric_limits::max() && + // output.numel() <= std::numeric_limits::max() + KernelMaxPool2DGradCompatible + <<>>(input_data, + output_data, + output_grad_data, + batch_size, + input_channels, + input_height, + input_width, + output_height, + output_width, + ksize_height, + ksize_width, + stride_height, + stride_width, + padding_height, + padding_width, + input_grad_data, + pool_divmods, + channel_last); + } else { + int64_t blocks = (nthreads + kBlockThreads - 1) / kBlockThreads; + dim3 grid(blocks, 1); + // NOTE: input.numel() <= std::numeric_limits::max() && + // output.numel() <= std::numeric_limits::max() + KernelMaxPool2DGrad + <<>>(nthreads, + input_data, + output_data, + output_grad_data, + input_channels, + input_height, + input_width, + output_height, + output_width, + ksize_height, + ksize_width, + stride_height, + stride_width, + padding_height, + padding_width, + input_grad_data, + pool_divmods, + channel_last); + } + } else { auto pool_divmods = FastDivModForPooling( input_channels, output_width, output_height); - KernelMaxPool2DGrad - <<>>(nthreads, - input_data, - output_data, - output_grad_data, - input_channels, - input_height, - input_width, - output_height, - output_width, - ksize_height, - ksize_width, - stride_height, - stride_width, - padding_height, - padding_width, - input_grad_data, - pool_divmods, - channel_last); + if (FLAGS_torch_compatible_kernel) { + int64_t blocks = + (input_width * input_height + kBlockThreads - 1) / kBlockThreads; + dim3 grid(blocks, batch_size, input_channels); + KernelMaxPool2DGradCompatible + <<>>(input_data, + output_data, + output_grad_data, + batch_size, + input_channels, + input_height, + input_width, + output_height, + output_width, + ksize_height, + ksize_width, + stride_height, + stride_width, + padding_height, + padding_width, + input_grad_data, + pool_divmods, + channel_last); + } else { + int64_t blocks = (nthreads + kBlockThreads - 1) / kBlockThreads; + dim3 grid(blocks, 1); + KernelMaxPool2DGrad + <<>>(nthreads, + input_data, + output_data, + output_grad_data, + input_channels, + input_height, + input_width, + output_height, + output_width, + ksize_height, + ksize_width, + stride_height, + stride_width, + padding_height, + padding_width, + input_grad_data, + pool_divmods, + channel_last); + } } } }; diff --git a/test/legacy_test/test_pool2d_api.py b/test/legacy_test/test_pool2d_api.py index 27b1986b79bf0c..6bc7d1b497fc97 100644 --- a/test/legacy_test/test_pool2d_api.py +++ b/test/legacy_test/test_pool2d_api.py @@ -769,6 +769,13 @@ def test_pool2d_static(self): self.check_lp_float16_static(place) paddle.disable_static() + def test_torch_compatible(self): + paddle.set_flags({'FLAGS_torch_compatible_kernel': 1}) + paddle.enable_static() + for place in self.places: + self.check_max_static_results(place) + paddle.disable_static() + def test_pool2d(self): for place in self.places: self.check_max_dygraph_results(place)