Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions paddle/common/flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2248,3 +2248,15 @@ PHI_DEFINE_EXPORTED_bool(
force_stride_compute_contig_out,
false,
"Whether force Stride_Compute_Kernel output contiguous.");

/**
* Torch Compatible related FLAG
* Name: FLAGS_torch_compatible_kernel
* Since Version: 3.2.2
* Value Range: bool, default=false
* Example:
* Note: Whether use torch compatible version kernel.
*/
PHI_DEFINE_EXPORTED_bool(torch_compatible_kernel,
false,
"Whether use torch compatible version kernel.");
208 changes: 167 additions & 41 deletions paddle/phi/kernels/funcs/pooling.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ limitations under the License. */
#include <hiprand_kernel.h>
#endif

#include "paddle/common/flags.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/kernels/funcs/distribution_helper.h"
Expand All @@ -30,6 +31,8 @@ limitations under the License. */
#include "paddle/phi/kernels/funcs/random.cuh"
#include "paddle/phi/kernels/funcs/reduce_function.h"

COMMON_DECLARE_bool(torch_compatible_kernel);

namespace phi {
namespace funcs {

Expand Down Expand Up @@ -92,6 +95,20 @@ struct FastDivModForPoolingWithMoreStaff {
stride_h(stride_height) {}
};

static __device__ inline int p_start(int size,
int pad,
int kernel,
int stride) {
return (size + pad < kernel) ? 0 : (size + pad - kernel) / stride + 1;
}

static __device__ inline int p_end(int size,
int pad,
int pooled_size,
int stride) {
return std::min((size + pad) / stride + 1, pooled_size);
}

template <typename FastDivModForPooling, typename IndexT>
__device__ void OffsetPreparationFor4Dimension(IndexT index,
bool channel_last,
Expand Down Expand Up @@ -474,6 +491,56 @@ __global__ void KernelMaxPool2DGrad(const IndexT nthreads,
}
}

template <typename T, typename IndexT>
__global__ void KernelMaxPool2DGradCompatible(
const T* input_data,
const T* output_data,
const T* output_grad,
const IndexT batch_size,
const IndexT channels,
const IndexT input_height,
const IndexT input_width,
const IndexT output_height,
const IndexT output_width,
const IndexT ksize_height,
const IndexT ksize_width,
const IndexT stride_height,
const IndexT stride_width,
const IndexT padding_height,
const IndexT padding_width,
T* input_grad,
FastDivModForPooling<IndexT> divmods,
bool channel_last = false) {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;

CUDA_KERNEL_LOOP(index, input_height * input_width) {
IndexT h = index / input_width;
IndexT w = index - h * input_width;
IndexT phstart = p_start(h, padding_height, ksize_height, stride_height);
IndexT phend = p_end(h, padding_height, output_height, stride_height);
IndexT pwstart = p_start(w, padding_width, ksize_width, stride_width);
IndexT pwend = p_end(w, padding_width, output_width, stride_width);
T input_data_value = input_data[h * input_width + w];
for (IndexT n = blockIdx.y; n < batch_size; n += gridDim.y) {
for (IndexT c = blockIdx.z; c < channels; c += gridDim.z) {
MPType gradient = static_cast<MPType>(0.0f);
IndexT offset = (n * channels + c) * output_height * output_width;
for (int ph = phstart; ph < phend; ++ph) {
for (int pw = pwstart; pw < pwend; ++pw) {
T output_data_value = output_data[ph * output_width + pw + offset];
if (output_data_value == input_data_value) {
gradient += static_cast<MPType>(
output_grad[ph * output_width + pw + offset]);
}
}
}
input_grad[(n * channels + c) * input_height * input_width + index] =
static_cast<MPType>(gradient);
}
}
}
}

template <typename PoolProcess, typename T>
void Pool2dDirectCUDAFunctor<PoolProcess, T>::operator()(
const T* input,
Expand Down Expand Up @@ -879,6 +946,8 @@ class MaxPool2dGradFunctor<phi::GPUContext, T> {
const std::vector<int64_t>& paddings,
const std::string data_format,
DenseTensor* input_grad) {
static const int kBlockThreads = 1024;

bool channel_last = (data_format == "NHWC");

const int64_t batch_size = input.dims()[0];
Expand Down Expand Up @@ -913,55 +982,112 @@ class MaxPool2dGradFunctor<phi::GPUContext, T> {

int64_t nthreads =
batch_size * output_channels * output_height * output_width;
int64_t blocks = (nthreads + 1024 - 1) / 1024;
dim3 threads(1024, 1);
dim3 grid(blocks, 1);
dim3 threads(kBlockThreads, 1);

if (input.numel() <= std::numeric_limits<int>::max() &&
output.numel() <= std::numeric_limits<int>::max()) {
auto pool_divmods = FastDivModForPooling<int>(
input_channels, output_width, output_height);
KernelMaxPool2DGrad<T, int>
<<<grid, threads, 0, dev_ctx.stream()>>>(nthreads,
input_data,
output_data,
output_grad_data,
input_channels,
input_height,
input_width,
output_height,
output_width,
ksize_height,
ksize_width,
stride_height,
stride_width,
padding_height,
padding_width,
input_grad_data,
pool_divmods,
channel_last);
if (FLAGS_torch_compatible_kernel) {
int64_t blocks =
(input_width * input_height + kBlockThreads - 1) / kBlockThreads;
dim3 grid(blocks, batch_size, input_channels);
// NOTE: input.numel() <= std::numeric_limits<int>::max() &&
// output.numel() <= std::numeric_limits<int>::max()
KernelMaxPool2DGradCompatible<T, int>
<<<grid, threads, 0, dev_ctx.stream()>>>(input_data,
output_data,
output_grad_data,
batch_size,
input_channels,
input_height,
input_width,
output_height,
output_width,
ksize_height,
ksize_width,
stride_height,
stride_width,
padding_height,
padding_width,
input_grad_data,
pool_divmods,
channel_last);
} else {
int64_t blocks = (nthreads + kBlockThreads - 1) / kBlockThreads;
dim3 grid(blocks, 1);
// NOTE: input.numel() <= std::numeric_limits<int>::max() &&
// output.numel() <= std::numeric_limits<int>::max()
KernelMaxPool2DGrad<T, int>
<<<grid, threads, 0, dev_ctx.stream()>>>(nthreads,
input_data,
output_data,
output_grad_data,
input_channels,
input_height,
input_width,
output_height,
output_width,
ksize_height,
ksize_width,
stride_height,
stride_width,
padding_height,
padding_width,
input_grad_data,
pool_divmods,
channel_last);
}

} else {
auto pool_divmods = FastDivModForPooling<int64_t>(
input_channels, output_width, output_height);
KernelMaxPool2DGrad<T, int64_t>
<<<grid, threads, 0, dev_ctx.stream()>>>(nthreads,
input_data,
output_data,
output_grad_data,
input_channels,
input_height,
input_width,
output_height,
output_width,
ksize_height,
ksize_width,
stride_height,
stride_width,
padding_height,
padding_width,
input_grad_data,
pool_divmods,
channel_last);
if (FLAGS_torch_compatible_kernel) {
int64_t blocks =
(input_width * input_height + kBlockThreads - 1) / kBlockThreads;
dim3 grid(blocks, batch_size, input_channels);
KernelMaxPool2DGradCompatible<T, int64_t>
<<<grid, threads, 0, dev_ctx.stream()>>>(input_data,
output_data,
output_grad_data,
batch_size,
input_channels,
input_height,
input_width,
output_height,
output_width,
ksize_height,
ksize_width,
stride_height,
stride_width,
padding_height,
padding_width,
input_grad_data,
pool_divmods,
channel_last);
} else {
int64_t blocks = (nthreads + kBlockThreads - 1) / kBlockThreads;
dim3 grid(blocks, 1);
KernelMaxPool2DGrad<T, int64_t>
<<<grid, threads, 0, dev_ctx.stream()>>>(nthreads,
input_data,
output_data,
output_grad_data,
input_channels,
input_height,
input_width,
output_height,
output_width,
ksize_height,
ksize_width,
stride_height,
stride_width,
padding_height,
padding_width,
input_grad_data,
pool_divmods,
channel_last);
}
}
}
};
Expand Down
7 changes: 7 additions & 0 deletions test/legacy_test/test_pool2d_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,6 +769,13 @@ def test_pool2d_static(self):
self.check_lp_float16_static(place)
paddle.disable_static()

def test_torch_compatible(self):
paddle.set_flags({'FLAGS_torch_compatible_kernel': 1})
paddle.enable_static()
for place in self.places:
self.check_max_static_results(place)
paddle.disable_static()

def test_pool2d(self):
for place in self.places:
self.check_max_dygraph_results(place)
Expand Down