@@ -574,6 +574,23 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
574574 p->stream << ' )' ;
575575 });
576576
577+ /* !
578+ * \brief Unpack reduction by calling each leaf via fleaf
579+ *
580+ * \param value The expression value.
581+ * \tparam TNode the reduction node to match.
582+ * \tparam FLeaf The callback function at leaf.
583+ */
584+ template <typename TNode, typename FLeaf>
585+ void UnpackReduction (const PrimExpr& value, FLeaf fleaf) {
586+ if (const TNode* node = value.as <TNode>()) {
587+ UnpackReduction<TNode, FLeaf>(node->a , fleaf);
588+ UnpackReduction<TNode, FLeaf>(node->b , fleaf);
589+ } else {
590+ fleaf (value);
591+ }
592+ }
593+
577594// Sub-class RewriteSimplifier::Impl to take benefit of
578595// rewriter for condition simplification etc.
579596class CanonicalSimplifier ::Impl : public RewriteSimplifier::Impl {
@@ -633,6 +650,27 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl {
633650 */
634651 void SeparateDivisibleParts (const SumExprNode* psum, int64_t coeff, SumExpr* out_divisible,
635652 SumExpr* out_non_divisible);
653+ /* !
654+ * \brief Pattern match and check whether lhs is fully divisible by
655+ * rhs using prod pattern simiplification expressions.
656+ *
657+ * The following two relations holds for floordiv/mod and truncdiv/mod
658+ * Note that the relation do not hold for euclidean divide and mod.
659+ *
660+ * This is because the floordiv/mod and truncdiv/mod result can be
661+ * uniquely determined by the value of the realdiv result and the
662+ * relation holds for realdiv.
663+ *
664+ * - div((a0 * a1 * c), (b0 * b1 * c)) = div((a0 * a1), (b0 * b1))
665+ * - mod((a0 * a1 * c), (b0 * b1 * c)) = mod((a0 * a1), (b0 * b1)) * c
666+ *
667+ * \param lhs The left operand to be updated.
668+ * \param rhs The right operand to be updated.
669+ * \param common_scale The common scale between lhs and rhs.
670+ * \returns The simplified result if it is successful.
671+ * \note This simplification mainly target when rhs is symbolic.
672+ */
673+ bool ProdDivSimplify (PrimExpr* lhs, PrimExpr* rhs, PrimExpr* common_scale);
636674 /* !
637675 * \brief Normalize expr to normal expr.
638676 * \param expr The input expression.
@@ -862,6 +900,66 @@ SplitExpr CanonicalSimplifier::Impl::SplitDivConst(SplitExpr lhs, int64_t cval,
862900 return lhs;
863901}
864902
903+ bool CanonicalSimplifier::Impl::ProdDivSimplify (PrimExpr* plhs, PrimExpr* prhs,
904+ PrimExpr* common_scale) {
905+ // the constant rhs case is covered by other simplifier so
906+ // we just skip to save the time
907+ if (prhs->as <IntImmNode>()) return false ;
908+ // collect lhs products and try to eliminate by matching them to prod in rhs
909+ Array<Optional<PrimExpr>> lhs_prods;
910+ PrimExpr new_rhs = make_const (prhs->dtype (), 1 );
911+ PrimExpr new_common_scale = make_const (prhs->dtype (), 1 );
912+ int64_t lhs_cscale = 1 , rhs_cscale = 1 ;
913+ int num_elimination = 0 ;
914+
915+ // collect lhs product and constant scale.
916+ auto fcollect_lhs = [&](PrimExpr value) {
917+ if (auto * intimm = value.as <tir::IntImmNode>()) {
918+ lhs_cscale *= intimm->value ;
919+ } else {
920+ lhs_prods.push_back (value);
921+ }
922+ };
923+ UnpackReduction<tir::MulNode>(*plhs, fcollect_lhs);
924+
925+ // collect rhs product and try to eliminate when possible
926+ PEqualChecker<PrimExpr> deep_equal;
927+ auto fcollect_rhs = [&](PrimExpr value) {
928+ if (auto * intimm = value.as <tir::IntImmNode>()) {
929+ rhs_cscale *= intimm->value ;
930+ } else {
931+ // try eliminate from lhs
932+ for (size_t i = 0 ; i < lhs_prods.size (); ++i) {
933+ if (lhs_prods[i].defined () && deep_equal (value, lhs_prods[i].value ())) {
934+ lhs_prods.Set (i, NullOpt);
935+ ++num_elimination;
936+ new_common_scale = new_common_scale * value;
937+ return ;
938+ }
939+ }
940+ // if elimination is not possible then construct the expression.
941+ new_rhs = new_rhs * value;
942+ }
943+ };
944+ UnpackReduction<tir::MulNode>(*prhs, fcollect_rhs);
945+ // find gcd of const scales.
946+ int64_t cscale_gcd = ZeroAwareGCD (lhs_cscale, rhs_cscale);
947+ lhs_cscale /= cscale_gcd;
948+ rhs_cscale /= cscale_gcd;
949+ // if no elimination is possible
950+ if (num_elimination == 0 && cscale_gcd == 1 ) return false ;
951+
952+ // construct prod via canonical form
953+ PrimExpr new_lhs = make_const (plhs->dtype (), 1 );
954+ for (Optional<PrimExpr> val : lhs_prods) {
955+ if (val.defined ()) new_lhs = new_lhs * val.value ();
956+ }
957+ *plhs = new_lhs * make_const (plhs->dtype (), lhs_cscale);
958+ *prhs = new_rhs * make_const (prhs->dtype (), rhs_cscale);
959+ *common_scale = new_common_scale * make_const (prhs->dtype (), cscale_gcd);
960+ return true ;
961+ }
962+
865963PrimExpr CanonicalSimplifier::Impl::VisitExpr_ (const DivNode* op) {
866964 if (!IsIndexType (op->dtype )) {
867965 return Rewriter::VisitExpr_ (op);
@@ -913,6 +1011,12 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const DivNode* op) {
9131011 // normal path
9141012 a = Normalize (a);
9151013 b = Normalize (b);
1014+ PrimExpr scale;
1015+ // note this is the case where b is not constant
1016+ if (ProdDivSimplify (&a, &b, &scale)) {
1017+ // use operator ver so it can constant fold if b == 1
1018+ return truncdiv (a, b);
1019+ }
9161020 if (op->a .same_as (a) && op->b .same_as (b)) {
9171021 return GetRef<PrimExpr>(op);
9181022 } else {
@@ -967,6 +1071,11 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const FloorDivNode* op) {
9671071 // normal path
9681072 a = Normalize (a);
9691073 b = Normalize (b);
1074+ PrimExpr scale;
1075+ if (ProdDivSimplify (&a, &b, &scale)) {
1076+ // use operator ver so it can const fold.
1077+ return floordiv (a, b);
1078+ }
9701079 if (op->a .same_as (a) && op->b .same_as (b)) {
9711080 return GetRef<PrimExpr>(op);
9721081 } else {
@@ -1088,6 +1197,13 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const ModNode* op) {
10881197 // normal path
10891198 a = Normalize (a);
10901199 b = Normalize (b);
1200+
1201+ PrimExpr scale;
1202+ if (ProdDivSimplify (&a, &b, &scale)) {
1203+ // use operator version here so it can const fold b == 1
1204+ return truncmod (a, b) * scale;
1205+ }
1206+
10911207 if (op->a .same_as (a) && op->b .same_as (b)) {
10921208 return GetRef<PrimExpr>(op);
10931209 } else {
@@ -1146,6 +1262,13 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const FloorModNode* op) {
11461262 // normal path
11471263 a = Normalize (a);
11481264 b = Normalize (b);
1265+
1266+ PrimExpr scale;
1267+ if (ProdDivSimplify (&a, &b, &scale)) {
1268+ // use operator version here so it can const fold b == 1
1269+ return floormod (a, b) * scale;
1270+ }
1271+
11491272 if (op->a .same_as (a) && op->b .same_as (b)) {
11501273 return GetRef<PrimExpr>(op);
11511274 } else {
0 commit comments