Commit 6ab06c5
[TIR] ThreadAllreduce warp-level primitive support with multi-warp (apache#15327)
This PR enhances the implementation of the LowerThreadAllreduce pass.
Prior to this PR, for CUDA backend we will leverage warp-level
primitives only when
* the reducing threads are a sub-warp (i.e., size 16, 8, 4, 2), or
* the number of reducing threads is less then 32, and equals the
reduction extent.
Under the requirement above, for reductions that have large number
of reducing threads (e.g., reducing over 128, 256 or larger number
or threads), the generated code is inefficient.
This PR improves the LowerThreadAllreduce pass, so that we now generate
more efficient CUDA code in such cases, when the number of reducing
threads is a multiple of warp size, with the help of warp-level
primitives.
Specifically, in such cases, we first reducing 32 elements within
each warp, getting the results of each warp stored in shared memory.
We then trigger a second round of warp-level primitive reduction
within the first warp, and get the final reduction results.
In addition to using warp-level primitives, by doing this we also
reduce the size of the shared memory. For example, even when reducing
over 1024 threads, we now only require shared memory of size 32,
compared with 1024 prior to this PR.
Tests are added to ensure correctness.1 parent 5eb420a commit 6ab06c5
File tree
4 files changed
+613
-137
lines changed- python/tvm/tir
- src
- te/operation
- tir/transforms
- tests/python/unittest
4 files changed
+613
-137
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
616 | 616 | | |
617 | 617 | | |
618 | 618 | | |
619 | | - | |
| 619 | + | |
620 | 620 | | |
621 | 621 | | |
622 | 622 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
181 | 181 | | |
182 | 182 | | |
183 | 183 | | |
| 184 | + | |
| 185 | + | |
| 186 | + | |
| 187 | + | |
| 188 | + | |
| 189 | + | |
184 | 190 | | |
185 | 191 | | |
186 | 192 | | |
187 | 193 | | |
188 | 194 | | |
189 | 195 | | |
| 196 | + | |
190 | 197 | | |
191 | 198 | | |
192 | 199 | | |
193 | 200 | | |
194 | | - | |
195 | | - | |
196 | | - | |
197 | | - | |
198 | | - | |
199 | | - | |
200 | 201 | | |
201 | 202 | | |
202 | 203 | | |
| |||
0 commit comments