Skip to content

Commit ab0e91c

Browse files
MasterJH5574tqchen
andcommitted
[Cherry-Pick][ARITH] Enhance CanonicalSimplify to Simplify ProdDiv (apache#174)
This PR enhances canonical simplify to simplify product division pattern where both side involves symbolic variables. Test cases are added. Co-authored-by: Tianqi Chen <[email protected]>
1 parent 2bf8bb6 commit ab0e91c

File tree

3 files changed

+156
-6
lines changed

3 files changed

+156
-6
lines changed

src/arith/canonical_simplify.cc

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -574,6 +574,23 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
574574
p->stream << ')';
575575
});
576576

577+
/*!
578+
* \brief Unpack reduction by calling each leaf via fleaf
579+
*
580+
* \param value The expression value.
581+
* \tparam TNode the reduction node to match.
582+
* \tparam FLeaf The callback function at leaf.
583+
*/
584+
template <typename TNode, typename FLeaf>
585+
void UnpackReduction(const PrimExpr& value, FLeaf fleaf) {
586+
if (const TNode* node = value.as<TNode>()) {
587+
UnpackReduction<TNode, FLeaf>(node->a, fleaf);
588+
UnpackReduction<TNode, FLeaf>(node->b, fleaf);
589+
} else {
590+
fleaf(value);
591+
}
592+
}
593+
577594
// Sub-class RewriteSimplifier::Impl to take benefit of
578595
// rewriter for condition simplification etc.
579596
class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl {
@@ -633,6 +650,27 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl {
633650
*/
634651
void SeparateDivisibleParts(const SumExprNode* psum, int64_t coeff, SumExpr* out_divisible,
635652
SumExpr* out_non_divisible);
653+
/*!
654+
* \brief Pattern match and check whether lhs is fully divisible by
655+
* rhs using prod pattern simiplification expressions.
656+
*
657+
* The following two relations holds for floordiv/mod and truncdiv/mod
658+
* Note that the relation do not hold for euclidean divide and mod.
659+
*
660+
* This is because the floordiv/mod and truncdiv/mod result can be
661+
* uniquely determined by the value of the realdiv result and the
662+
* relation holds for realdiv.
663+
*
664+
* - div((a0 * a1 * c), (b0 * b1 * c)) = div((a0 * a1), (b0 * b1))
665+
* - mod((a0 * a1 * c), (b0 * b1 * c)) = mod((a0 * a1), (b0 * b1)) * c
666+
*
667+
* \param lhs The left operand to be updated.
668+
* \param rhs The right operand to be updated.
669+
* \param common_scale The common scale between lhs and rhs.
670+
* \returns The simplified result if it is successful.
671+
* \note This simplification mainly target when rhs is symbolic.
672+
*/
673+
bool ProdDivSimplify(PrimExpr* lhs, PrimExpr* rhs, PrimExpr* common_scale);
636674
/*!
637675
* \brief Normalize expr to normal expr.
638676
* \param expr The input expression.
@@ -862,6 +900,66 @@ SplitExpr CanonicalSimplifier::Impl::SplitDivConst(SplitExpr lhs, int64_t cval,
862900
return lhs;
863901
}
864902

903+
bool CanonicalSimplifier::Impl::ProdDivSimplify(PrimExpr* plhs, PrimExpr* prhs,
904+
PrimExpr* common_scale) {
905+
// the constant rhs case is covered by other simplifier so
906+
// we just skip to save the time
907+
if (prhs->as<IntImmNode>()) return false;
908+
// collect lhs products and try to eliminate by matching them to prod in rhs
909+
Array<Optional<PrimExpr>> lhs_prods;
910+
PrimExpr new_rhs = make_const(prhs->dtype(), 1);
911+
PrimExpr new_common_scale = make_const(prhs->dtype(), 1);
912+
int64_t lhs_cscale = 1, rhs_cscale = 1;
913+
int num_elimination = 0;
914+
915+
// collect lhs product and constant scale.
916+
auto fcollect_lhs = [&](PrimExpr value) {
917+
if (auto* intimm = value.as<tir::IntImmNode>()) {
918+
lhs_cscale *= intimm->value;
919+
} else {
920+
lhs_prods.push_back(value);
921+
}
922+
};
923+
UnpackReduction<tir::MulNode>(*plhs, fcollect_lhs);
924+
925+
// collect rhs product and try to eliminate when possible
926+
PEqualChecker<PrimExpr> deep_equal;
927+
auto fcollect_rhs = [&](PrimExpr value) {
928+
if (auto* intimm = value.as<tir::IntImmNode>()) {
929+
rhs_cscale *= intimm->value;
930+
} else {
931+
// try eliminate from lhs
932+
for (size_t i = 0; i < lhs_prods.size(); ++i) {
933+
if (lhs_prods[i].defined() && deep_equal(value, lhs_prods[i].value())) {
934+
lhs_prods.Set(i, NullOpt);
935+
++num_elimination;
936+
new_common_scale = new_common_scale * value;
937+
return;
938+
}
939+
}
940+
// if elimination is not possible then construct the expression.
941+
new_rhs = new_rhs * value;
942+
}
943+
};
944+
UnpackReduction<tir::MulNode>(*prhs, fcollect_rhs);
945+
// find gcd of const scales.
946+
int64_t cscale_gcd = ZeroAwareGCD(lhs_cscale, rhs_cscale);
947+
lhs_cscale /= cscale_gcd;
948+
rhs_cscale /= cscale_gcd;
949+
// if no elimination is possible
950+
if (num_elimination == 0 && cscale_gcd == 1) return false;
951+
952+
// construct prod via canonical form
953+
PrimExpr new_lhs = make_const(plhs->dtype(), 1);
954+
for (Optional<PrimExpr> val : lhs_prods) {
955+
if (val.defined()) new_lhs = new_lhs * val.value();
956+
}
957+
*plhs = new_lhs * make_const(plhs->dtype(), lhs_cscale);
958+
*prhs = new_rhs * make_const(prhs->dtype(), rhs_cscale);
959+
*common_scale = new_common_scale * make_const(prhs->dtype(), cscale_gcd);
960+
return true;
961+
}
962+
865963
PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const DivNode* op) {
866964
if (!IsIndexType(op->dtype)) {
867965
return Rewriter::VisitExpr_(op);
@@ -913,6 +1011,12 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const DivNode* op) {
9131011
// normal path
9141012
a = Normalize(a);
9151013
b = Normalize(b);
1014+
PrimExpr scale;
1015+
// note this is the case where b is not constant
1016+
if (ProdDivSimplify(&a, &b, &scale)) {
1017+
// use operator ver so it can constant fold if b == 1
1018+
return truncdiv(a, b);
1019+
}
9161020
if (op->a.same_as(a) && op->b.same_as(b)) {
9171021
return GetRef<PrimExpr>(op);
9181022
} else {
@@ -967,6 +1071,11 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const FloorDivNode* op) {
9671071
// normal path
9681072
a = Normalize(a);
9691073
b = Normalize(b);
1074+
PrimExpr scale;
1075+
if (ProdDivSimplify(&a, &b, &scale)) {
1076+
// use operator ver so it can const fold.
1077+
return floordiv(a, b);
1078+
}
9701079
if (op->a.same_as(a) && op->b.same_as(b)) {
9711080
return GetRef<PrimExpr>(op);
9721081
} else {
@@ -1088,6 +1197,13 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const ModNode* op) {
10881197
// normal path
10891198
a = Normalize(a);
10901199
b = Normalize(b);
1200+
1201+
PrimExpr scale;
1202+
if (ProdDivSimplify(&a, &b, &scale)) {
1203+
// use operator version here so it can const fold b == 1
1204+
return truncmod(a, b) * scale;
1205+
}
1206+
10911207
if (op->a.same_as(a) && op->b.same_as(b)) {
10921208
return GetRef<PrimExpr>(op);
10931209
} else {
@@ -1146,6 +1262,13 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const FloorModNode* op) {
11461262
// normal path
11471263
a = Normalize(a);
11481264
b = Normalize(b);
1265+
1266+
PrimExpr scale;
1267+
if (ProdDivSimplify(&a, &b, &scale)) {
1268+
// use operator version here so it can const fold b == 1
1269+
return floormod(a, b) * scale;
1270+
}
1271+
11491272
if (op->a.same_as(a) && op->b.same_as(b)) {
11501273
return GetRef<PrimExpr>(op);
11511274
} else {

tests/python/relax/test_op_manipulate.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def test_reshape_infer_struct_info_shape_symbolic():
134134
_check_inference(
135135
bb,
136136
relax.op.reshape(x, (d, c, b, -1)),
137-
relax.TensorStructInfo((d, c, b, tir.floordiv(a * b * c * d, d * c * b)), "float32"),
137+
relax.TensorStructInfo((d, c, b, a), "float32"),
138138
)
139139
_check_inference(
140140
bb,
@@ -144,12 +144,12 @@ def test_reshape_infer_struct_info_shape_symbolic():
144144
_check_inference(
145145
bb,
146146
relax.op.reshape(x, (2, -1, a)),
147-
relax.TensorStructInfo((2, tir.floordiv(a * b * c * d, a * 2), a), "float32"),
147+
relax.TensorStructInfo((2, tir.floordiv(b * c * d, 2), a), "float32"),
148148
)
149149
_check_inference(
150150
bb,
151151
relax.op.reshape(x, (c, -1, d, b)),
152-
relax.TensorStructInfo((c, tir.floordiv(a * b * c * d, c * d * b), d, b), "float32"),
152+
relax.TensorStructInfo((c, a, d, b), "float32"),
153153
)
154154
_check_inference(
155155
bb,
@@ -159,9 +159,7 @@ def test_reshape_infer_struct_info_shape_symbolic():
159159
_check_inference(
160160
bb,
161161
relax.op.reshape(x, (c, a * b * d, -1)),
162-
relax.TensorStructInfo(
163-
(c, a * b * d, tir.floordiv(a * b * c * d, c * (a * b * d))), "float32"
164-
),
162+
relax.TensorStructInfo((c, a * b * d, 1), "float32"),
165163
)
166164
# Remove Var from StructInfo when we can
167165
_check_inference(bb, relax.op.reshape(x, s0), relax.TensorStructInfo((c, a, d, b), "float32"))

tests/python/unittest/test_arith_canonical_simplify.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,5 +386,34 @@ def test_simplify_normalize_min_value_expr():
386386
ck.verify(0 == x + te.min_value("int32"), False)
387387

388388

389+
def test_proddiv_simplify():
390+
ck = CanonicalChecker()
391+
flm = tvm.te.floormod
392+
fld = tvm.te.floordiv
393+
tdiv = tvm.te.truncdiv
394+
tmod = tvm.te.truncmod
395+
396+
x, y, z = te.var("x"), te.var("y"), te.var("y")
397+
398+
ck.verify(flm(x * 32 * x, x), 0)
399+
ck.verify(flm(z * x * 32 * x * y, x * z), 0)
400+
ck.verify(flm(z * x * 32 * x * y, x * z * y * 8 * x), 0)
401+
ck.verify(flm(z * x * 32 * (x * y), 6 * x * z), flm(x * y * 16, 3) * (x * z * 2))
402+
ck.verify(flm(x * 32 * x, x * z), flm(x * 32, z) * x)
403+
404+
ck.verify(tmod(x * 32 * x, x), 0)
405+
ck.verify(tmod(z * x * 32 * x * y, x * z), 0)
406+
ck.verify(tmod(z * x * 32 * (x * y), 6 * x * z), tmod(x * y * 16, 3) * (x * z * 2))
407+
ck.verify(tmod(x * 32 * x, x * z), tmod(x * 32, z) * x)
408+
409+
ck.verify(fld(x * 2 * x * z, 4 * x * x * x), fld(z, x * 2))
410+
ck.verify(fld(x * (2 * y) * 3, 3 * y), x * 2)
411+
ck.verify(fld(x * (2 * y) * 3, 3 * y * z), fld(x * 2, z))
412+
413+
ck.verify(tdiv(x * 2 * x * z, 4 * x * x * x), tdiv(z, x * 2))
414+
ck.verify(tdiv(x * (2 * y) * 3, 3 * y), x * 2)
415+
ck.verify(tdiv(x * (2 * y) * 3, 3 * y * z), tdiv(x * 2, z))
416+
417+
389418
if __name__ == "__main__":
390419
tvm.testing.main()

0 commit comments

Comments
 (0)