Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add maximum limit for grid of index_select #41127

Merged
merged 10 commits into from
Apr 3, 2022
8 changes: 8 additions & 0 deletions paddle/fluid/platform/device/gpu/gpu_launch_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,14 @@ inline GpuLaunchConfig GetGpuLaunchConfig2D(
return config;
}

template <typename Context>
void LimitGridDim(const Context& ctx, dim3* grid_dim) {
auto max_grid_dim = reinterpret_cast<const platform::CUDADeviceContext&>(ctx)
.GetCUDAMaxGridDimSize();
grid_dim->x = grid_dim->x < max_grid_dim[0] ? grid_dim->x : max_grid_dim[0];
grid_dim->y = grid_dim->y < max_grid_dim[1] ? grid_dim->y : max_grid_dim[1];
grid_dim->z = grid_dim->z < max_grid_dim[2] ? grid_dim->z : max_grid_dim[2];
}
} // namespace platform
} // namespace paddle

Expand Down
44 changes: 18 additions & 26 deletions paddle/phi/kernels/funcs/elementwise_grad_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ limitations under the License. */
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/phi/kernels/primitive/kernel_primitives.h"

#endif
Expand All @@ -49,14 +50,6 @@ namespace phi {
namespace funcs {
using DDim = phi::DDim;

template <typename T>
void LimitGridDim(const GPUContext &ctx, T *grid_dim) {
auto max_grid_dim = ctx.GetCUDAMaxGridDimSize()[0];
if (*grid_dim > max_grid_dim) {
*grid_dim = max_grid_dim;
}
}

template <typename T, typename DX_OP, typename DY_OP, typename Tout = T>
void CommonGradBroadcastCPU(const DenseTensor &x,
const DenseTensor &y,
Expand Down Expand Up @@ -978,17 +971,17 @@ static void ElemwiseGradBroadcast1CUDA(gpuStream_t stream,
constexpr int half_walf = 16;
if (w < half_walf || h < half_walf) {
int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, h);
int gird_size = w;
ElemwiseGradBroadcast1CUDAKernel<<<gird_size, block_size, 0, stream>>>(
int grid_size = w;
ElemwiseGradBroadcast1CUDAKernel<<<grid_size, block_size, 0, stream>>>(
x, y, out, dout, h, w, is_xsize_larger, dx_op, dy_op, dx, dy);
} else {
// suppose perfoemance improves with h increased.
dim3 block_size = dim3(BLOCK_X, BLOCK_Y);
int grid_size = (w + BLOCK_X - 1) / BLOCK_X;
dim3 grid_size = dim3((w + BLOCK_X - 1) / BLOCK_X);
auto gplace = phi::GPUPlace();
auto *ctx = static_cast<GPUContext *>(
paddle::platform::DeviceContextPool::Instance().Get(gplace));
LimitGridDim(*ctx, &grid_size);
paddle::platform::LimitGridDim(*ctx, &grid_size);
FastElemwiseGradBroadcast1CUDAKernel<<<grid_size, block_size, 0, stream>>>(
x, y, out, dout, h, w, is_xsize_larger, dx_op, dy_op, dx, dy);
}
Expand All @@ -1009,13 +1002,12 @@ static void ElemwiseGradBroadcast2CUDA(gpuStream_t stream,
T *dx,
T *dy) {
int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, pre * post);
int gird_size = n;
int grid_size = n;
dim3 grid_size = dim3(n);
auto gplace = phi::GPUPlace();
auto *ctx = static_cast<GPUContext *>(
paddle::platform::DeviceContextPool::Instance().Get(gplace));
LimitGridDim(*ctx, &grid_size);
ElemwiseGradBroadcast2CUDAKernel<<<gird_size, block_size, 0, stream>>>(
paddle::platform::LimitGridDim(*ctx, &grid_size);
ElemwiseGradBroadcast2CUDAKernel<<<grid_size, block_size, 0, stream>>>(
x, y, out, dout, pre, n, post, is_xsize_larger, dx_op, dy_op, dx, dy);
}

Expand Down Expand Up @@ -1216,8 +1208,8 @@ void CommonGradBroadcastCUDA(const DenseTensor &x,
is_y);
} else {
dim3 block_size = dim3(BLOCK_X, BLOCK_Y);
int grid_size = (w + BLOCK_X - 1) / BLOCK_X;
LimitGridDim(ctx, &grid_size);
dim3 grid_size = dim3((w + BLOCK_X - 1) / BLOCK_X);
paddle::platform::LimitGridDim(ctx, &grid_size);
FastCommonGradBroadcastCUDAKernelHeight<<<grid_size,
block_size,
0,
Expand Down Expand Up @@ -1253,8 +1245,8 @@ void CommonGradBroadcastCUDA(const DenseTensor &x,
is_y);
} else {
dim3 block_size = dim3(BLOCK_X, BLOCK_Y);
int grid_size = (w + BLOCK_X - 1) / BLOCK_X;
LimitGridDim(ctx, &grid_size);
dim3 grid_size = dim3((w + BLOCK_X - 1) / BLOCK_X);
paddle::platform::LimitGridDim(ctx, &grid_size);
FastCommonGradBroadcastCUDAKernelHeight<<<grid_size,
block_size,
0,
Expand Down Expand Up @@ -1350,8 +1342,8 @@ void CommonGradBroadcastCUDA(const DenseTensor &x,
<< " post:" << post;

int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, mid);
int grid_size = pre * post;
LimitGridDim(ctx, &grid_size);
dim3 grid_size = dim3(pre * post);
paddle::platform::LimitGridDim(ctx, &grid_size);

FastCommonGradBroadcastAllCUDAKernel<<<grid_size, block_size, 0, stream>>>(
x_data,
Expand Down Expand Up @@ -1392,8 +1384,8 @@ void CommonGradBroadcastCUDA(const DenseTensor &x,
1,
std::multiplies<int>());
int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, mid);
int grid_size = pre * post;
LimitGridDim(ctx, &grid_size);
dim3 grid_size = dim3(pre * post);
paddle::platform::LimitGridDim(ctx, &grid_size);
// we need to calc y offset with blockid, so do x_pre/y_pre to get left
// size.
if (k_pre != pre) k_pre = pre / k_pre;
Expand Down Expand Up @@ -1423,8 +1415,8 @@ void CommonGradBroadcastCUDA(const DenseTensor &x,
1,
std::multiplies<int>());
int block_size = std::min(ELEMWISE_MAX_BLOCK_DIM, mid);
int grid_size = pre * post;
LimitGridDim(ctx, &grid_size);
dim3 grid_size = dim3(pre * post);
paddle::platform::LimitGridDim(ctx, &grid_size);
if (k_pre != pre) k_pre = pre / k_pre;

FastCommonGradBroadcastOneCUDAKernel<<<grid_size,
Expand Down
16 changes: 4 additions & 12 deletions paddle/phi/kernels/funcs/reduce_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ namespace cub = hipcub;
#endif

#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/fast_divmod.h"
#include "paddle/phi/api/ext/dispatch.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
Expand Down Expand Up @@ -309,7 +310,7 @@ struct ReduceConfig {
: reduce_dims_origin(origin_reduce_dims), x_dim(origin_x_dim) {}

// get the parameters of reduceKernel
void Run(const paddle::platform::Place& place) {
void Run(const KPDevice& dev_ctx) {
// step1: update the reduce_dim left_dim and x_dim
SetReduceDim();

Expand All @@ -323,7 +324,7 @@ struct ReduceConfig {
SetBlockDim();

// step5: limit the grid to prevent thead overflow
LimitGridDim(place);
paddle::platform::LimitGridDim(dev_ctx, &grid);
}

// when should_reduce_again is true, we need malloc temp space for temp data
Expand Down Expand Up @@ -607,15 +608,6 @@ struct ReduceConfig {
grid = grid_dim;
}

void LimitGridDim(const paddle::platform::Place& place) {
auto* ctx = static_cast<paddle::platform::CUDADeviceContext*>(
paddle::platform::DeviceContextPool::Instance().Get(place));
std::array<int, 3> max_grid_dim = ctx->GetCUDAMaxGridDimSize();
grid.x = grid.x < max_grid_dim[0] ? grid.x : max_grid_dim[0];
grid.y = grid.y < max_grid_dim[1] ? grid.y : max_grid_dim[1];
grid.z = grid.z < max_grid_dim[2] ? grid.z : max_grid_dim[2];
}

public:
std::vector<int> reduce_dims_origin;
std::vector<int> reduce_dim;
Expand Down Expand Up @@ -1072,7 +1064,7 @@ void ReduceKernel(const KPDevice& dev_ctx,

auto x_dim = phi::vectorize<int>(x.dims());
auto config = ReduceConfig<Ty>(origin_reduce_dims, x_dim);
config.Run(x.place());
config.Run(dev_ctx);
int numel = x.numel();
// after config.run()
// SetOutputData for ReduceHigherDim when should_reduce_again is true,
Expand Down
9 changes: 1 addition & 8 deletions paddle/phi/kernels/gpu/index_sample_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,6 @@
namespace phi {

namespace {
template <typename Context>
void LimitGridDim(const Context& ctx, dim3* grid_dim) {
auto max_grid_dim =
reinterpret_cast<const phi::GPUContext&>(ctx).GetCUDAMaxGridDimSize();
grid_dim->x = grid_dim->x < max_grid_dim[0] ? grid_dim->x : max_grid_dim[0];
grid_dim->y = grid_dim->y < max_grid_dim[1] ? grid_dim->y : max_grid_dim[1];
}
#define PREDEFINED_BLOCK_SIZE_X 512
#define PREDEFINED_BLOCK_SIZE 1024
#define MIN(a, b) ((a) < (b) ? (a) : (b))
Expand Down Expand Up @@ -107,7 +100,7 @@ void IndexSampleGradKernel(const Context& ctx,
dim3 block_dim(block_width, block_height);
dim3 grid_dim((index_length + block_dim.x - 1) / block_dim.x,
(batch_size + block_dim.y - 1) / block_dim.y);
LimitGridDim(ctx, &grid_dim);
paddle::platform::LimitGridDim(ctx, &grid_dim);

phi::funcs::SetConstant<Context, T> set_zero;
set_zero(ctx, x_grad, static_cast<T>(0));
Expand Down
9 changes: 1 addition & 8 deletions paddle/phi/kernels/gpu/index_sample_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,6 @@
namespace phi {

namespace {
template <typename Context>
void LimitGridDim(const Context& ctx, dim3* grid_dim) {
auto max_grid_dim =
reinterpret_cast<const phi::GPUContext&>(ctx).GetCUDAMaxGridDimSize();
grid_dim->x = grid_dim->x < max_grid_dim[0] ? grid_dim->x : max_grid_dim[0];
grid_dim->y = grid_dim->y < max_grid_dim[1] ? grid_dim->y : max_grid_dim[1];
}
#define PREDEFINED_BLOCK_SIZE_X 512
#define PREDEFINED_BLOCK_SIZE 1024
#define MIN(a, b) ((a) < (b) ? (a) : (b))
Expand Down Expand Up @@ -95,7 +88,7 @@ void IndexSampleKernel(const Context& ctx,
dim3 block_dim(block_width, block_height);
dim3 grid_dim((index_length + block_dim.x - 1) / block_dim.x,
(batch_size + block_dim.y - 1) / block_dim.y);
LimitGridDim(ctx, &grid_dim);
paddle::platform::LimitGridDim(ctx, &grid_dim);

if (index_type == DataType::INT64) {
const int64_t* index_data = index.data<int64_t>();
Expand Down
25 changes: 11 additions & 14 deletions paddle/phi/kernels/gpu/index_select_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "paddle/phi/kernels/index_select_grad_kernel.h"

#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/core/kernel_registry.h"
Expand All @@ -23,8 +24,6 @@ DECLARE_bool(cudnn_deterministic);

namespace phi {

using paddle::platform::PADDLE_CUDA_NUM_THREADS;

template <typename T, typename IndexT>
__global__ void index_select_grad_cuda_kernel(const T* output_grad,
T* input_grad,
Expand Down Expand Up @@ -89,25 +88,23 @@ void IndexSelectGradKernel(const Context& ctx,

auto stream = ctx.stream();

index_select_grad_init<
T><<<(numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS,
0,
stream>>>(in_grad_data, numel);
int block_dim = 256;
dim3 grid_dim = dim3((numel + block_dim - 1) / block_dim);
paddle::platform::LimitGridDim(ctx, &grid_dim);

int blocks =
(out_nums + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS;
int threads = PADDLE_CUDA_NUM_THREADS;
index_select_grad_init<T><<<grid_dim, block_dim, 0, stream>>>(in_grad_data,
numel);

if (FLAGS_cudnn_deterministic) {
VLOG(2) << "Run grad kernel of index_select with single thread.";
blocks = 1;
threads = 1;
block_dim = 1;
grid_dim.x = 1;
ZzSean marked this conversation as resolved.
Show resolved Hide resolved
}

if (index_type == phi::DataType::INT64) {
const int64_t* index_data = index.data<int64_t>();
index_select_grad_cuda_kernel<T, int64_t><<<blocks, threads, 0, stream>>>(
index_select_grad_cuda_kernel<T,
int64_t><<<grid_dim, block_dim, 0, stream>>>(
output_grad_data,
in_grad_data,
index_data,
Expand All @@ -118,7 +115,7 @@ void IndexSelectGradKernel(const Context& ctx,
delta);
} else {
const int* index_data = index.data<int>();
index_select_grad_cuda_kernel<T, int><<<blocks, threads, 0, stream>>>(
index_select_grad_cuda_kernel<T, int><<<grid_dim, block_dim, 0, stream>>>(
output_grad_data,
in_grad_data,
index_data,
Expand Down
35 changes: 15 additions & 20 deletions paddle/phi/kernels/gpu/index_select_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "paddle/phi/kernels/index_select_kernel.h"

#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/core/kernel_registry.h"
Expand All @@ -31,16 +32,14 @@ __global__ void index_select_cuda_kernel(const T* input,
int64_t stride,
int64_t size,
int64_t delta) {
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= N) {
return;
CUDA_KERNEL_LOOP(idx, N) {
int64_t pre_idx = idx / (stride * size);
int64_t dim_idx = idx % (stride * size) / stride;
IndexT src_dim_idx = index[dim_idx];
int64_t input_idx =
idx + (delta * pre_idx + src_dim_idx - dim_idx) * stride;
output[idx] = input[input_idx];
}

int64_t pre_idx = idx / (stride * size);
int64_t dim_idx = idx % (stride * size) / stride;
IndexT src_dim_idx = index[dim_idx];
int64_t input_idx = idx + (delta * pre_idx + src_dim_idx - dim_idx) * stride;
output[idx] = input[input_idx];
}

template <typename T, typename Context>
Expand Down Expand Up @@ -75,21 +74,17 @@ void IndexSelectKernel(const Context& ctx,
int64_t numel = output->numel();
auto stream = ctx.stream();

int block_dim = 256;
dim3 grid_dim = dim3((numel + block_dim - 1) / block_dim);
paddle::platform::LimitGridDim(ctx, &grid_dim);

if (index_type == phi::DataType::INT64) {
const int64_t* index_data = index.data<int64_t>();
index_select_cuda_kernel<T, int64_t><<<
(numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS,
0,
stream>>>(in_data, out_data, index_data, numel, stride, size, delta);
index_select_cuda_kernel<T, int64_t><<<grid_dim, block_dim, 0, stream>>>(
in_data, out_data, index_data, numel, stride, size, delta);
} else {
const int* index_data = index.data<int>();
index_select_cuda_kernel<
T,
int><<<(numel + PADDLE_CUDA_NUM_THREADS - 1) / PADDLE_CUDA_NUM_THREADS,
PADDLE_CUDA_NUM_THREADS,
0,
stream>>>(
index_select_cuda_kernel<T, int><<<grid_dim, block_dim, 0, stream>>>(
in_data, out_data, index_data, numel, stride, size, delta);
}
}
Expand Down