Skip to content

Commit 3f69ed4

Browse files
authored
[TIR] Finer predicate handling in cross-thread reduction (#15374)
This PR fixes the predicate handling logic of the cross-thread reduction lowering pass. For the cross-thread reduction write-back block, prior to this PR, its predicate is the conjunction of `t == 0` for each reduction thread dim of the cross-thread reduction. This is problematic when the write-back buffer is stored in local memory, where each thread is supposed to have a copy of the final value, while the final value is only stored by the first thread. In this PR, the predicate is changed to be the conjunction of the clauses from the two parts: * the clause of the original reduction block's predicate which contains spatial loop var, * `t == 0` for each reduction thread dim **only when the write-back buffer is global or shared**. So the first part ensures that the write-back will not go out of bound, and the second part ensures that when the write-back buffer is local, every thread gets a value and when the write-back buffer is non-local, only one thread writes the value out. Meanwhile, this PR fixes the cross-thread broadcasting detection with the awareness of the storage scope of the write buffer of the broadcasting block. Specifically, for each consumer block of a buffer produced by cross-thread reduction under the same kernel (i.e., same set of `blockIdx`) of the cross-thread reduction block, when the write buffer of this consumer block is in local memory, we do not treat it as broadcasting, and will not add a predicate to it. Otherwise, we will add the predicate according to the broadcasting handling introduced by #15192.
1 parent e4af302 commit 3f69ed4

File tree

2 files changed

+237
-18
lines changed

2 files changed

+237
-18
lines changed

src/tir/transforms/lower_cross_thread_reduction.cc

Lines changed: 104 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -426,12 +426,48 @@ Stmt TransformReductionBlock(const BlockRealizeNode* realize, //
426426
BufferStore(wb_buffers[i], BufferLoad(ct_buffers[i], {Integer(0)}), wb_indices));
427427
wb_regions.push_back(BufferRegion(wb_buffers[i], region));
428428
}
429+
430+
// Construct the predicate of the write-back block. It is the conjunction of
431+
// - each predicate clause of the original block which contains spatial loop var, and
432+
// - `t == 0` for each reduction thread dim when the write-back buffer is not local.
429433
PrimExpr wb_predicate = const_true();
430-
for (const ForNode* loop : reduction_loops) {
431-
if (loop->thread_binding.defined()) {
432-
wb_predicate = wb_predicate && (loop->loop_var == IntImm(loop->loop_var->dtype, 0));
434+
std::unordered_set<const VarNode*> reduction_loop_vars;
435+
reduction_loop_vars.reserve(reduction_loops.size());
436+
for (const ForNode* reduction_loop : reduction_loops) {
437+
reduction_loop_vars.insert(reduction_loop->loop_var.get());
438+
}
439+
PostOrderVisit(realize->predicate, [&wb_predicate, &reduction_loop_vars](const ObjectRef& obj) {
440+
if (const auto* and_node = obj.as<AndNode>()) {
441+
Array<PrimExpr> sub_exprs = {and_node->a, and_node->b};
442+
for (PrimExpr sub_expr : sub_exprs) {
443+
if (sub_expr->IsInstance<AndNode>()) {
444+
continue;
445+
}
446+
bool is_reduction = [sub_expr, &reduction_loop_vars]() {
447+
Array<Var> vars = UndefinedVars(sub_expr);
448+
for (Var var : vars) {
449+
if (reduction_loop_vars.find(var.get()) != reduction_loop_vars.end()) {
450+
return true;
451+
}
452+
}
453+
return false;
454+
}();
455+
if (!is_reduction) {
456+
wb_predicate = wb_predicate && sub_expr;
457+
}
458+
}
459+
return true;
460+
}
461+
return false;
462+
});
463+
if (wb_buffers[0].scope() != "local") {
464+
for (const ForNode* loop : reduction_loops) {
465+
if (loop->thread_binding.defined()) {
466+
wb_predicate = wb_predicate && (loop->loop_var == IntImm(loop->loop_var->dtype, 0));
467+
}
433468
}
434469
}
470+
435471
stmts.push_back(BlockRealize(
436472
/*iter_values=*/std::move(bindings),
437473
/*predicate=*/wb_predicate,
@@ -498,21 +534,45 @@ class CrossThreadReductionTransformer : public StmtMutator {
498534
}
499535

500536
// Check if the input block needs thread broadcast rewrite.
501-
// One block needs broadcast rewrite when there exists one or more thread
502-
// vars which vars free variables to this block.
537+
// One block needs broadcast rewrite when
538+
// 1. it consumes a buffer produced by cross-thread reduction under
539+
// the same kernel (i.e., same group of blockIdx),
540+
// 2. it writes to non-local memory,
541+
// 3. at least one of the reduction thread vars of the cross-thread reduction
542+
// is free to this block (i.e., not bound to the block).
503543
std::vector<std::pair<ThreadScope, Range>> NeedCrossThreadBroadcast(
504544
const BlockRealizeNode* realize) {
505-
std::unordered_map<ThreadScope, Range, ThreadScopeHash, ThreadScopeEqual> unbound_thread2range =
506-
thread2range_;
545+
Block block = realize->block;
546+
547+
// If the block writes to local memory, no rewrite is needed.
548+
for (BufferRegion write_region : block->writes) {
549+
if (write_region->buffer.scope() == "local") {
550+
return {};
551+
}
552+
}
553+
554+
// Find out the reduction threads for the read-buffers which are produced by
555+
// cross-thread reduction.
556+
std::unordered_map<ThreadScope, Range, ThreadScopeHash, ThreadScopeEqual> thread2range;
557+
for (BufferRegion read_region : block->reads) {
558+
auto buf_it = crt_buf2threads_.find(read_region->buffer.get());
559+
if (buf_it == crt_buf2threads_.end()) {
560+
continue;
561+
}
562+
for (auto [scope, range] : buf_it->second) {
563+
thread2range[scope] = range;
564+
}
565+
}
566+
567+
// Erase those threads which are not free to this block.
507568
for (const ForNode* loop : loop_stack_) {
508569
if (loop->thread_binding.defined()) {
509570
ThreadScope scope = ThreadScope::Create(loop->thread_binding.value()->thread_tag);
510-
unbound_thread2range.erase(scope);
571+
thread2range.erase(scope);
511572
}
512573
}
513-
514574
std::vector<std::pair<ThreadScope, Range>> unbound_thread2range_list;
515-
for (auto [scope, range] : unbound_thread2range) {
575+
for (auto [scope, range] : thread2range) {
516576
unbound_thread2range_list.emplace_back(scope, range);
517577
}
518578
return unbound_thread2range_list;
@@ -582,13 +642,28 @@ class CrossThreadReductionTransformer : public StmtMutator {
582642
std::tie(reducer, combiner_lhs, combiner_rhs) =
583643
GetReducerAndCombinerLhsRhs(NullOpt, init_values, updates);
584644

645+
// Condition 4. All reduction buffers should be all local or all non-local.
646+
int is_local_buf = -1;
585647
Array<Buffer> reduction_buffers;
586648
reduction_buffers.reserve(updates.size());
587649
for (const BufferStore& buf_store : updates) {
588650
reduction_buffers.push_back(buf_store->buffer);
651+
if (buf_store->buffer.scope() == "local") {
652+
CHECK_NE(is_local_buf, 0)
653+
<< "ValueError: Cross-thread reduction requires all reduction buffers to be all "
654+
"local or all non-local. However, here some buffer is local while some buffer is "
655+
"shared or global.";
656+
is_local_buf = 1;
657+
} else {
658+
CHECK_NE(is_local_buf, 1)
659+
<< "ValueError: Cross-thread reduction requires all reduction buffers to be all "
660+
"local or all non-local. However, here some buffer is local while some buffer is "
661+
"shared or global.";
662+
is_local_buf = 0;
663+
}
589664
}
590665

591-
// Condition 4. The block should be the last block under the first reduction-related loop.
666+
// Condition 5. The block should be the last block under the first reduction-related loop.
592667
bool visit = false;
593668
PreOrderVisit(GetRef<For>(reduction_loops[0]), [block, &visit](const ObjectRef& obj) {
594669
if (const auto* realize = obj.as<BlockRealizeNode>()) {
@@ -631,8 +706,6 @@ class CrossThreadReductionTransformer : public StmtMutator {
631706
if (scope.rank == 1 && scope.dim_index >= 0) {
632707
is_thread_idx = true;
633708
++thread_idx_depth;
634-
thread2range_[scope] = Range::FromMinExtent(loop->min, loop->extent);
635-
thread_loop_var2scope_[loop->loop_var.get()] = scope;
636709
} else if (scope.rank == 0) {
637710
is_block_idx = true;
638711
++block_idx_depth;
@@ -649,7 +722,7 @@ class CrossThreadReductionTransformer : public StmtMutator {
649722
--block_idx_depth;
650723
}
651724
if (is_block_idx || (is_thread_idx && thread_idx_depth == 0 && block_idx_depth == 0)) {
652-
thread2range_.clear();
725+
crt_buf2threads_.clear();
653726
}
654727

655728
// Replace `result` with the pre-stored result if `loop` appears as a key in `loop2new_stmt_`.
@@ -716,6 +789,21 @@ class CrossThreadReductionTransformer : public StmtMutator {
716789
loop2new_stmt_[reduction_loops[0]] =
717790
TransformReductionBlock(realize, it_buffers, ct_buffers, reduction_buffers, wb_indices,
718791
reducer, combiner_rhs, reduction_loops);
792+
793+
// Step 5. Record the reduction thread dims for the write-back buffers.
794+
// The information is used for consumer block broadcasting detection.
795+
std::vector<std::pair<ThreadScope, Range>> reduction_threads;
796+
reduction_threads.reserve(reduction_loops.size());
797+
for (const ForNode* loop : reduction_loops) {
798+
if (loop->thread_binding.defined()) {
799+
reduction_threads.emplace_back(
800+
ThreadScope::Create(loop->thread_binding.value()->thread_tag),
801+
Range::FromMinExtent(loop->min, loop->extent));
802+
}
803+
}
804+
for (const Buffer& reduction_buf : reduction_buffers) {
805+
crt_buf2threads_[reduction_buf.get()] = reduction_threads;
806+
}
719807
}
720808

721809
Stmt MakeCrossThreadBroadcast(
@@ -792,8 +880,8 @@ class CrossThreadReductionTransformer : public StmtMutator {
792880

793881
int block_idx_depth = 0;
794882
int thread_idx_depth = 0;
795-
std::unordered_map<ThreadScope, Range, ThreadScopeHash, ThreadScopeEqual> thread2range_;
796-
std::unordered_map<const VarNode*, ThreadScope> thread_loop_var2scope_;
883+
std::unordered_map<const BufferNode*, std::vector<std::pair<ThreadScope, Range>>>
884+
crt_buf2threads_;
797885
};
798886

799887
PrimFunc LowerCrossThreadReduction(PrimFunc f) {

tests/python/unittest/test_tir_transform_lower_cross_thread_reduction.py

Lines changed: 133 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,64 @@ def lowered_single_reduction_loop_with_block_predicate(
496496
)
497497

498498

499+
@T.prim_func
500+
def spatial_reduction_loop_predicate(A: T.Buffer((2, 32), "float32"), B: T.Buffer((2,), "float32")):
501+
for i_0 in range(1):
502+
for i_1 in T.thread_binding(16, thread="threadIdx.y"):
503+
for k_0 in range(1):
504+
for k_1 in T.thread_binding(64, thread="threadIdx.x"):
505+
with T.block("block"):
506+
vi = T.axis.spatial(2, i_0 * 16 + i_1)
507+
vk = T.axis.reduce(32, k_0 * 64 + k_1)
508+
T.where(i_0 * 16 + i_1 < 2 and k_0 * 64 + k_1 < 32)
509+
T.reads(A[vi, vk])
510+
T.writes(B[vi])
511+
with T.init():
512+
B[vi] = T.float32(0)
513+
B[vi] = B[vi] + A[vi, vk]
514+
515+
516+
@T.prim_func
517+
def lowered_reduction_spatial_loop_predicate(
518+
A: T.Buffer((2, 32), "float32"), B: T.Buffer((2,), "float32")
519+
):
520+
cross_thread_B = T.alloc_buffer((1,), strides=(1,), scope="local")
521+
in_thread_B = T.alloc_buffer((1,), strides=(1,), scope="local")
522+
for i_0 in range(1):
523+
for i_1 in T.thread_binding(16, thread="threadIdx.y"):
524+
for k_1 in T.thread_binding(64, thread="threadIdx.x"):
525+
with T.block("block_in_thread_init"):
526+
T.reads()
527+
T.writes(in_thread_B[0])
528+
in_thread_B[0] = T.float32(0)
529+
for k_0 in range(1):
530+
with T.block("block_in_thread"):
531+
vi = T.axis.spatial(2, i_0 * 16 + i_1)
532+
vk = T.axis.reduce(32, k_0 * 64 + k_1)
533+
T.where(i_0 * 16 + i_1 < 2 and k_0 * 64 + k_1 < 32)
534+
T.reads(A[vi, vk])
535+
T.writes(in_thread_B[0])
536+
in_thread_B[0] = in_thread_B[0] + A[vi, vk]
537+
with T.block("block_cross_thread"):
538+
T.reads(in_thread_B[0])
539+
T.writes(cross_thread_B[0])
540+
T.attr(
541+
T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
542+
"reduce_scope",
543+
T.reinterpret("handle", T.uint64(0)),
544+
)
545+
T.tvm_thread_allreduce(
546+
T.uint32(1), in_thread_B[0], T.bool(True), cross_thread_B[0], k_1
547+
)
548+
k_0 = T.int32()
549+
with T.block("block_write_back"):
550+
vi = T.axis.spatial(2, i_0 * 16 + i_1)
551+
T.where(i_0 * 16 + i_1 < 2 and k_1 == 0)
552+
T.reads(cross_thread_B[0])
553+
T.writes(B[vi])
554+
B[vi] = cross_thread_B[0]
555+
556+
499557
@T.prim_func
500558
def single_reduction_loop_with_tensorize(
501559
input_A: T.Buffer((1, 64, 7, 7, 32), "uint8"),
@@ -1315,7 +1373,6 @@ def lowered_thread_broadcast_1(A: T.Buffer((256, 256), "float32"), B: T.Buffer((
13151373
)
13161374
with T.block("sum_write_back"):
13171375
vi = T.axis.spatial(256, i)
1318-
T.where(k == 0)
13191376
T.reads(cross_thread_temp_local[0])
13201377
T.writes(temp_local[vi])
13211378
temp_local[vi] = cross_thread_temp_local[0]
@@ -1428,7 +1485,7 @@ def lowered_thread_broadcast_2(lv1605: T.Buffer((T.int64(1), T.int64(32), T.int6
14281485
with T.block("NT_matmul_write_back"):
14291486
v0 = T.axis.spatial(T.int64(32), ax0_ax1_fused // n)
14301487
v1 = T.axis.spatial(n, ax0_ax1_fused % n)
1431-
T.where(ax0_fused == T.int64(0))
1488+
T.where(T.int64(0) <= ax0_ax1_fused // n and ax0_ax1_fused // n < T.int64(32) and T.int64(0) <= ax0_ax1_fused % n and ax0_ax1_fused % n < n)
14321489
T.reads(cross_thread_var_NT_matmul_intermediate_local[0])
14331490
T.writes(var_NT_matmul_intermediate_local[T.int64(0), v0, T.int64(0), v1])
14341491
var_NT_matmul_intermediate_local[T.int64(0), v0, T.int64(0), v1] = cross_thread_var_NT_matmul_intermediate_local[0]
@@ -1442,6 +1499,72 @@ def lowered_thread_broadcast_2(lv1605: T.Buffer((T.int64(1), T.int64(32), T.int6
14421499
var_compute_intermediate[T.int64(0), v0, T.int64(0), v1] = T.Cast("float32", T.min(T.max(var_NT_matmul_intermediate_local[T.int64(0), v0, T.int64(0), v1] * T.float16(0.088397790055248615), T.float16(-65504)), lv1582[T.int64(0), T.int64(0), T.int64(0), v1]))
14431500
# fmt: on
14441501

1502+
1503+
@T.prim_func
1504+
def no_thread_broadcast(A: T.Buffer((256, 256), "float32"), B: T.Buffer((256, 256), "float32")):
1505+
temp_1_local = T.alloc_buffer((256,), scope="local")
1506+
temp_2_local = T.alloc_buffer((1,), scope="local")
1507+
for i in T.thread_binding(256, thread="blockIdx.x"):
1508+
for k in T.thread_binding(256, thread="threadIdx.x"):
1509+
with T.block("sum"):
1510+
vi, vk = T.axis.remap("SR", [i, k])
1511+
T.reads(A[vi, vk])
1512+
T.writes(temp_1_local[vi])
1513+
with T.init():
1514+
temp_1_local[vi] = T.float32(0)
1515+
temp_1_local[vi] = temp_1_local[vi] + A[vi, vk]
1516+
with T.block("add"):
1517+
vi = T.axis.spatial(256, i)
1518+
T.reads(temp_1_local[vi])
1519+
T.writes(temp_2_local[0])
1520+
temp_2_local[0] = temp_1_local[vi] + T.float32(1)
1521+
for j in T.thread_binding(256, thread="threadIdx.x"):
1522+
with T.block("sum"):
1523+
vi, vj = T.axis.remap("SR", [i, j])
1524+
T.reads(temp_2_local[0])
1525+
T.writes(B[vi, vj])
1526+
B[vi, vj] = A[vi, vj] + temp_2_local[0]
1527+
1528+
1529+
@T.prim_func
1530+
def lowered_no_thread_broadcast(
1531+
A: T.Buffer((256, 256), "float32"), B: T.Buffer((256, 256), "float32")
1532+
):
1533+
temp_1_local = T.alloc_buffer((256,), scope="local")
1534+
temp_2_local = T.alloc_buffer((1,), scope="local")
1535+
cross_thread_temp_1_local = T.alloc_buffer((1,), strides=(1,), scope="local")
1536+
for i in T.thread_binding(256, thread="blockIdx.x"):
1537+
for k in T.thread_binding(256, thread="threadIdx.x"):
1538+
with T.block("sum_cross_thread"):
1539+
vi, vk = T.axis.remap("SR", [i, k])
1540+
T.reads(A[vi, vk])
1541+
T.writes(cross_thread_temp_1_local[0])
1542+
T.attr(
1543+
T.comm_reducer(lambda x0, y0: x0 + y0, [T.float32(0)]),
1544+
"reduce_scope",
1545+
T.reinterpret("handle", T.uint64(0)),
1546+
)
1547+
T.tvm_thread_allreduce(
1548+
T.uint32(1), A[vi, vk], T.bool(True), cross_thread_temp_1_local[0], k
1549+
)
1550+
with T.block("sum_write_back"):
1551+
vi = T.axis.spatial(256, i)
1552+
T.reads(cross_thread_temp_1_local[0])
1553+
T.writes(temp_1_local[vi])
1554+
temp_1_local[vi] = cross_thread_temp_1_local[0]
1555+
with T.block("add"):
1556+
vi = T.axis.spatial(256, i)
1557+
T.reads(temp_1_local[vi])
1558+
T.writes(temp_2_local[0])
1559+
temp_2_local[0] = temp_1_local[vi] + T.float32(1)
1560+
for j in T.thread_binding(256, thread="threadIdx.x"):
1561+
with T.block("sum"):
1562+
vi, vj = T.axis.remap("SR", [i, j])
1563+
T.reads(temp_2_local[0])
1564+
T.writes(B[vi, vj])
1565+
B[vi, vj] = A[vi, vj] + temp_2_local[0]
1566+
1567+
14451568
# pylint: enable=no-member,invalid-name,unused-variable,unexpected-keyword-arg
14461569

14471570

@@ -1472,6 +1595,10 @@ def test_single_reduction_loop_with_block_predicate():
14721595
)
14731596

14741597

1598+
def test_spatial_reduction_loop_predicate():
1599+
_check(spatial_reduction_loop_predicate, lowered_reduction_spatial_loop_predicate)
1600+
1601+
14751602
def test_single_reduction_loop_with_tensorize():
14761603
_check(
14771604
single_reduction_loop_with_tensorize,
@@ -1534,6 +1661,10 @@ def test_thread_broadcast_rewrite_2():
15341661
_check(thread_broadcast_2, lowered_thread_broadcast_2)
15351662

15361663

1664+
def test_no_thread_broadcast_rewrite():
1665+
_check(no_thread_broadcast, lowered_no_thread_broadcast)
1666+
1667+
15371668
def test_lower_te():
15381669
a = te.placeholder((32, 2, 2))
15391670
k1 = te.reduce_axis((0, 2), "k1")

0 commit comments

Comments
 (0)