Skip to content

Commit 8ae12ac

Browse files
authored
【Hackathon 9th No.19】Fix put_along_axis assign with last index to same dst (#74854)
* Fix put_along_axis assign with last index with same dst * Fix typos
1 parent 1b44b2b commit 8ae12ac

File tree

1 file changed

+90
-8
lines changed

1 file changed

+90
-8
lines changed

paddle/phi/kernels/funcs/gather_scatter_functor.cu

Lines changed: 90 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,78 @@ __global__ void ScatterMeanGPUKernel(tensor_t* self_data,
371371
}
372372
}
373373

374+
__device__ __forceinline__ void decompose_tid(int64_t tid,
375+
int64_t select_dim_size,
376+
int64_t outer_dim_size,
377+
int64_t* i,
378+
int64_t* j,
379+
int64_t* k) {
380+
const int64_t ij_span = select_dim_size * outer_dim_size;
381+
*i = tid / ij_span;
382+
const int64_t r = tid % ij_span;
383+
*j = r / outer_dim_size;
384+
*k = r % outer_dim_size;
385+
}
386+
387+
template <typename index_t>
388+
__global__ void PickWinnersScatterKernel(const index_t* __restrict__ index_data,
389+
int64_t select_dim_size,
390+
int64_t self_select_dim_size,
391+
int64_t /*src_select_dim_size*/,
392+
int64_t /*inner_dim_size*/,
393+
int64_t outer_dim_size,
394+
int64_t outer_dim_size_self,
395+
int64_t /*outer_dim_size_src*/,
396+
int64_t n,
397+
int* __restrict__ winners) {
398+
const int64_t tid = blockIdx.x * (int64_t)blockDim.x + threadIdx.x;
399+
if (tid >= n) return;
400+
401+
int64_t i, j, k;
402+
decompose_tid(tid, select_dim_size, outer_dim_size, &i, &j, &k);
403+
404+
index_t idx = index_data[tid];
405+
if (idx < 0) idx += static_cast<index_t>(self_select_dim_size);
406+
const int64_t dst = k + static_cast<int64_t>(idx) * outer_dim_size_self +
407+
i * outer_dim_size_self * self_select_dim_size;
408+
409+
atomicMax(&winners[dst], static_cast<int>(tid));
410+
}
411+
412+
template <typename tensor_t, typename index_t, typename func_t>
413+
__global__ void ScatterWriteByWinnersKernel(
414+
tensor_t* __restrict__ self_data,
415+
const index_t* __restrict__ index_data,
416+
tensor_t* __restrict__ src_data,
417+
int64_t select_dim_size,
418+
int64_t self_select_dim_size,
419+
int64_t src_select_dim_size,
420+
int64_t /*inner_dim_size*/,
421+
int64_t outer_dim_size,
422+
int64_t outer_dim_size_self,
423+
int64_t outer_dim_size_src,
424+
int64_t n,
425+
func_t reduce_op,
426+
const int* __restrict__ winners) {
427+
const int64_t tid = blockIdx.x * (int64_t)blockDim.x + threadIdx.x;
428+
if (tid >= n) return;
429+
430+
int64_t i, j, k;
431+
decompose_tid(tid, select_dim_size, outer_dim_size, &i, &j, &k);
432+
433+
index_t idx = index_data[tid];
434+
if (idx < 0) idx += static_cast<index_t>(self_select_dim_size);
435+
436+
const int64_t dst = k + static_cast<int64_t>(idx) * outer_dim_size_self +
437+
i * outer_dim_size_self * self_select_dim_size;
438+
439+
const int64_t src_off =
440+
k + j * outer_dim_size_src + i * outer_dim_size_src * src_select_dim_size;
441+
if (static_cast<int>(tid) == winners[dst]) {
442+
reduce_op(self_data + dst, src_data + src_off);
443+
}
444+
}
445+
374446
template <typename tensor_t,
375447
typename index_t = int64_t,
376448
bool is_scatter_like = true>
@@ -422,25 +494,35 @@ struct gpu_gather_scatter_functor {
422494
DenseTensor shared_mem_tensor;
423495
if (method_name == "scatter_assign_gpu") {
424496
shared_mem_tensor.Resize({self_size});
425-
dev_ctx.Alloc<int>(&shared_mem_tensor);
497+
auto* winners = dev_ctx.Alloc<int>(&shared_mem_tensor);
426498
phi::funcs::set_constant(dev_ctx, &shared_mem_tensor, 0);
427-
428-
int* shared_mem = shared_mem_tensor.data<int>();
429-
ScatterAssignGPUKernel<tensor_t, index_t, func_t, is_scatter_like>
499+
// Stage 1: Get the last index to be assigned the same dst.
500+
PickWinnersScatterKernel<index_t>
501+
<<<grid, block, 0, stream>>>(index_data,
502+
select_dim_size,
503+
self_select_dim_size,
504+
src_select_dim_size,
505+
inner_dim_size,
506+
outer_dim_size,
507+
outer_dim_size_self,
508+
outer_dim_size_src,
509+
n,
510+
winners);
511+
// Stage 2: Only the max tid in stage 1 can write src to dst.
512+
ScatterWriteByWinnersKernel<tensor_t, index_t, func_t>
430513
<<<grid, block, 0, stream>>>(self_data,
431-
dim,
432514
index_data,
433515
src_data,
434516
select_dim_size,
435517
self_select_dim_size,
436518
src_select_dim_size,
519+
inner_dim_size,
437520
outer_dim_size,
438521
outer_dim_size_self,
439522
outer_dim_size_src,
440-
index_size,
441-
self_size,
523+
n,
442524
reduce_op,
443-
shared_mem);
525+
winners);
444526
} else if (method_name == "scatter_mean_gpu") {
445527
shared_mem_tensor.Resize({self_size * 2});
446528
dev_ctx.Alloc<int>(&shared_mem_tensor);

0 commit comments

Comments
 (0)