Skip to content

Commit 6ab06c5

Browse files
MasterJH5574junrushao
authored andcommitted
[TIR] ThreadAllreduce warp-level primitive support with multi-warp (apache#15327)
This PR enhances the implementation of the LowerThreadAllreduce pass. Prior to this PR, for CUDA backend we will leverage warp-level primitives only when * the reducing threads are a sub-warp (i.e., size 16, 8, 4, 2), or * the number of reducing threads is less then 32, and equals the reduction extent. Under the requirement above, for reductions that have large number of reducing threads (e.g., reducing over 128, 256 or larger number or threads), the generated code is inefficient. This PR improves the LowerThreadAllreduce pass, so that we now generate more efficient CUDA code in such cases, when the number of reducing threads is a multiple of warp size, with the help of warp-level primitives. Specifically, in such cases, we first reducing 32 elements within each warp, getting the results of each warp stored in shared memory. We then trigger a second round of warp-level primitive reduction within the first warp, and get the final reduction results. In addition to using warp-level primitives, by doing this we also reduce the size of the shared memory. For example, even when reducing over 1024 threads, we now only require shared memory of size 32, compared with 1024 prior to this PR. Tests are added to ensure correctness.
1 parent 5eb420a commit 6ab06c5

File tree

4 files changed

+613
-137
lines changed

4 files changed

+613
-137
lines changed

python/tvm/tir/op.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -616,7 +616,7 @@ def tvm_storage_sync(storage_scope):
616616
call : PrimExpr
617617
The call expression.
618618
"""
619-
return call_intrin("handle", "tir.tvm_storage_sync", storage_scope)
619+
return call_intrin("int32", "tir.tvm_storage_sync", storage_scope)
620620

621621

622622
def tvm_warp_shuffle(mask, value, warp_id, width, warp_size):

src/te/operation/cross_thread_reduction.cc

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -181,22 +181,23 @@ Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage,
181181
freduce_args.push_back(dummy_load);
182182
}
183183

184+
// Checks for the thread.
185+
std::vector<PrimExpr> output_preds;
186+
if (stage->store_predicate.defined()) {
187+
output_preds.emplace_back(stage->store_predicate);
188+
}
189+
184190
for (IterVar iv : stage->leaf_iter_vars) {
185191
if (iv->iter_type == kCommReduce) {
186192
auto it = stage->iter_var_attrs.find(iv);
187193
if (it != stage->iter_var_attrs.end() && (*it).second->bind_thread.defined()) {
188194
IterVar tv = (*it).second->bind_thread;
189195
freduce_args.push_back(tv->var);
196+
output_preds.push_back(tv->var == make_const(tv->var->dtype, 0));
190197
}
191198
}
192199
}
193200

194-
// Checks for the thread.
195-
std::vector<PrimExpr> output_preds;
196-
if (stage->store_predicate.defined()) {
197-
output_preds.emplace_back(stage->store_predicate);
198-
}
199-
200201
// Apply the existing input predicate if any.
201202
output_preds.push_back(input_pred);
202203

0 commit comments

Comments
 (0)