Skip to content

Conversation

@ggggxm
Copy link
Contributor

@ggggxm ggggxm commented May 14, 2025

PR Category

Operator Mechanism

PR Types

Bug fixes

Description

修改了bincount的CPU和GPU实现

  • CPU
    • 将min_length参数改为int_64类型,paddle.bincount的out_size为max(max(input) + 1, min_length)。
  • GPU
    • 将min_length参数改为int_64类型
    • 原本的求输入数据最大值最小值操作调用第三方库Eigen实现,现改为一个Cuda kernel,减少launch开销。
    • 修改原本的BinCountKernel,增加内循环,在numel特别大时,一个线程处理多个元素
  • 性能分析
    • 在测试数据size为1e8级别时,修改后实现比修改前实现提升超100倍。

@paddle-bot
Copy link

paddle-bot bot commented May 14, 2025

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot paddle-bot bot added the contributor External developers label May 14, 2025
input_max_t.Resize({1});
auto* input_max_data = dev_ctx.template Alloc<InputT>(&input_max_t);
input_min_t.Resize({1});
auto* input_min_data = dev_ctx.template Alloc<InputT>(&input_min_t);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input_min/max_data是不是没赋初值,导致kernel里面atomic没有初值;input_x是不是可以删了;还有input_min/max_data可以合在一个tensor里,减少显存分配次数,虽然影响不大

@ggggxm ggggxm force-pushed the cuda_error700 branch 2 times, most recently from bc0caca to e26970b Compare May 16, 2025 02:42
@lshpku lshpku merged commit 4a03621 into PaddlePaddle:develop May 20, 2025
49 of 50 checks passed
co63oc pushed a commit to co63oc/Paddle that referenced this pull request May 22, 2025
* fix bincount kernel for big tensor

* use HostAlloc to alloc memory

* add cpu test case
wanghuancoder pushed a commit to wanghuancoder/Paddle that referenced this pull request May 27, 2025
* fix bincount kernel for big tensor

* use HostAlloc to alloc memory

* add cpu test case
wanghuancoder added a commit that referenced this pull request Jun 3, 2025
* refine forrange (#72360)

* refine forrange

* refine forrange

* reduce support big tensor (#71970)

* reduce support big tensor

* [PHI] Fix gridDim limit for reduce kernel (#72507)

* [API] isclose support bigtensor (#72516)

* isclose support bigtensor

* refine

* [API] isnan isinf isfinite support bigtensor (#72517)

* isnan isinf isfinite support bigtensor

* refine

* [PHI] Fix cum kernel for big tensor (#72562)

* [PHI] Preliminary fix for elementwise broadcast int32 shape overflow (#72584)

* [PHI] Align linalg.solve kernel with torch (#72608)

* Update strided copy kernel (#72662)

* [PHI] Fix grid sample kernel for big tensor (#72628)

* [PHI] Fix argsort big tensor bug (#72712)

* [PHI] Fixed argsort big tensor bug

* [PHI] Fixed shape mismatch problem.

* [PHI] Fix contiguous kernel for big tensor (#72705)

* [PHI] Fix flatten and split kernel for big tensor (#72634)

* [PHI] Fix out-of-bound issue of paddle.take_along_axis (#72757)

* [PHI] fix paddle.diag with big tensor (#72638)

* [API] fix paddle.cross with big tensor (#72652)

* [PHI] Fix paddle.where api for big tensor (#72717)

* [PHI] Fix bincount kernel for big tensor (#72706)

* fix bincount kernel for big tensor

* use HostAlloc to alloc memory

* add cpu test case

* [PHI] Fix full_like kernel for big tensor (#72831)

* [API] Fix int overflow and float16 support for paddle.frac (#72815)

* [PHI] Align paddle.inner with torch in matmul logic (#72843)

* [PHI] Fix paddle.var & paddle.std float16 overflow (#72650)

* [PHI] Fix logsumexp precision problem (#72681)

* [PHI] Debug for logsumexp, bug source found

* [PHI] Removed GetNumBlocks func to get correct logsumexp

* [PHI] Removed redundant debug VLOG

* [PHI] Elegant grid bounded solution

* [Accuracy diff No.55-56、76-77] Fix accuracy diff for var&std API (#72879)

* [Accuracy diff No.21] Fix accuracy diff for heaviside API (#72894)

---------

Co-authored-by: Shuhao Liang <[email protected]>
Co-authored-by: Qianyue He <[email protected]>
Co-authored-by: Lei Ding <[email protected]>
Co-authored-by: ggggxm <[email protected]>
Co-authored-by: xkkkkkk23 <[email protected]>
Co-authored-by: Zx <[email protected]>
Co-authored-by: huangjiyi <[email protected]>
Co-authored-by: ooo oo <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

contributor External developers

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants