Skip to content

Commit 0fc047c

Browse files
wrongtest-intellifwrongtest
andauthored
[Compute-inline] Prefer T.where for reverse compute-inlined block with predicate (#17128)
* prefer T.where for reverse compute-inlined block with predicate * update ut scripts --------- Co-authored-by: wrongtest <[email protected]>
1 parent 3e08e70 commit 0fc047c

File tree

5 files changed

+98
-49
lines changed

5 files changed

+98
-49
lines changed

src/tir/schedule/primitive/compute_inline.cc

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -682,11 +682,14 @@ class ReverseComputeInliner : public BaseInliner {
682682
using BaseInliner::VisitStmt_;
683683

684684
/*! \brief Generate the predicate after inlining based on the consumer predicate */
685-
Block BuildInlinedConsumerPredicate(const BlockNode* producer_block) {
685+
BlockRealize BuildInlinedConsumerPredicate(BlockRealize producer_block_realize) {
686686
// Bind the producer block iter domains for simplification
687687
Map<Var, PrimExpr> subst_map;
688+
Block producer_block = producer_block_realize->block;
688689
for (int i = 0, n = producer_block->iter_vars.size(); i < n; ++i) {
689690
const IterVar& iter = producer_block->iter_vars[i];
691+
const PrimExpr& binding = producer_block_realize->iter_values[i];
692+
subst_map.Set(iter->var, binding);
690693
analyzer_.Bind(iter->var, Range::FromMinExtent(iter->dom->min, iter->dom->extent));
691694
}
692695
if (producer_block->annotations.count(tir::attr::auto_copy) != 0) {
@@ -705,30 +708,33 @@ class ReverseComputeInliner : public BaseInliner {
705708
PrimExpr predicate = Substituter(this)(consumer_iter_in_bound_);
706709
// Simplify the predicate using the producer block iter domains
707710
predicate = analyzer_.Simplify(predicate);
708-
ObjectPtr<BlockNode> block = make_object<BlockNode>(*producer_block);
709711
if (is_one(predicate)) {
710-
return Block(block);
711-
}
712-
if (const auto* if_ = producer_block->body.as<tir::IfThenElseNode>()) {
713-
PrimExpr if_predicate = analyzer_.Simplify(if_->condition);
714-
if (!StructuralEqual()(predicate, if_predicate)) {
715-
predicate = analyzer_.Simplify(predicate && if_->condition);
712+
return producer_block_realize;
713+
}
714+
if (const auto* if_ = producer_block->body.as<IfThenElseNode>()) {
715+
if (!if_->else_case.defined()) {
716+
PrimExpr if_predicate = analyzer_.Simplify(if_->condition);
717+
if (!StructuralEqual()(predicate, if_predicate)) {
718+
predicate = analyzer_.Simplify(predicate && if_->condition);
719+
producer_block.CopyOnWrite()->body = if_->then_case;
720+
}
716721
}
717-
block->body = IfThenElse(predicate, if_->then_case);
718-
return Block(block);
719722
}
720-
block->body = IfThenElse(predicate, block->body);
721-
return Block(block);
723+
PrimExpr outer_predicate = Substitute(predicate, subst_map);
724+
auto n = producer_block_realize.CopyOnWrite();
725+
n->block = producer_block;
726+
n->predicate = analyzer_.Simplify(outer_predicate);
727+
return GetRef<BlockRealize>(n);
722728
}
723729

724-
Stmt VisitStmt_(const BlockNode* op) final {
725-
Block src_block = GetRef<Block>(op);
726-
Block tgt_block = Downcast<Block>(BaseInliner::VisitStmt_(op));
727-
if (op == producer_block_) {
728-
tgt_block = BuildInlinedConsumerPredicate(tgt_block.get());
729-
block_reuse.Set(src_block, tgt_block);
730+
Stmt VisitStmt_(const BlockRealizeNode* op) final {
731+
Block src_block = op->block;
732+
BlockRealize tgt_block_realize = Downcast<BlockRealize>(StmtMutator::VisitStmt_(op));
733+
if (src_block.get() == producer_block_) {
734+
tgt_block_realize = BuildInlinedConsumerPredicate(tgt_block_realize);
735+
block_reuse.Set(src_block, tgt_block_realize->block);
730736
}
731-
return std::move(tgt_block);
737+
return std::move(tgt_block_realize);
732738
}
733739

734740
Stmt VisitStmt_(const BufferStoreNode* _store) final {

tests/python/dlight/test_gpu_matmul.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,10 @@ def expected(var_inp0: T.handle, inp1: T.Buffer((T.int64(4096), T.int64(4096)),
113113
v0 = T.axis.spatial(T.int64(1), ax0)
114114
v1 = T.axis.spatial((m + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_2 * T.int64(4) + ax1)
115115
v2 = T.axis.spatial(T.int64(4096), ax0_ax2_0_fused * T.int64(64) + ax2_2 * T.int64(4) + ax2_0 * T.int64(2) + ax2_1_1)
116+
T.where(ax1_0 * T.int64(32) + ax1_2 * T.int64(4) + ax1 < m)
116117
T.reads(matmul_reindex_pad_local[v0, v1, v2])
117118
T.writes(matmul[T.int64(0), v1, v2])
118-
if v1 < m:
119-
matmul[T.int64(0), v1, v2] = matmul_reindex_pad_local[v0, v1, v2]
119+
matmul[T.int64(0), v1, v2] = matmul_reindex_pad_local[v0, v1, v2]
120120
# fmt: on
121121

122122

@@ -200,10 +200,10 @@ def expected(var_inp0: T.handle, inp1: T.Buffer((4096, 4096), "float32"), var_ma
200200
v0 = T.axis.spatial(1, ax0)
201201
v1 = T.axis.spatial((m + 31) // 32 * 32, ax1_0 * 32 + ax1_2 * 4 + ax1)
202202
v2 = T.axis.spatial(4096, ax0_ax2_0_fused * 64 + ax2_2 * 4 + ax2_0 * 2 + ax2_1_1)
203+
T.where(ax1_0 * 32 + ax1_2 * 4 + ax1 < m)
203204
T.reads(matmul_reindex_pad_local[v0, v1, v2])
204205
T.writes(matmul[0, v1, v2])
205-
if v1 < m:
206-
matmul[0, v1, v2] = matmul_reindex_pad_local[v0, v1, v2]
206+
matmul[0, v1, v2] = matmul_reindex_pad_local[v0, v1, v2]
207207
# fmt: on
208208

209209
mod = tvm.IRModule({"main": func})
@@ -466,10 +466,10 @@ def expected(lv13: T.Buffer((T.int64(4096), T.int64(512)), "uint32"), lv14: T.Bu
466466
v0 = T.axis.spatial(T.int64(1), ax0)
467467
v1 = T.axis.spatial((n + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_2 * T.int64(4) + ax1)
468468
v2 = T.axis.spatial(T.int64(4096), ax0_ax2_0_fused * T.int64(64) + ax2_2 * T.int64(4) + ax2_0 * T.int64(2) + ax2_1_1)
469+
T.where(ax1_0 * T.int64(32) + ax1_2 * T.int64(4) + ax1 < n)
469470
T.reads(var_matmul_intermediate_reindex_pad_local[v0, v1, v2], lv13_1[v2], lv3[T.int64(0), v1, v2])
470471
T.writes(p_output0_intermediate[T.int64(0), v1, v2])
471-
if v1 < n:
472-
p_output0_intermediate[T.int64(0), v1, v2] = T.Cast("float16", var_matmul_intermediate_reindex_pad_local[v0, v1, v2] + T.Cast("float32", lv13_1[v2])) + lv3[T.int64(0), v1, v2]
472+
p_output0_intermediate[T.int64(0), v1, v2] = T.Cast("float16", var_matmul_intermediate_reindex_pad_local[v0, v1, v2] + T.Cast("float32", lv13_1[v2])) + lv3[T.int64(0), v1, v2]
473473

474474
# fmt: on
475475

@@ -596,9 +596,9 @@ def expected(p_lv26: T.handle, lv9: T.Buffer((T.int64(2048), T.int64(2048)), "fl
596596
v1 = T.axis.spatial((n + T.int64(31)) // T.int64(32) * T.int64(32), ax1_0 * T.int64(32) + ax1_2 * T.int64(4) + ax1)
597597
v2 = T.axis.spatial(T.int64(2048), ax0_ax2_0_fused * T.int64(64) + ax2_2 * T.int64(4) + ax2_0 * T.int64(2) + ax2_1_1)
598598
T.reads(lv52[T.int64(0), v1, v2], var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2])
599+
T.where(ax1_0 * T.int64(32) + ax1_2 * T.int64(4) + ax1 < n)
599600
T.writes(var_T_multiply_intermediate[v1, v2])
600-
if v1 < n:
601-
var_T_multiply_intermediate[v1, v2] = T.Cast("float16", lv52[T.int64(0), v1, v2]) * (var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2] * T.sigmoid(var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2]))
601+
var_T_multiply_intermediate[v1, v2] = T.Cast("float16", lv52[T.int64(0), v1, v2]) * (var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2] * T.sigmoid(var_NT_matmul_intermediate_reindex_pad_local[v0, v1, v2]))
602602

603603
# fmt: on
604604

@@ -666,10 +666,10 @@ def expected(var_inp0: T.handle, inp1: T.Buffer((T.int64(4096), T.int64(4096)),
666666
v0 = T.axis.spatial(T.int64(1), ax0)
667667
v1 = T.axis.spatial((m + T.int64(31)) // T.int64(32) * T.int64(32), ax0_ax1_0_fused * T.int64(32) + ax1_2 * T.int64(2) + ax1)
668668
v2 = T.axis.spatial(T.int64(4096), ax2_0 * T.int64(64) + ax2_2 * T.int64(8) + ax2_0_1 * T.int64(8) + ax2_1_1)
669+
T.where(ax0_ax1_0_fused * T.int64(32) + ax1_2 * T.int64(2) + ax1 < m)
669670
T.reads(matmul_reindex_pad_local[v0, v1, v2])
670671
T.writes(matmul[T.int64(0), v1, v2])
671-
if v1 < m:
672-
matmul[T.int64(0), v1, v2] = matmul_reindex_pad_local[v0, v1, v2]
672+
matmul[T.int64(0), v1, v2] = matmul_reindex_pad_local[v0, v1, v2]
673673
# fmt: on
674674

675675

tests/python/dlight/test_gpu_matmul_tensorize.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -254,10 +254,10 @@ def expected(var_X: T.handle, W: T.Buffer((15, 256), "float16"), var_compute: T.
254254
v0 = T.axis.spatial(1, ax0)
255255
v1 = T.axis.spatial((m + 31) // 32 * 32, ax1_0 * 32 + ax1_2 * 4 + ax1)
256256
v2 = T.axis.spatial(64, ax2_2 * 4 + ax2_0 * 2 + ax2_1_1)
257+
T.where(ax1_0 * 32 + ax1_2 * 4 + ax1 < m and ax2_2 * 4 + ax2_0 * 2 + ax2_1_1 < 15)
257258
T.reads(compute_reindex_pad_local[v0, v1, v2])
258259
T.writes(compute[v1, v2])
259-
if v1 < m and v2 < 15:
260-
compute[v1, v2] = compute_reindex_pad_local[v0, v1, v2]
260+
compute[v1, v2] = compute_reindex_pad_local[v0, v1, v2]
261261
# fmt: on
262262

263263

@@ -417,11 +417,11 @@ def expected(lv686: T.Buffer((4096, 256), "uint32"), lv687: T.Buffer((4096, 64),
417417
v0 = T.axis.spatial(1, 0)
418418
v1 = T.axis.spatial((n + 127) // 128 * 128, ax1_0_0_ax2_0_0_fused * 128 + ax2_0_2_ax1_0_2_fused % 4 * 32 + (ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) // 32)
419419
v2 = T.axis.spatial(4096, ax1_0_1_ax2_0_1_fused * 128 + ax2_0_2_ax1_0_2_fused // 4 * 32 + (ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) % 32)
420+
T.where(ax1_0_0_ax2_0_0_fused * 128 + ax2_0_2_ax1_0_2_fused % 4 * 32 + ((ax0_ax1_fused_0 * 32 + ax0_ax1_fused_1) * 4 + ax0_ax1_fused_2) // 32 < n)
420421
T.reads(lv3[0, v1, v2], var_NT_matmul_intermediate_reindex_pad_shared_dyn[v0, v1, v2])
421422
T.writes(p_output0_intermediate[0, v1, v2])
422423
T.block_attr({"buffer_dim_align": [[0, 1, 16, 4]]})
423-
if v1 < n:
424-
p_output0_intermediate[0, v1, v2] = lv3[0, v1, v2] * T.float16(0.5) + var_NT_matmul_intermediate_reindex_pad_shared_dyn[v0, v1, v2]
424+
p_output0_intermediate[0, v1, v2] = lv3[0, v1, v2] * T.float16(0.5) + var_NT_matmul_intermediate_reindex_pad_shared_dyn[v0, v1, v2]
425425
# fmt: on
426426

427427

@@ -690,11 +690,11 @@ def expected(var_A: T.handle, B: T.Buffer((4096, 22016), "int8"), var_matmul: T.
690690
v0 = T.axis.spatial(1, 0)
691691
v1 = T.axis.spatial((m + 127) // 128 * 128, ax1_0_0_ax2_0_0_fused * 128 + ax2_0_2_ax1_0_2_fused % 4 * 32 + (ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) // 32)
692692
v2 = T.axis.spatial(4096, ax1_0_1_ax2_0_1_fused * 128 + ax2_0_2_ax1_0_2_fused // 4 * 32 + (ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) % 32)
693+
T.where(ax1_0_0_ax2_0_0_fused * 128 + ax2_0_2_ax1_0_2_fused % 4 * 32 + ((ax0_ax1_fused_0 * 32 + ax0_ax1_fused_1) * 4 + ax0_ax1_fused_2) // 32 < m)
693694
T.reads(matmul_1_reindex_pad_shared_dyn[v0, v1, v2])
694695
T.writes(matmul_1[0, v1, v2])
695696
T.block_attr({"buffer_dim_align": [[0, 1, 16, 4]]})
696-
if v1 < m:
697-
matmul_1[0, v1, v2] = matmul_1_reindex_pad_shared_dyn[v0, v1, v2]
697+
matmul_1[0, v1, v2] = matmul_1_reindex_pad_shared_dyn[v0, v1, v2]
698698
# fmt: on
699699

700700

@@ -831,10 +831,10 @@ def expected(var_A: T.handle, B: T.Buffer((28672, 4096), "float16"), var_C: T.ha
831831
v0 = T.axis.spatial(1, ax0_1)
832832
v1 = T.axis.spatial((batch_size + 15) // 16 * 16, ax1_0 * 16 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 128 + ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) // 64)
833833
v2 = T.axis.spatial(28672, ax2_0 * 64 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 128 + ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) % 64)
834+
T.where(ax1_0 * 16 + (((ax1_ax2_fused_0 * 4 + ax1_ax2_fused_1 + ax1_ax2_fused_2) * 32 + ax1_ax2_fused_3) * 4 + ax1_ax2_fused_4) // 64 < batch_size)
834835
T.reads(C_reindex_pad_shared[v0, v1, v2])
835836
T.writes(C[v1, 0, v2])
836-
if v1 < batch_size:
837-
C[v1, 0, v2] = C_reindex_pad_shared[v0, v1, v2]
837+
C[v1, 0, v2] = C_reindex_pad_shared[v0, v1, v2]
838838
# fmt: on
839839

840840

@@ -971,10 +971,10 @@ def expected(B0: T.Buffer((28672, 512), "uint32"), B1: T.Buffer((28672, 128), "f
971971
v0 = T.axis.spatial(1, ax0_1)
972972
v1 = T.axis.spatial((batch_size + 15) // 16 * 16, ax1_0 * 16 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 128 + ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) // 64)
973973
v2 = T.axis.spatial(28672, ax2_0 * 64 + (ax1_ax2_fused_0 * 512 + ax1_ax2_fused_1 * 128 + ax1_ax2_fused_2 * 128 + ax1_ax2_fused_3 * 4 + ax1_ax2_fused_4) % 64)
974+
T.where(ax1_0 * 16 + (((ax1_ax2_fused_0 * 4 + ax1_ax2_fused_1 + ax1_ax2_fused_2) * 32 + ax1_ax2_fused_3) * 4 + ax1_ax2_fused_4) // 64 < batch_size)
974975
T.reads(C_reindex_pad_shared[v0, v1, v2])
975976
T.writes(C[v1, 0, v2])
976-
if v1 < batch_size:
977-
C[v1, 0, v2] = C_reindex_pad_shared[v0, v1, v2]
977+
C[v1, 0, v2] = C_reindex_pad_shared[v0, v1, v2]
978978

979979

980980
if __name__ == "__main__":

tests/python/meta_schedule/test_meta_schedule_schedule_rule_mlt_tc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -856,11 +856,11 @@ def padded_matmul_relu_0(A: T.Buffer((127, 127), "float16"), B: T.Buffer((127, 1
856856
v3 = T.axis.spatial(1, 0)
857857
v4 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 256 // 16)
858858
v5 = T.axis.spatial(16, ax0_ax1_ax3_ax4_ax5_fused % 16)
859+
T.where(ax0_0_0_ax1_0_0_fused // 2 * 32 + ax2 * 16 + ax0_ax1_ax3_ax4_ax5_fused % 256 // 16 < 127 and ax0_0_0_ax1_0_0_fused % 2 * 64 + ax0_0_1_ax1_0_1_fused * 32 + ax0_ax1_ax3_ax4_ax5_fused // 256 * 16 + ax0_ax1_ax3_ax4_ax5_fused % 16 < 127)
859860
T.reads(C_reindex_shared[v0, v1, v2, v3, v4, v5])
860861
T.writes(compute[v4 + v2 * 16 + v0 * 32, v5 + v1 * 16])
861862
T.block_attr({"meta_schedule.cooperative_fetch": 4})
862-
if v0 * 32 + v2 * 16 + v4 < 127 and v1 * 16 + v5 < 127:
863-
compute[v4 + v2 * 16 + v0 * 32, v5 + v1 * 16] = T.max(C_reindex_shared[v0, v1, v2, v3, v4, v5], T.float32(0))
863+
compute[v4 + v2 * 16 + v0 * 32, v5 + v1 * 16] = T.max(C_reindex_shared[v0, v1, v2, v3, v4, v5], T.float32(0))
864864
# fmt: on
865865

866866
decision_0 = [

0 commit comments

Comments
 (0)