Skip to content

Commit d6c1489

Browse files
wrongtesttqchen
authored andcommitted
enforcement on loop partition control
1 parent d641354 commit d6c1489

File tree

7 files changed

+145
-17
lines changed

7 files changed

+145
-17
lines changed

include/tvm/tir/builtin.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -970,6 +970,9 @@ TVM_DLL const Op& vscale();
970970
*/
971971
TVM_DLL const Op& get_active_lane_mask();
972972

973+
/*! \brief Annotate a predicate not be considered as target condition of loop partition. */
974+
TVM_DLL const Op& ignore_loop_partition();
975+
973976
/*! \brief The kind of structure field info used in intrinsic */
974977
enum TVMStructFieldKind : int {
975978
// array head address

python/tvm/script/ir_builder/tir/ir.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1910,6 +1910,7 @@ def wrapped(*args, **kwargs):
19101910
anylist_setitem_call_packed = _op_wrapper(_tir_op.anylist_setitem_call_packed)
19111911
anylist_setitem_call_cpacked = _op_wrapper(_tir_op.anylist_setitem_call_cpacked)
19121912
vscale = _op_wrapper(_tir_op.vscale)
1913+
ignore_loop_partition = _op_wrapper(_tir_op.ignore_loop_partition)
19131914

19141915

19151916
def _dtype_forward(func):
@@ -2262,4 +2263,5 @@ def wrapped(*args, **kwargs):
22622263
"vscale",
22632264
"get_active_lane_mask",
22642265
"call_kernel",
2266+
"ignore_loop_partition",
22652267
]

python/tvm/tir/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@
9797
from .op import start_profile_intrinsic, end_profile_intrinsic
9898
from .op import vscale, get_active_lane_mask, get_vscale_expr
9999
from .op import dp4a
100+
from .op import ignore_loop_partition
100101
from .generic import add, subtract, multiply
101102

102103
from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule, ScheduleError

python/tvm/tir/op.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3581,6 +3581,18 @@ def get_vscale_expr(dtype: Union[str, tvm.DataType], min_size: int = 128) -> Pri
35813581
return min_size // dtype.bits * vscale()
35823582

35833583

3584+
def ignore_loop_partition(predicate) -> PrimExpr:
3585+
"""
3586+
Annotate a predicate not be considered as target condition of loop partition.
3587+
3588+
Parameters
3589+
----------
3590+
predicate : PrimExpr
3591+
The annotated predicate expression.
3592+
"""
3593+
return call_intrin("bool", "tir.ignore_loop_partition", predicate)
3594+
3595+
35843596
# pylint: disable=unnecessary-lambda
35853597
sum = comm_reducer(lambda x, y: x + y, lambda t: const(0, dtype=t), name="sum")
35863598
min = comm_reducer(lambda x, y: _ffi_api._OpMin(x, y, None), max_value, name="min") # type: ignore

src/tir/op/builtin.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,12 @@ TIR_DEFINE_BUILTIN_FUNC(get_active_lane_mask)
422422
.set_attr<TScriptDtypePrintLocation>("TScriptDtypePrintLocation",
423423
Integer(ScriptDtypePrintLocation::kFirst));
424424

425+
TIR_DEFINE_BUILTIN_FUNC(ignore_loop_partition)
426+
.set_num_inputs(1)
427+
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure))
428+
.set_attr<TScriptDtypePrintLocation>("TScriptDtypePrintLocation",
429+
Integer(ScriptDtypePrintLocation::kNone));
430+
425431
} // namespace builtin
426432
} // namespace tir
427433
} // namespace tvm

src/tir/transforms/loop_partition.cc

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -101,15 +101,15 @@ class CandidateSelector final : public StmtExprVisitor {
101101
: partition_const_loop_(partition_const_loop) {}
102102

103103
void VisitStmt_(const ForNode* op) final {
104+
// always treat var with hint to be partitioned
105+
const VarNode* var = op->loop_var.get();
106+
if (partition_hint_vars.count(var)) {
107+
candidates.insert(GetRef<Stmt>(op));
108+
StmtExprVisitor::VisitStmt_(op);
109+
return;
110+
}
104111
// partition const loop when sets partition_const_loop_
105112
if (!is_const_int(op->min) || !is_const_int(op->extent) || partition_const_loop_) {
106-
// always treat var with hint to be partitioned
107-
const VarNode* var = op->loop_var.get();
108-
if (partition_hint_vars.count(var)) {
109-
candidates.insert(GetRef<Stmt>(op));
110-
StmtExprVisitor::VisitStmt_(op);
111-
return;
112-
}
113113
record_.insert({var, false});
114114
StmtExprVisitor::VisitStmt_(op);
115115
if (record_.at(var) && !no_split_) {
@@ -126,14 +126,14 @@ class CandidateSelector final : public StmtExprVisitor {
126126
const IterVarNode* iv = op->node.as<IterVarNode>();
127127
ICHECK(iv);
128128
Var var = iv->var;
129+
// always treat var with hint to be partitioned
130+
if (partition_hint_vars.count(var.get())) {
131+
candidates.insert(GetRef<Stmt>(op));
132+
StmtExprVisitor::VisitStmt_(op);
133+
return;
134+
}
129135
runtime::ThreadScope scope = runtime::ThreadScope::Create(iv->thread_tag);
130136
if ((scope.rank == 0) && (!is_const_int(op->value) || partition_const_loop_)) {
131-
// always treat var with hint to be partitioned
132-
if (partition_hint_vars.count(var.get())) {
133-
candidates.insert(GetRef<Stmt>(op));
134-
StmtExprVisitor::VisitStmt_(op);
135-
return;
136-
}
137137
record_.insert({var.get(), false});
138138
StmtExprVisitor::VisitStmt_(op);
139139
if (record_.at(var.get()) && !no_split_) {
@@ -262,6 +262,8 @@ class PartitionFinder : public StmtExprVisitor {
262262
void VisitExpr_(const CallNode* op) final {
263263
if (op->op.same_as(builtin::likely())) {
264264
DeduceCondition(op->args[0]);
265+
} else if (op->op.same_as(builtin::ignore_loop_partition())) {
266+
return;
265267
} else {
266268
StmtExprVisitor::VisitExpr_(op);
267269
}
@@ -287,6 +289,22 @@ class PartitionFinder : public StmtExprVisitor {
287289
// cond is true within interval
288290
partitions[{cond, true}] = interval;
289291
}
292+
293+
if (interval.IsNothing()) {
294+
// `DeduceBound` do not support NE now, thus when
295+
// deduce l==r failed, just only try (l<=r && l>=r)
296+
if (const EQNode* op = cond.as<EQNode>()) {
297+
IntSet part1 = DeduceBound(current_var_, GE(op->a, op->b), hint_map_, relax_map_);
298+
IntSet part2 = DeduceBound(current_var_, LE(op->a, op->b), hint_map_, relax_map_);
299+
interval = arith::Intersect({part1, part2});
300+
if (!interval.IsNothing()) {
301+
// cond is true within interval
302+
partitions[{cond, true}] = interval;
303+
return;
304+
}
305+
}
306+
}
307+
290308
PrimExpr inverse_cond = InverseCond(cond);
291309
if (inverse_cond.defined()) {
292310
IntSet interval = DeduceBound(current_var_, inverse_cond, hint_map_, relax_map_);
@@ -469,6 +487,7 @@ std::pair<IntSet, ExpressionSet> LoopPartitioner::GetIntervalAndCondset(
469487
if (kv.first.second == cond_value) {
470488
arith::IntervalSet interval = Downcast<arith::IntervalSet>(kv.second);
471489
arith::IntervalSet intersection = arith::Intersect(&analyzer_, interval, for_interval);
490+
472491
if (!intersection->IsEmpty()) {
473492
sets.push_back(kv.second);
474493
cond_set.insert(kv.first.first);
@@ -625,8 +644,7 @@ Stmt LoopPartitioner::TryPartition(const Stmt& stmt, Var var, PrimExpr min, Prim
625644
}();
626645

627646
if (middle_interval.IsNothing() && opt_cond_value == false) {
628-
// Return loop directly as it can be simplified.
629-
return stmt;
647+
return Stmt();
630648
}
631649

632650
if (!opt_cond_value.has_value()) {
@@ -750,6 +768,9 @@ class RemoveLikelyTagsAndHints : public StmtExprMutator {
750768
if (op->op.same_as(builtin::likely())) {
751769
ICHECK_EQ(op->args.size(), 1);
752770
return StmtExprMutator::VisitExpr(op->args[0]);
771+
} else if (op->op.same_as(builtin::ignore_loop_partition())) {
772+
ICHECK_EQ(op->args.size(), 1);
773+
return StmtExprMutator::VisitExpr(op->args[0]);
753774
} else {
754775
return StmtExprMutator::VisitExpr_(op);
755776
}

tests/python/tir-transform/test_tir_transform_loop_partition.py

Lines changed: 85 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -570,11 +570,12 @@ def test_explicit_partition_hint():
570570
tvm.ir.assert_structural_equal(mod["main"], partitioned_concat)
571571

572572

573-
def partition_from_scheduled_tir(prim_func, pass_cfg):
573+
def partition_from_scheduled_tir(prim_func, pass_cfg, do_flatten=True):
574574
with tvm.transform.PassContext(config=pass_cfg):
575575
mod = IRModule.from_expr(prim_func.with_attr("global_symbol", "main"))
576576
mod = tvm.tir.transform.LowerOpaqueBlock()(mod)
577-
mod = tvm.tir.transform.FlattenBuffer()(mod)
577+
if do_flatten:
578+
mod = tvm.tir.transform.FlattenBuffer()(mod)
578579
mod = tvm.tir.transform.LoopPartition()(mod)
579580
mod = tvm.tir.transform.Simplify()(mod)
580581
mod = tvm.tir.transform.RemoveNoOp()(mod)
@@ -1037,6 +1038,29 @@ def concat_five_buffers_with_equalities_expected(
10371038
T_concat_1[i0 * 129 + 129] = buffer_e_1[i0]
10381039

10391040

1041+
@T.prim_func
1042+
def nested_partition_with_single_points(A: T.Buffer[(25,), "int32"]):
1043+
for i in T.serial(5, annotations={"pragma_loop_partition_hint": 1}):
1044+
if i == 1:
1045+
for j in T.serial(5, annotations={"pragma_loop_partition_hint": 1}):
1046+
if j > 2:
1047+
A[i * 5 + j] = i * 5 + j
1048+
else:
1049+
for j in T.serial(5, annotations={"pragma_loop_partition_hint": 1}):
1050+
if j > 2:
1051+
A[i * 5 + j] = i * 15 + j
1052+
1053+
1054+
@T.prim_func
1055+
def nested_partition_with_single_points_expected(A: T.Buffer[(25,), "int32"]):
1056+
for j in range(2):
1057+
A[j + 3] = j + 3
1058+
for j in range(2):
1059+
A[j + 8] = j + 8
1060+
for i, j in T.grid(3, 2):
1061+
A[i * 5 + j + 13] = i * 15 + j + 33
1062+
1063+
10401064
@pytest.mark.parametrize(
10411065
"origin,expected",
10421066
[
@@ -1045,6 +1069,7 @@ def concat_five_buffers_with_equalities_expected(
10451069
(concat_func_end_point_equality, concat_func_end_point_equality_expected),
10461070
(concat_func_edge_equalities, concat_func_edge_equalities_expected),
10471071
(concat_five_buffers_with_equalities, concat_five_buffers_with_equalities_expected),
1072+
(nested_partition_with_single_points, nested_partition_with_single_points_expected),
10481073
],
10491074
)
10501075
def test_single_point_partition(origin, expected):
@@ -1062,5 +1087,63 @@ def test_single_point_partition(origin, expected):
10621087
tvm.ir.assert_structural_equal(mod["main"], expected)
10631088

10641089

1090+
def test_equation_on_floordiv():
1091+
@T.prim_func
1092+
def before(A: T.Buffer[(2, 2, 20), "int32"]):
1093+
for i in T.serial(5, annotations={"pragma_loop_partition_hint": 1}):
1094+
if i == 1:
1095+
for vv in T.vectorized(640, annotations={"pragma_loop_partition_hint": 1}):
1096+
if i * 2 + vv // 320 == 3:
1097+
A[i - 1, i * 2 + vv // 320 - 3, vv % 320 // 16] = 1
1098+
1099+
@T.prim_func
1100+
def expected(A: T.Buffer[(2, 2, 20), "int32"]):
1101+
for vv in T.vectorized(320):
1102+
A[0, 0, vv // 16] = 1
1103+
1104+
expected = expected.with_attr({"global_symbol": "main"})
1105+
after = partition_from_scheduled_tir(
1106+
before.with_attr("global_symbol", "main"), {}, do_flatten=False
1107+
)
1108+
tvm.ir.assert_structural_equal(after["main"], expected)
1109+
1110+
1111+
def test_ignore_loop_partition_hint():
1112+
"""Skip unroll body and prologue for pipeline case"""
1113+
1114+
@T.prim_func
1115+
def before(A: T.Buffer[(10), "float32"], D: T.Buffer[(10), "float32"]):
1116+
B = T.decl_buffer([2], "float32")
1117+
C = T.decl_buffer([2], "float32")
1118+
for i in T.serial(12, annotations={"pragma_loop_partition_hint": 1}):
1119+
if T.ignore_loop_partition(i < 10):
1120+
B[i % 2] = A[i] + 1.0
1121+
if T.ignore_loop_partition(1 <= i and i < 11):
1122+
C[(i - 1) % 2] = B[(i - 1) % 2] + 2.0
1123+
if 2 <= i:
1124+
D[i - 2] = C[i % 2] + 3.0
1125+
1126+
@T.prim_func
1127+
def expected(A: T.Buffer[(10), "float32"], D: T.Buffer[(10), "float32"]):
1128+
B = T.decl_buffer([2], "float32")
1129+
C = T.decl_buffer([2], "float32")
1130+
for i in range(2):
1131+
B[i] = A[i] + 1.0
1132+
if i == 1:
1133+
C[i - 1] = B[i - 1] + 2.0
1134+
for i in T.serial(10):
1135+
if i < 8:
1136+
B[i % 2] = A[i + 2] + 1.0
1137+
if i < 9:
1138+
C[(i + 1) % 2] = B[(i + 1) % 2] + 2.0
1139+
D[i] = C[i % 2] + 3.0
1140+
1141+
expected = expected.with_attr({"global_symbol": "main"})
1142+
after = partition_from_scheduled_tir(
1143+
before.with_attr({"global_symbol": "main"}), {}, do_flatten=False
1144+
)
1145+
tvm.ir.assert_structural_equal(after["main"], expected)
1146+
1147+
10651148
if __name__ == "__main__":
10661149
tvm.testing.main()

0 commit comments

Comments
 (0)