@@ -371,6 +371,78 @@ __global__ void ScatterMeanGPUKernel(tensor_t* self_data,
371371 }
372372}
373373
374+ __device__ __forceinline__ void decompose_tid (int64_t tid,
375+ int64_t select_dim_size,
376+ int64_t outer_dim_size,
377+ int64_t * i,
378+ int64_t * j,
379+ int64_t * k) {
380+ const int64_t ij_span = select_dim_size * outer_dim_size;
381+ *i = tid / ij_span;
382+ const int64_t r = tid % ij_span;
383+ *j = r / outer_dim_size;
384+ *k = r % outer_dim_size;
385+ }
386+
387+ template <typename index_t >
388+ __global__ void PickWinnersScatterKernel (const index_t * __restrict__ index_data,
389+ int64_t select_dim_size,
390+ int64_t self_select_dim_size,
391+ int64_t /* src_select_dim_size*/ ,
392+ int64_t /* inner_dim_size*/ ,
393+ int64_t outer_dim_size,
394+ int64_t outer_dim_size_self,
395+ int64_t /* outer_dim_size_src*/ ,
396+ int64_t n,
397+ int * __restrict__ winners) {
398+ const int64_t tid = blockIdx .x * (int64_t )blockDim .x + threadIdx .x ;
399+ if (tid >= n) return ;
400+
401+ int64_t i, j, k;
402+ decompose_tid (tid, select_dim_size, outer_dim_size, &i, &j, &k);
403+
404+ index_t idx = index_data[tid];
405+ if (idx < 0 ) idx += static_cast <index_t >(self_select_dim_size);
406+ const int64_t dst = k + static_cast <int64_t >(idx) * outer_dim_size_self +
407+ i * outer_dim_size_self * self_select_dim_size;
408+
409+ atomicMax (&winners[dst], static_cast <int >(tid));
410+ }
411+
412+ template <typename tensor_t , typename index_t , typename func_t >
413+ __global__ void ScatterWriteByWinnersKernel (
414+ tensor_t * __restrict__ self_data,
415+ const index_t * __restrict__ index_data,
416+ tensor_t * __restrict__ src_data,
417+ int64_t select_dim_size,
418+ int64_t self_select_dim_size,
419+ int64_t src_select_dim_size,
420+ int64_t /* inner_dim_size*/ ,
421+ int64_t outer_dim_size,
422+ int64_t outer_dim_size_self,
423+ int64_t outer_dim_size_src,
424+ int64_t n,
425+ func_t reduce_op,
426+ const int * __restrict__ winners) {
427+ const int64_t tid = blockIdx .x * (int64_t )blockDim .x + threadIdx .x ;
428+ if (tid >= n) return ;
429+
430+ int64_t i, j, k;
431+ decompose_tid (tid, select_dim_size, outer_dim_size, &i, &j, &k);
432+
433+ index_t idx = index_data[tid];
434+ if (idx < 0 ) idx += static_cast <index_t >(self_select_dim_size);
435+
436+ const int64_t dst = k + static_cast <int64_t >(idx) * outer_dim_size_self +
437+ i * outer_dim_size_self * self_select_dim_size;
438+
439+ const int64_t src_off =
440+ k + j * outer_dim_size_src + i * outer_dim_size_src * src_select_dim_size;
441+ if (static_cast <int >(tid) == winners[dst]) {
442+ reduce_op (self_data + dst, src_data + src_off);
443+ }
444+ }
445+
374446template <typename tensor_t ,
375447 typename index_t = int64_t ,
376448 bool is_scatter_like = true >
@@ -422,25 +494,35 @@ struct gpu_gather_scatter_functor {
422494 DenseTensor shared_mem_tensor;
423495 if (method_name == " scatter_assign_gpu" ) {
424496 shared_mem_tensor.Resize ({self_size});
425- dev_ctx.Alloc <int >(&shared_mem_tensor);
497+ auto * winners = dev_ctx.Alloc <int >(&shared_mem_tensor);
426498 phi::funcs::set_constant (dev_ctx, &shared_mem_tensor, 0 );
427-
428- int * shared_mem = shared_mem_tensor.data <int >();
429- ScatterAssignGPUKernel<tensor_t , index_t , func_t , is_scatter_like>
499+ // Stage 1: Get the last index to be assigned the same dst.
500+ PickWinnersScatterKernel<index_t >
501+ <<<grid, block, 0 , stream>>> (index_data,
502+ select_dim_size,
503+ self_select_dim_size,
504+ src_select_dim_size,
505+ inner_dim_size,
506+ outer_dim_size,
507+ outer_dim_size_self,
508+ outer_dim_size_src,
509+ n,
510+ winners);
511+ // Stage 2: Only the max tid in stage 1 can write src to dst.
512+ ScatterWriteByWinnersKernel<tensor_t , index_t , func_t >
430513 <<<grid, block, 0 , stream>>> (self_data,
431- dim,
432514 index_data,
433515 src_data,
434516 select_dim_size,
435517 self_select_dim_size,
436518 src_select_dim_size,
519+ inner_dim_size,
437520 outer_dim_size,
438521 outer_dim_size_self,
439522 outer_dim_size_src,
440- index_size,
441- self_size,
523+ n,
442524 reduce_op,
443- shared_mem );
525+ winners );
444526 } else if (method_name == " scatter_mean_gpu" ) {
445527 shared_mem_tensor.Resize ({self_size * 2 });
446528 dev_ctx.Alloc <int >(&shared_mem_tensor);
0 commit comments