Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 90 additions & 8 deletions paddle/phi/kernels/funcs/gather_scatter_functor.cu
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,78 @@ __global__ void ScatterMeanGPUKernel(tensor_t* self_data,
}
}

__device__ __forceinline__ void decompose_tid(int64_t tid,
int64_t select_dim_size,
int64_t outer_dim_size,
int64_t* i,
int64_t* j,
int64_t* k) {
const int64_t ij_span = select_dim_size * outer_dim_size;
*i = tid / ij_span;
const int64_t r = tid % ij_span;
*j = r / outer_dim_size;
*k = r % outer_dim_size;
}

template <typename index_t>
__global__ void PickWinnersScatterKernel(const index_t* __restrict__ index_data,
int64_t select_dim_size,
int64_t self_select_dim_size,
int64_t /*src_select_dim_size*/,
int64_t /*inner_dim_size*/,
int64_t outer_dim_size,
int64_t outer_dim_size_self,
int64_t /*outer_dim_size_src*/,
int64_t n,
int* __restrict__ winners) {
const int64_t tid = blockIdx.x * (int64_t)blockDim.x + threadIdx.x;
if (tid >= n) return;

int64_t i, j, k;
decompose_tid(tid, select_dim_size, outer_dim_size, &i, &j, &k);

index_t idx = index_data[tid];
if (idx < 0) idx += static_cast<index_t>(self_select_dim_size);
const int64_t dst = k + static_cast<int64_t>(idx) * outer_dim_size_self +
i * outer_dim_size_self * self_select_dim_size;

atomicMax(&winners[dst], static_cast<int>(tid));
}

template <typename tensor_t, typename index_t, typename func_t>
__global__ void ScatterWriteByWinnersKernel(
tensor_t* __restrict__ self_data,
const index_t* __restrict__ index_data,
tensor_t* __restrict__ src_data,
int64_t select_dim_size,
int64_t self_select_dim_size,
int64_t src_select_dim_size,
int64_t /*inner_dim_size*/,
int64_t outer_dim_size,
int64_t outer_dim_size_self,
int64_t outer_dim_size_src,
int64_t n,
func_t reduce_op,
const int* __restrict__ winners) {
const int64_t tid = blockIdx.x * (int64_t)blockDim.x + threadIdx.x;
if (tid >= n) return;

int64_t i, j, k;
decompose_tid(tid, select_dim_size, outer_dim_size, &i, &j, &k);

index_t idx = index_data[tid];
if (idx < 0) idx += static_cast<index_t>(self_select_dim_size);

const int64_t dst = k + static_cast<int64_t>(idx) * outer_dim_size_self +
i * outer_dim_size_self * self_select_dim_size;

const int64_t src_off =
k + j * outer_dim_size_src + i * outer_dim_size_src * src_select_dim_size;
if (static_cast<int>(tid) == winners[dst]) {
reduce_op(self_data + dst, src_data + src_off);
}
}

template <typename tensor_t,
typename index_t = int64_t,
bool is_scatter_like = true>
Expand Down Expand Up @@ -422,25 +494,35 @@ struct gpu_gather_scatter_functor {
DenseTensor shared_mem_tensor;
if (method_name == "scatter_assign_gpu") {
shared_mem_tensor.Resize({self_size});
dev_ctx.Alloc<int>(&shared_mem_tensor);
auto* winners = dev_ctx.Alloc<int>(&shared_mem_tensor);
phi::funcs::set_constant(dev_ctx, &shared_mem_tensor, 0);

int* shared_mem = shared_mem_tensor.data<int>();
ScatterAssignGPUKernel<tensor_t, index_t, func_t, is_scatter_like>
// Stage 1: Get the last index to be assigned the same dst.
PickWinnersScatterKernel<index_t>
<<<grid, block, 0, stream>>>(index_data,
select_dim_size,
self_select_dim_size,
src_select_dim_size,
inner_dim_size,
outer_dim_size,
outer_dim_size_self,
outer_dim_size_src,
n,
winners);
// Stage 2: Only the max tid in stage 1 can write src to dst.
ScatterWriteByWinnersKernel<tensor_t, index_t, func_t>
<<<grid, block, 0, stream>>>(self_data,
dim,
index_data,
src_data,
select_dim_size,
self_select_dim_size,
src_select_dim_size,
inner_dim_size,
outer_dim_size,
outer_dim_size_self,
outer_dim_size_src,
index_size,
self_size,
n,
reduce_op,
shared_mem);
winners);
} else if (method_name == "scatter_mean_gpu") {
shared_mem_tensor.Resize({self_size * 2});
dev_ctx.Alloc<int>(&shared_mem_tensor);
Expand Down