-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TOPI] GPU scatter_add using atomic #7044
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right, we don't actually need to enforce the determinism on scatter_add, within floating point error a + b is the same as b + a. On scatter, we ended up doing out[i] = a and then out[i] = b, doing those in the other order yields very different results.
What kind of perf improvement do you see with this?
Well, the comparison would be a bit embarrassing, since the current one is the worst gpu kernel ever :) Below is an excerpt from nvprof log. This result is obtained on my crappy laptop gpu, but the difference is still significant (5 ms vs 20 us). Current one
New one in this PR
|
Hey @masahi, thanks for this work. I'm wondering if you've looked at a sort then lookup approach to scatter (some references: https://www.cse.ust.hk/catalac/papers/scatter_sc07.pdf, https://developer.nvidia.com/gpugems/gpugems2/part-iv-general-purpose-computation-gpus-primer/chapter-32-taking-plunge-gpu)? You also might want to look at |
@tkonolige Thanks for the pointers. I've only improved scatter_add, I'm afraid I have no idea how to improve scatter. But yeah, I can see that if we can assume some structure on indices, we can do some parallelization on scatter too. I find this problem interesting, I'll put scatter improvement in my backlog. |
Thanks @mbrookhart @tkonolige @Laurawly |
* use atomic add for faster 1d scatter add * update tests * run black * more pylint fix * remove fp64 bintcount test Co-authored-by: masa <[email protected]>
* use atomic add for faster 1d scatter add * update tests * run black * more pylint fix * remove fp64 bintcount test Co-authored-by: masa <[email protected]>
* use atomic add for faster 1d scatter add * update tests * run black * more pylint fix * remove fp64 bintcount test Co-authored-by: masa <[email protected]>
This updates
scatter_add
CUDA schedule for more parallelism.scatter_add
GPU introduced in #6856 uses the same schedule asscatter
. From what I can tell,scatter
GPU schedule sacrifices performance for guaranteed determinism: The axis which is scattered over is updated exclusively by the same thread, so that there is no chance the same output index is updated by different threads. See for examplescatter
2D with axis = 0:tvm/python/tvm/topi/cuda/scatter.py
Lines 142 to 152 in 0421efb
Note that
ni
dimension, which is the axis being scattered over, is processed sequentially by the same thread. This kills performance ifni
dimension is large. In particular,scatter
1D implementation is not parallelized at all, it only uses a single block with a single thread (!!). Here is an example ofscatter
1D kernel code from one of pytorch frontend tests (part ofbincount
, implemented viascatter_add
).I can understand that we cannot do better if we want to guarantee determinism. But I realized that for
scatter_add
, the situation is a bit more relaxed: we can use atomic add instruction to unlock maximum parallelism, while ensuring (some) determinism. Non determinism can arise for floats, but this limitation is common in otherscatter_add
GPU implementations I looked at (pytorch, cupy, xla etc).I've only updated
scatter_add
1D schedule, since its performance is the most terrible and fixing it is my priority. Also I don't want to roll anothernum_dimension
xnum_axis
patterns of IR. Maybe if I find a clean way to implement a single implementation that handles all combination of dimensions and axes, I'll update other schedules too.please review @mbrookhart @tkonolige @zhiics @kevinthesun let me know your thoughts