Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOPI] GPU scatter_add using atomic #7044

Merged
merged 5 commits into from
Dec 7, 2020
Merged

Conversation

masahi
Copy link
Member

@masahi masahi commented Dec 7, 2020

This updates scatter_add CUDA schedule for more parallelism.

scatter_add GPU introduced in #6856 uses the same schedule as scatter. From what I can tell, scatter GPU schedule sacrifices performance for guaranteed determinism: The axis which is scattered over is updated exclusively by the same thread, so that there is no chance the same output index is updated by different threads. See for example scatter 2D with axis = 0:

if axis == 0:
with ib.new_scope():
j = te.thread_axis("blockIdx.x")
ib.scope_attr(j, "thread_extent", ci)
with ib.for_range(0, ni, name="i") as i:
idx = i * ci + j
index = indices_ptr[idx]
with ib.if_scope(index < 0):
update_func(out_ptr, (index + n) * c + j, updates_ptr[idx])
with ib.else_scope():
update_func(out_ptr, index * c + j, updates_ptr[idx])

Note that ni dimension, which is the axis being scattered over, is processed sequentially by the same thread. This kills performance if ni dimension is large. In particular, scatter 1D implementation is not parallelized at all, it only uses a single block with a single thread (!!). Here is an example of scatter 1D kernel code from one of pytorch frontend tests (part of bincount, implemented via scatter_add).

extern "C" __global__ void fused_scatter_add_kernel1(int* __restrict__ placeholder, int* __restrict__ scatter_add_gpu, int* __restrict__ placeholder1, int any_dim) {
  for (int i = 0; i < 10000; ++i) {
    if (placeholder[(i)] < 0) {
      scatter_add_gpu[((placeholder[(i)] + any_dim))] = (scatter_add_gpu[((placeholder[(i)] + any_dim))] + placeholder1[(i)]);
    } else {
      scatter_add_gpu[(placeholder[(i)])] = (scatter_add_gpu[(placeholder[(i)])] + placeholder1[(i)]);
    }
  }
}

I can understand that we cannot do better if we want to guarantee determinism. But I realized that for scatter_add, the situation is a bit more relaxed: we can use atomic add instruction to unlock maximum parallelism, while ensuring (some) determinism. Non determinism can arise for floats, but this limitation is common in other scatter_add GPU implementations I looked at (pytorch, cupy, xla etc).

I've only updated scatter_add 1D schedule, since its performance is the most terrible and fixing it is my priority. Also I don't want to roll another num_dimension x num_axis patterns of IR. Maybe if I find a clean way to implement a single implementation that handles all combination of dimensions and axes, I'll update other schedules too.

please review @mbrookhart @tkonolige @zhiics @kevinthesun let me know your thoughts

Copy link
Contributor

@mbrookhart mbrookhart left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, we don't actually need to enforce the determinism on scatter_add, within floating point error a + b is the same as b + a. On scatter, we ended up doing out[i] = a and then out[i] = b, doing those in the other order yields very different results.

What kind of perf improvement do you see with this?

tests/python/frontend/pytorch/test_forward.py Show resolved Hide resolved
@masahi
Copy link
Member Author

masahi commented Dec 7, 2020

What kind of perf improvement do you see with this?

Well, the comparison would be a bit embarrassing, since the current one is the worst gpu kernel ever :) Below is an excerpt from nvprof log. This result is obtained on my crappy laptop gpu, but the difference is still significant (5 ms vs 20 us).

Current one

Duration  Grid Size  Block Size Name
5.5576ms  (1 1 1)    (1 1 1)    fused_scatter_add_1_kernel1

New one in this PR

Duration  Grid Size  Block Size Name
22.176us  (10 1 1)   (1024 1 1) fused_scatter_add_1_kernel1

@tkonolige
Copy link
Contributor

Hey @masahi, thanks for this work. I'm wondering if you've looked at a sort then lookup approach to scatter (some references: https://www.cse.ust.hk/catalac/papers/scatter_sc07.pdf, https://developer.nvidia.com/gpugems/gpugems2/part-iv-general-purpose-computation-gpus-primer/chapter-32-taking-plunge-gpu)?

You also might want to look at scatter_nd in the codebase. It is a generalization of scatter to arbitrary dimensions. For 1D its performance probably won't be great, so maybe you could improve it too?

@masahi
Copy link
Member Author

masahi commented Dec 7, 2020

@tkonolige Thanks for the pointers. I've only improved scatter_add, I'm afraid I have no idea how to improve scatter. But yeah, I can see that if we can assume some structure on indices, we can do some parallelization on scatter too. I find this problem interesting, I'll put scatter improvement in my backlog.

@masahi masahi merged commit 2a2081e into apache:main Dec 7, 2020
@masahi
Copy link
Member Author

masahi commented Dec 7, 2020

Thanks @mbrookhart @tkonolige @Laurawly

TusharKanekiDey pushed a commit to TusharKanekiDey/tvm that referenced this pull request Jan 20, 2021
* use atomic add for faster 1d scatter add

* update tests

* run black

* more pylint fix

* remove fp64 bintcount test

Co-authored-by: masa <[email protected]>
trevor-m pushed a commit to neo-ai/tvm that referenced this pull request Jan 21, 2021
* use atomic add for faster 1d scatter add

* update tests

* run black

* more pylint fix

* remove fp64 bintcount test

Co-authored-by: masa <[email protected]>
electriclilies pushed a commit to electriclilies/tvm that referenced this pull request Feb 18, 2021
* use atomic add for faster 1d scatter add

* update tests

* run black

* more pylint fix

* remove fp64 bintcount test

Co-authored-by: masa <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants