-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optimization for StackGradCUDAKernel for last dimension stack case. #48992
Changes from all commits
c9cf60d
697577c
3c12404
22020d4
da02e2d
dfdd57c
3cd9507
654dba8
704e2fa
834973d
2e2ed67
bd16efe
b79b243
1f2665c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,15 +13,13 @@ | |
// limitations under the License. | ||
|
||
#include "paddle/phi/kernels/stack_grad_kernel.h" | ||
|
||
#include "paddle/fluid/memory/memory.h" | ||
#include "paddle/phi/backends/gpu/gpu_context.h" | ||
#include "paddle/phi/backends/gpu/gpu_launch_config.h" | ||
#include "paddle/phi/core/kernel_registry.h" | ||
|
||
namespace phi { | ||
|
||
template <typename T, typename IntType> | ||
template <typename T, typename IndexT> | ||
__global__ void UnStackHelperCUDAKernel(const T* __restrict__ input, | ||
int pre_dim_size, | ||
int split_dim_size, | ||
|
@@ -33,101 +31,152 @@ __global__ void UnStackHelperCUDAKernel(const T* __restrict__ input, | |
// In this case they are equal | ||
assert(split_dim_size % num_split == 0); | ||
|
||
IntType size = pre_dim_size * split_dim_size * suf_dim_size; | ||
IntType each_dim_size = split_dim_size / num_split; | ||
IndexT size = pre_dim_size * split_dim_size * suf_dim_size; | ||
IndexT each_dim_size = split_dim_size / num_split; | ||
|
||
for (IntType offset = blockIdx.x * blockDim.x + threadIdx.x; offset < size; | ||
for (IndexT offset = blockIdx.x * blockDim.x + threadIdx.x; offset < size; | ||
offset += blockDim.x * gridDim.x) { | ||
IntType i = offset / (split_dim_size * suf_dim_size); | ||
IntType j = (offset % (split_dim_size * suf_dim_size)) / suf_dim_size; | ||
IntType k = offset % suf_dim_size; | ||
IndexT i = offset / (split_dim_size * suf_dim_size); | ||
IndexT j = (offset % (split_dim_size * suf_dim_size)) / suf_dim_size; | ||
IndexT k = offset % suf_dim_size; | ||
|
||
T* output = output_ptrs[j / each_dim_size]; | ||
if (output == nullptr) { | ||
return; | ||
} | ||
IntType output_ind = i * each_dim_size * suf_dim_size + | ||
(j % each_dim_size) * suf_dim_size + k; | ||
IndexT output_ind = i * each_dim_size * suf_dim_size + | ||
(j % each_dim_size) * suf_dim_size + k; | ||
*(output + output_ind) = input[offset]; | ||
} | ||
} | ||
|
||
template <typename T, typename Context> | ||
void StackGradKernel(const Context& dev_ctx, | ||
const DenseTensor& out, | ||
int axis, | ||
std::vector<DenseTensor*> x_grad) { | ||
if (axis < 0) axis += out.dims().size(); | ||
|
||
int n = out.dims()[axis]; | ||
PADDLE_ENFORCE_EQ(n, | ||
x_grad.size(), | ||
phi::errors::InvalidArgument( | ||
"Output x_grad size should be equal to n, but" | ||
" received n is:%d x_grad size is:%d.", | ||
n, | ||
x_grad.size())); | ||
|
||
// x_grad is output, so save each data address, then copy each dy into dx_data | ||
std::vector<T*> outputs(n); | ||
for (size_t j = 0; j < x_grad.size(); ++j) { | ||
if (x_grad[j] == nullptr) { | ||
outputs[j] = nullptr; | ||
continue; | ||
template <typename T, typename IndexT> | ||
__global__ void StackGradKernelForLastDim(const T* __restrict__ in_data, | ||
const IndexT cols, | ||
const IndexT rows, | ||
const IndexT tile_x_num, | ||
T** out_datas) { | ||
constexpr int buffer_size = 512; | ||
__shared__ T s_buf[buffer_size]; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里block的线程数最大是512?如果block大小是变的 可以考虑动态shared memory size There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 设置的 block大小是可变的,最大值为512,所以在 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok |
||
|
||
for (IndexT tile_x = blockIdx.x; tile_x < tile_x_num; tile_x += gridDim.x) { | ||
IndexT row_idx = tile_x * blockDim.x + threadIdx.x; | ||
IndexT col_idx = blockIdx.y * blockDim.y + threadIdx.y; | ||
int s_idx = threadIdx.y * blockDim.x + threadIdx.x; | ||
bool is_valid = (col_idx < cols && row_idx < rows); | ||
|
||
if (is_valid) { | ||
T data = in_data[row_idx * cols + col_idx]; | ||
s_buf[s_idx] = data; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里可以参考transpose写法,连续读,写shared memory的时候进行转置。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这部分的优化是顺带完成的,读取部分没有实现连续访存读取需要进一步性能优化;不过,暂时不是影响模型的核心原因,所以进一步优化,在这一阶段不会继续下去。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 好的 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. auto col_divmod = phi::funcs::FastDivMod(out_num);
template <typename T, typename IndexT>
__global__ void StackGradKernelForLastDim(const T* __restrict__ in_data,
const int cols,
const IndexT rows,
const IndexT tile_x_num,
const IndexT numel,
const phi::funcs::FastDivMod divmoder,
T** out_datas) {
__shared__ T s_buf[1024];
int share_stride = blockDim.x + 1;
for (IndexT tile_x = blockIdx.x; tile_x < tile_x_num; tile_x += gridDim.x) {
int tid_in_block = threadIdx.y * blockDim.x + threadIdx.x;
auto result = divmoder.Divmod(tid_in_block);
IndexT tid = tile_x * blockDim.x * blockDim.y + tid_in_block;
if (tid < numel) {
int share_idx = result[1] * share_stride + result[0];
s_buf[share_idx] = in_data[tid];
}
IndexT row_idx = tile_x * blockDim.x + threadIdx.x;
int col_idx = blockIdx.y * blockDim.y + threadIdx.y;
__syncthreads();
if (col_idx < cols && row_idx < rows) {
int share_idx = threadIdx.y * share_stride + threadIdx.x;
if (out_datas[col_idx] != nullptr) {
out_datas[col_idx][row_idx] = s_buf[share_idx];
}
}
}
} 完成了这里的修改后,性能并没有明显的提升,反而相对合入的kernel的性能有所下降。但是这个kernel 已经实现了读写的访存连续。
两者叠加造成了目前的测试结果,当然我后面也需要深入看看ncu 把这个问题搞清楚,感谢 @zkh2016 指教. |
||
} | ||
if (x_grad[j]->numel() != 0UL) { | ||
T* ptr = dev_ctx.template Alloc<T>(x_grad[j]); | ||
outputs[j] = ptr; | ||
} else { | ||
outputs[j] = nullptr; | ||
__syncthreads(); | ||
if (is_valid) { | ||
if (out_datas[col_idx] != nullptr) { | ||
out_datas[col_idx][row_idx] = s_buf[s_idx]; | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. s_buf[s_idx];好像没啥作用呀,为啥不直接将data赋值给out_data,如果直接赋值那sync也需要去掉吧。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这块想着是做访存合并,把数据全部读取到连续,存入转置的share_memory中,然后再连续读出来。 |
||
} | ||
} | ||
auto dy_data = out.data<T>(); | ||
// each x_grad should have same shape | ||
int dy_pre = 1, dy_suf = 1; | ||
auto dy_dims = out.dims(); | ||
int split_dim = n; | ||
for (int i = 0; i < axis; ++i) { | ||
dy_pre *= dy_dims[i]; | ||
} | ||
|
||
template <typename Context, typename T, typename IndexT> | ||
void LaunchStackGradCUDAKernel(const Context& ctx, | ||
const DenseTensor& out, | ||
std::vector<DenseTensor*>* x_grad_ptr, | ||
const int axis, | ||
const int64_t dy_pre) { | ||
auto x_grad = *x_grad_ptr; | ||
int out_num = out.dims()[axis]; | ||
PADDLE_ENFORCE_EQ( | ||
out_num, | ||
x_grad.size(), | ||
phi::errors::InvalidArgument( | ||
"Output x_grad size shall be equal to output num, but output num " | ||
"received in stack_grad op is:%d, and x_grad size is:%d.", | ||
out_num, | ||
x_grad.size())); | ||
std::vector<T*> outputs(out_num); | ||
for (size_t j = 0; j < out_num; ++j) { | ||
if (x_grad[j] == nullptr || x_grad[j]->numel() == 0UL) { | ||
outputs[j] = nullptr; | ||
} else { | ||
outputs[j] = ctx.template Alloc<T>(x_grad[j]); | ||
} | ||
} | ||
dy_suf = out.numel() / (split_dim * dy_pre); | ||
|
||
auto tmp_out_data = paddle::memory::Alloc( | ||
dev_ctx.GetPlace(), | ||
outputs.size() * sizeof(T*), | ||
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream()))); | ||
paddle::memory::Copy(dev_ctx.GetPlace(), | ||
ctx.GetPlace(), | ||
out_num * sizeof(T*), | ||
phi::Stream(reinterpret_cast<phi::StreamId>(ctx.stream()))); | ||
paddle::memory::Copy(ctx.GetPlace(), | ||
tmp_out_data->ptr(), | ||
phi::CPUPlace(), | ||
reinterpret_cast<void*>(outputs.data()), | ||
outputs.size() * sizeof(T*), | ||
dev_ctx.stream()); | ||
|
||
auto config = phi::backends::gpu::GetGpuLaunchConfig1D( | ||
dev_ctx, dy_pre * split_dim * dy_suf); | ||
|
||
if (out.numel() < std::numeric_limits<int32_t>::max()) { | ||
UnStackHelperCUDAKernel<T, int32_t> | ||
<<<config.block_per_grid.x, | ||
config.thread_per_block.x, | ||
0, | ||
dev_ctx.stream()>>>(dy_data, | ||
dy_pre, | ||
split_dim, | ||
dy_suf, | ||
split_dim, | ||
reinterpret_cast<T**>(tmp_out_data->ptr())); | ||
out_num * sizeof(T*), | ||
ctx.stream()); | ||
|
||
if (axis == (out.dims().size() - 1)) { | ||
constexpr int kThreads = 512; | ||
constexpr int kWarpSize = 32; | ||
constexpr int kMaxOut = 16; | ||
int tid_x = 0, tid_y = 0, bid_x = 0, bid_y = 1; | ||
bool is_small_num = out_num < kMaxOut; | ||
|
||
if (is_small_num) { | ||
tid_y = out_num; | ||
tid_x = | ||
std::min(backends::gpu::RoundToNextHighPowOfTwo(dy_pre, kWarpSize), | ||
kThreads / backends::gpu::RoundToNextHighPowOfTwo(tid_y)); | ||
} else { | ||
tid_y = kMaxOut; | ||
tid_x = kWarpSize; | ||
bid_y = backends::gpu::DivUp<int>(out_num, kMaxOut); | ||
} | ||
int tile_x_num = backends::gpu::DivUp<int>(dy_pre, tid_x); | ||
bid_x = std::min(tile_x_num, backends::gpu::kMultiDimslimit); | ||
dim3 blocks(tid_x, tid_y, 1); | ||
dim3 grids(bid_x, bid_y, 1); | ||
|
||
StackGradKernelForLastDim<T, IndexT><<<grids, blocks, 0, ctx.stream()>>>( | ||
out.data<T>(), | ||
out_num, | ||
dy_pre, | ||
tile_x_num, | ||
reinterpret_cast<T**>(tmp_out_data->ptr())); | ||
} else { | ||
int dy_suf = out.numel() / (out_num * dy_pre); | ||
auto config = | ||
backends::gpu::GetGpuLaunchConfig1D(ctx, dy_pre * out_num * dy_suf); | ||
|
||
UnStackHelperCUDAKernel<T, IndexT> | ||
<<<config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>( | ||
out.data<T>(), | ||
dy_pre, | ||
out_num, | ||
dy_suf, | ||
out_num, | ||
reinterpret_cast<T**>(tmp_out_data->ptr())); | ||
} | ||
} | ||
|
||
template <typename T, typename Context> | ||
void StackGradKernel(const Context& dev_ctx, | ||
const DenseTensor& out, | ||
int axis, | ||
std::vector<DenseTensor*> x_grad) { | ||
const auto& dy_dims = out.dims(); | ||
int actual_axis = axis < 0 ? axis + dy_dims.size() : axis; | ||
bool use_int32 = out.numel() < std::numeric_limits<int32_t>::max(); | ||
|
||
int64_t dy_pre = 1; | ||
for (int i = 0; i < actual_axis; ++i) { | ||
dy_pre *= dy_dims[i]; | ||
} | ||
if (use_int32) { | ||
LaunchStackGradCUDAKernel<Context, T, int32_t>( | ||
dev_ctx, out, &x_grad, actual_axis, dy_pre); | ||
} else { | ||
UnStackHelperCUDAKernel<T, int64_t> | ||
<<<config.block_per_grid.x, | ||
config.thread_per_block.x, | ||
0, | ||
dev_ctx.stream()>>>(dy_data, | ||
dy_pre, | ||
split_dim, | ||
dy_suf, | ||
split_dim, | ||
reinterpret_cast<T**>(tmp_out_data->ptr())); | ||
LaunchStackGradCUDAKernel<Context, T, int64_t>( | ||
dev_ctx, out, &x_grad, actual_axis, dy_pre); | ||
} | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个到底是什么的limit,变量命名上看不出来,后面再优化下,建议如
kMaxGridSize
。There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的,这个是想表达多维线程设置情况下,每个线程的设置值的上限.