Skip to content

Commit 4a919b4

Browse files
committed
[ARITH] Enhance CanonicalSimplify to Simplify ProdDiv
This PR enhances canonical simplify to simplify product division pattern where both side involves symbolic variables. Test cases are added.
1 parent a84a2cb commit 4a919b4

File tree

5 files changed

+160
-6
lines changed

5 files changed

+160
-6
lines changed

src/arith/bound_deducer.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,10 @@ void BoundDeducer::Deduce() {
344344
expr_map_ = EvalSetForEachSubExpr(expr_, hint_map_);
345345

346346
this->VisitExpr(expr_);
347+
348+
if (success_) {
349+
result_ = analyzer_.Simplify(result_);
350+
}
347351
}
348352

349353
void BoundDeducer::Relax() {

src/arith/canonical_simplify.cc

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -633,6 +633,27 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl {
633633
*/
634634
void SeparateDivisibleParts(const SumExprNode* psum, int64_t coeff, SumExpr* out_divisible,
635635
SumExpr* out_non_divisible);
636+
/*!
637+
* \brief Pattern match and check whether lhs is fully divisible by
638+
* rhs using prod pattern simiplification expressions.
639+
*
640+
* The following two relations holds for floordiv/mod and truncdiv/mod
641+
* Note that the relation do not hold for euclidean divide and mod.
642+
*
643+
* This is because the floordiv/mod and truncdiv/mod result can be
644+
* uniquely determined by the value of the realdiv result and the
645+
* relation holds for realdiv.
646+
*
647+
* - div((a0 * a1 * c), (b0 * b1 * c)) = div((a0 * a1), (b0 * b1))
648+
* - mod((a0 * a1 * c), (b0 * b1 * c)) = mod((a0 * a1), (b0 * b1)) * c
649+
*
650+
* \param lhs The left operand to be updated.
651+
* \param rhs The right operand to be updated.
652+
* \param common_scale The common scale between lhs and rhs.
653+
* \returns The simplified result if it is successful.
654+
* \note This simplification mainly target when rhs is symbolic.
655+
*/
656+
bool ProdDivSimplify(PrimExpr* lhs, PrimExpr* rhs, PrimExpr* common_scale);
636657
/*!
637658
* \brief Normalize expr to normal expr.
638659
* \param expr The input expression.
@@ -862,6 +883,66 @@ SplitExpr CanonicalSimplifier::Impl::SplitDivConst(SplitExpr lhs, int64_t cval,
862883
return lhs;
863884
}
864885

886+
bool CanonicalSimplifier::Impl::ProdDivSimplify(PrimExpr* plhs, PrimExpr* prhs,
887+
PrimExpr* common_scale) {
888+
// the constant rhs case is covered by other simplifier so
889+
// we just skip to save the time
890+
if (prhs->as<IntImmNode>()) return false;
891+
// collect lhs products and try to eliminate by matching them to prod in rhs
892+
Array<Optional<PrimExpr>> lhs_prods;
893+
PrimExpr new_rhs = make_const(prhs->dtype(), 1);
894+
PrimExpr new_common_scale = make_const(prhs->dtype(), 1);
895+
int64_t lhs_cscale = 1, rhs_cscale = 1;
896+
int num_elimination = 0;
897+
898+
// collect lhs product and constant scale.
899+
auto fcollect_lhs = [&](PrimExpr value) {
900+
if (auto* intimm = value.as<tir::IntImmNode>()) {
901+
lhs_cscale *= intimm->value;
902+
} else {
903+
lhs_prods.push_back(value);
904+
}
905+
};
906+
UnpackReduction<tir::MulNode>(*plhs, fcollect_lhs);
907+
908+
// collect rhs product and try to eliminate when possible
909+
PEqualChecker<PrimExpr> deep_equal;
910+
auto fcollect_rhs = [&](PrimExpr value) {
911+
if (auto* intimm = value.as<tir::IntImmNode>()) {
912+
rhs_cscale *= intimm->value;
913+
} else {
914+
// try eliminate from lhs
915+
for (size_t i = 0; i < lhs_prods.size(); ++i) {
916+
if (lhs_prods[i].defined() && deep_equal(value, lhs_prods[i].value())) {
917+
lhs_prods.Set(i, NullOpt);
918+
++num_elimination;
919+
new_common_scale = new_common_scale * value;
920+
return;
921+
}
922+
}
923+
// if elimination is not possible then construct the expression.
924+
new_rhs = new_rhs * value;
925+
}
926+
};
927+
UnpackReduction<tir::MulNode>(*prhs, fcollect_rhs);
928+
// find gcd of const scales.
929+
int64_t cscale_gcd = ZeroAwareGCD(lhs_cscale, rhs_cscale);
930+
lhs_cscale /= cscale_gcd;
931+
rhs_cscale /= cscale_gcd;
932+
// if no elimination is possible
933+
if (num_elimination == 0 && cscale_gcd == 1) return false;
934+
935+
// construct prod via canonical form
936+
PrimExpr new_lhs = make_const(plhs->dtype(), 1);
937+
for (Optional<PrimExpr> val : lhs_prods) {
938+
if (val.defined()) new_lhs = new_lhs * val.value();
939+
}
940+
*plhs = new_lhs * make_const(plhs->dtype(), lhs_cscale);
941+
*prhs = new_rhs * make_const(prhs->dtype(), rhs_cscale);
942+
*common_scale = new_common_scale * make_const(prhs->dtype(), cscale_gcd);
943+
return true;
944+
}
945+
865946
PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const DivNode* op) {
866947
if (!IsIndexType(op->dtype)) {
867948
return Rewriter::VisitExpr_(op);
@@ -913,6 +994,12 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const DivNode* op) {
913994
// normal path
914995
a = Normalize(a);
915996
b = Normalize(b);
997+
PrimExpr scale;
998+
// note this is the case where b is not constant
999+
if (ProdDivSimplify(&a, &b, &scale)) {
1000+
// use operator ver so it can constant fold if b == 1
1001+
return truncdiv(a, b);
1002+
}
9161003
if (op->a.same_as(a) && op->b.same_as(b)) {
9171004
return GetRef<PrimExpr>(op);
9181005
} else {
@@ -967,6 +1054,11 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const FloorDivNode* op) {
9671054
// normal path
9681055
a = Normalize(a);
9691056
b = Normalize(b);
1057+
PrimExpr scale;
1058+
if (ProdDivSimplify(&a, &b, &scale)) {
1059+
// use operator ver so it can const fold.
1060+
return floordiv(a, b);
1061+
}
9701062
if (op->a.same_as(a) && op->b.same_as(b)) {
9711063
return GetRef<PrimExpr>(op);
9721064
} else {
@@ -1088,6 +1180,13 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const ModNode* op) {
10881180
// normal path
10891181
a = Normalize(a);
10901182
b = Normalize(b);
1183+
1184+
PrimExpr scale;
1185+
if (ProdDivSimplify(&a, &b, &scale)) {
1186+
// use operator version here so it can const fold b == 1
1187+
return truncmod(a, b) * scale;
1188+
}
1189+
10911190
if (op->a.same_as(a) && op->b.same_as(b)) {
10921191
return GetRef<PrimExpr>(op);
10931192
} else {
@@ -1146,6 +1245,13 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const FloorModNode* op) {
11461245
// normal path
11471246
a = Normalize(a);
11481247
b = Normalize(b);
1248+
1249+
PrimExpr scale;
1250+
if (ProdDivSimplify(&a, &b, &scale)) {
1251+
// use operator version here so it can const fold b == 1
1252+
return floormod(a, b) * scale;
1253+
}
1254+
11491255
if (op->a.same_as(a) && op->b.same_as(b)) {
11501256
return GetRef<PrimExpr>(op);
11511257
} else {

src/arith/pattern_match.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -915,6 +915,23 @@ matches_one_of(const TPattern&... patterns) {
915915
return PMatchesOneOf<TPattern...>(patterns...);
916916
}
917917

918+
/*!
919+
* \brief Unpack reduction by calling each leaf via fleaf.
920+
*
921+
* \param value The expression value.
922+
* \tparam TNode the reduction node to match.
923+
* \tparam FLeaf The callback function at leaf.
924+
*/
925+
template <typename TNode, typename FLeaf>
926+
inline void UnpackReduction(const PrimExpr& value, FLeaf fleaf) {
927+
if (const TNode* node = value.as<TNode>()) {
928+
UnpackReduction<TNode, FLeaf>(node->a, fleaf);
929+
UnpackReduction<TNode, FLeaf>(node->b, fleaf);
930+
} else {
931+
fleaf(value);
932+
}
933+
}
934+
918935
} // namespace arith
919936
} // namespace tvm
920937
#endif // TVM_ARITH_PATTERN_MATCH_H_

tests/python/unittest/test_arith_canonical_simplify.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,5 +386,34 @@ def test_simplify_normalize_min_value_expr():
386386
ck.verify(0 == x + te.min_value("int32"), False)
387387

388388

389+
def test_proddiv_simplify():
390+
ck = CanonicalChecker()
391+
flm = tvm.te.floormod
392+
fld = tvm.te.floordiv
393+
tdiv = tvm.te.truncdiv
394+
tmod = tvm.te.truncmod
395+
396+
x, y, z = te.var("x"), te.var("y"), te.var("y")
397+
398+
ck.verify(flm(x * 32 * x, x), 0)
399+
ck.verify(flm(z * x * 32 * x * y, x * z), 0)
400+
ck.verify(flm(z * x * 32 * x * y, x * z * y * 8 * x), 0)
401+
ck.verify(flm(z * x * 32 * (x * y), 6 * x * z), flm(x * y * 16, 3) * (x * z * 2))
402+
ck.verify(flm(x * 32 * x, x * z), flm(x * 32, z) * x)
403+
404+
ck.verify(tmod(x * 32 * x, x), 0)
405+
ck.verify(tmod(z * x * 32 * x * y, x * z), 0)
406+
ck.verify(tmod(z * x * 32 * (x * y), 6 * x * z), tmod(x * y * 16, 3) * (x * z * 2))
407+
ck.verify(tmod(x * 32 * x, x * z), tmod(x * 32, z) * x)
408+
409+
ck.verify(fld(x * 2 * x * z, 4 * x * x * x), fld(z, x * 2))
410+
ck.verify(fld(x * (2 * y) * 3, 3 * y), x * 2)
411+
ck.verify(fld(x * (2 * y) * 3, 3 * y * z), fld(x * 2, z))
412+
413+
ck.verify(tdiv(x * 2 * x * z, 4 * x * x * x), tdiv(z, x * 2))
414+
ck.verify(tdiv(x * (2 * y) * 3, 3 * y), x * 2)
415+
ck.verify(tdiv(x * (2 * y) * 3, 3 * y * z), tdiv(x * 2, z))
416+
417+
389418
if __name__ == "__main__":
390419
tvm.testing.main()

tests/python/unittest/test_arith_deduce_bound.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -114,12 +114,10 @@ def test_deduce():
114114
assert str(res9.max_value) == "neg_inf"
115115
assert str(res9.min_value) == "pos_inf"
116116

117-
# Unsatisfiable Mul in `EQ`
118-
res10 = tvm.arith.deduce_bound(
119-
a, (b * a == b), {b: b_s}, {}
120-
) # simplifier is not able to prove that (b % b == 0)
121-
assert str(res10.max_value) == "neg_inf"
122-
assert str(res10.min_value) == "pos_inf"
117+
res10 = tvm.arith.deduce_bound(a, (b * a == b), {b: b_s}, {})
118+
# simplifier is now able to prove symbolic relation (b * a % b == 0)
119+
tvm.testing.assert_prim_expr_equal(res10.max_value, 1)
120+
tvm.testing.assert_prim_expr_equal(res10.min_value, 1)
123121

124122

125123
def test_check():

0 commit comments

Comments
 (0)