Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add block and grid loop for index_sample kernel to deal with a large-shape tensor #37816

Merged
merged 4 commits into from
Jan 21, 2022
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 41 additions & 21 deletions paddle/fluid/operators/index_sample_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,16 @@ template <typename T, typename IndexT = int>
__global__ void IndexSampleForward(const IndexT* index, const T* in_data,
T* out_data, size_t index_length,
size_t input_length, size_t batch_size) {
int index_i = blockDim.x * blockIdx.x + threadIdx.x;
int index_j = blockDim.y * blockIdx.y + threadIdx.y;
int index_idx = index_j * index_length + index_i;
int in_idx = index_j * input_length + index_i;

if (index_i < index_length & index_j < batch_size) {
IndexT sample_idx = index[index_idx];
out_data[index_idx] = in_data[in_idx - index_i + sample_idx];
unsigned int index_i = blockDim.x * blockIdx.x + threadIdx.x;
unsigned int index_j = blockDim.y * blockIdx.y + threadIdx.y;
for (; index_j < batch_size; index_j += blockDim.y * gridDim.y) {
index_i = blockDim.x * blockIdx.x + threadIdx.x;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个确定不是冗余的么😂完全没必要重新计算一遍吧?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已删除,感谢~

for (; index_i < index_length; index_i += blockDim.x * gridDim.x) {
unsigned int index_idx = index_j * index_length + index_i;
unsigned int in_idx = index_j * input_length + index_i;
IndexT sample_idx = index[index_idx];
out_data[index_idx] = in_data[in_idx - index_i + sample_idx];
}
}
}

Expand All @@ -44,18 +46,21 @@ __global__ void IndexSampleGrad(const IndexT* index, T* in_grad,
const T* out_grad, size_t index_length,
size_t input_length, size_t batch_size,
bool same_data_in_row = true) {
int index_i = blockDim.x * blockIdx.x + threadIdx.x;
int index_j = blockDim.y * blockIdx.y + threadIdx.y;
int index_idx = index_j * index_length + index_i;
int in_idx = index_j * input_length + index_i;

if (index_i < index_length & index_j < batch_size) {
IndexT sample_idx = index[index_idx];
if (same_data_in_row) {
platform::CudaAtomicAdd(&(in_grad[in_idx - index_i + sample_idx]),
out_grad[sample_idx]);
} else {
in_grad[in_idx - index_i + sample_idx] = out_grad[index_idx];
unsigned int index_i = blockDim.x * blockIdx.x + threadIdx.x;
unsigned int index_j = blockDim.y * blockIdx.y + threadIdx.y;

for (; index_j < batch_size; index_j += blockDim.y * gridDim.y) {
index_i = blockDim.x * blockIdx.x + threadIdx.x;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上,index_i没必要重计算一遍

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已删除,感谢~

for (; index_i < index_length; index_i += blockDim.x * gridDim.x) {
unsigned int index_idx = index_j * index_length + index_i;
unsigned int in_idx = index_j * input_length + index_i;
IndexT sample_idx = index[index_idx];
if (same_data_in_row) {
platform::CudaAtomicAdd(&(in_grad[in_idx - index_i + sample_idx]),
out_grad[sample_idx]);
} else {
in_grad[in_idx - index_i + sample_idx] = out_grad[index_idx];
}
}
}
}
Expand Down Expand Up @@ -97,8 +102,16 @@ class IndexSampleKernel<platform::CUDADeviceContext, T>
platform::RoundToPowerOfTwo(index_length * batch_size) / block_width;

dim3 block_dim(block_width, block_height);
unsigned int threads = 512;
block_dim.x = block_dim.x < threads ? block_dim.x : threads;
block_dim.y = block_dim.y < threads ? block_dim.y : threads;
dim3 grid_dim((index_length + block_dim.x - 1) / block_dim.x,
(batch_size + block_dim.y - 1) / block_dim.y);
dim3 max_grid_dim =
ctx.template device_context<platform::CUDADeviceContext>()
.GetCUDAMaxGridDimSize();
grid_dim.x = grid_dim.x < max_grid_dim.x ? grid_dim.x : max_grid_dim.x;
grid_dim.y = grid_dim.y < max_grid_dim.y ? grid_dim.y : max_grid_dim.y;

if (index_type == framework::proto::VarType::INT64) {
const int64_t* index_data = index->data<int64_t>();
Expand Down Expand Up @@ -153,9 +166,16 @@ class IndexSampleGradKernel<platform::CUDADeviceContext, T>
auto block_height =
platform::RoundToPowerOfTwo(index_length * batch_size) / block_width;
dim3 block_dim(block_width, block_height);
unsigned int threads = 512;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

重复代码,可以提取出来:

void CheckLaunchParamValid(const framework::ExecutionContext& ctx, dim3* block_dim,  dim3* grid_dim) {
  unsigned int threads = 512;
  block_dim->x = block_dim->x < threads ? block_dim->x : threads;
  block_dim->y = block_dim->y < threads ? block_dim->y : threads;

  dim3 max_grid_dim =
        ctx.template device_context<platform::CUDADeviceContext>()
            .GetCUDAMaxGridDimSize();
  grid_dim->x = grid_dim->x < max_grid_dim.x ? grid_dim->x : max_grid_dim.x;
  grid_dim->y = grid_dim->y < max_grid_dim.y ? grid_dim->y : max_grid_dim.y;
}

然后调用

CheckLaunchParamValid(ctx, &block_dim, &grid_dim);

而非重复写两次。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

定义了函数MIN检查block dim,函数LimitGridDim检查grid dim。感谢~

block_dim.x = block_dim.x < threads ? block_dim.x : threads;
block_dim.y = block_dim.y < threads ? block_dim.y : threads;
dim3 grid_dim((index_length + block_dim.x - 1) / block_dim.x,
(batch_size + block_dim.y - 1) / block_dim.y);

dim3 max_grid_dim =
ctx.template device_context<platform::CUDADeviceContext>()
.GetCUDAMaxGridDimSize();
grid_dim.x = grid_dim.x < max_grid_dim.x ? grid_dim.x : max_grid_dim.x;
grid_dim.y = grid_dim.y < max_grid_dim.y ? grid_dim.y : max_grid_dim.y;
math::SetConstant<platform::CUDADeviceContext, T> set_zero;
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
set_zero(dev_ctx, input_grad, static_cast<T>(0));
Expand Down