@@ -633,6 +633,27 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl {
633633 */
634634 void SeparateDivisibleParts (const SumExprNode* psum, int64_t coeff, SumExpr* out_divisible,
635635 SumExpr* out_non_divisible);
636+ /* !
637+ * \brief Pattern match and check whether lhs is fully divisible by
638+ * rhs using prod pattern simiplification expressions.
639+ *
640+ * The following two relations holds for floordiv/mod and truncdiv/mod
641+ * Note that the relation do not hold for euclidean divide and mod.
642+ *
643+ * This is because the floordiv/mod and truncdiv/mod result can be
644+ * uniquely determined by the value of the realdiv result and the
645+ * relation holds for realdiv.
646+ *
647+ * - div((a0 * a1 * c), (b0 * b1 * c)) = div((a0 * a1), (b0 * b1))
648+ * - mod((a0 * a1 * c), (b0 * b1 * c)) = mod((a0 * a1), (b0 * b1)) * c
649+ *
650+ * \param lhs The left operand to be updated.
651+ * \param rhs The right operand to be updated.
652+ * \param common_scale The common scale between lhs and rhs.
653+ * \returns The simplified result if it is successful.
654+ * \note This simplification mainly target when rhs is symbolic.
655+ */
656+ bool ProdDivSimplify (PrimExpr* lhs, PrimExpr* rhs, PrimExpr* common_scale);
636657 /* !
637658 * \brief Normalize expr to normal expr.
638659 * \param expr The input expression.
@@ -862,6 +883,66 @@ SplitExpr CanonicalSimplifier::Impl::SplitDivConst(SplitExpr lhs, int64_t cval,
862883 return lhs;
863884}
864885
886+ bool CanonicalSimplifier::Impl::ProdDivSimplify (PrimExpr* plhs, PrimExpr* prhs,
887+ PrimExpr* common_scale) {
888+ // the constant rhs case is covered by other simplifier so
889+ // we just skip to save the time
890+ if (prhs->as <IntImmNode>()) return false ;
891+ // collect lhs products and try to eliminate by matching them to prod in rhs
892+ Array<Optional<PrimExpr>> lhs_prods;
893+ PrimExpr new_rhs = make_const (prhs->dtype (), 1 );
894+ PrimExpr new_common_scale = make_const (prhs->dtype (), 1 );
895+ int64_t lhs_cscale = 1 , rhs_cscale = 1 ;
896+ int num_elimination = 0 ;
897+
898+ // collect lhs product and constant scale.
899+ auto fcollect_lhs = [&](PrimExpr value) {
900+ if (auto * intimm = value.as <tir::IntImmNode>()) {
901+ lhs_cscale *= intimm->value ;
902+ } else {
903+ lhs_prods.push_back (value);
904+ }
905+ };
906+ UnpackReduction<tir::MulNode>(*plhs, fcollect_lhs);
907+
908+ // collect rhs product and try to eliminate when possible
909+ PEqualChecker<PrimExpr> deep_equal;
910+ auto fcollect_rhs = [&](PrimExpr value) {
911+ if (auto * intimm = value.as <tir::IntImmNode>()) {
912+ rhs_cscale *= intimm->value ;
913+ } else {
914+ // try eliminate from lhs
915+ for (size_t i = 0 ; i < lhs_prods.size (); ++i) {
916+ if (lhs_prods[i].defined () && deep_equal (value, lhs_prods[i].value ())) {
917+ lhs_prods.Set (i, NullOpt);
918+ ++num_elimination;
919+ new_common_scale = new_common_scale * value;
920+ return ;
921+ }
922+ }
923+ // if elimination is not possible then construct the expression.
924+ new_rhs = new_rhs * value;
925+ }
926+ };
927+ UnpackReduction<tir::MulNode>(*prhs, fcollect_rhs);
928+ // find gcd of const scales.
929+ int64_t cscale_gcd = ZeroAwareGCD (lhs_cscale, rhs_cscale);
930+ lhs_cscale /= cscale_gcd;
931+ rhs_cscale /= cscale_gcd;
932+ // if no elimination is possible
933+ if (num_elimination == 0 && cscale_gcd == 1 ) return false ;
934+
935+ // construct prod via canonical form
936+ PrimExpr new_lhs = make_const (plhs->dtype (), 1 );
937+ for (Optional<PrimExpr> val : lhs_prods) {
938+ if (val.defined ()) new_lhs = new_lhs * val.value ();
939+ }
940+ *plhs = new_lhs * make_const (plhs->dtype (), lhs_cscale);
941+ *prhs = new_rhs * make_const (prhs->dtype (), rhs_cscale);
942+ *common_scale = new_common_scale * make_const (prhs->dtype (), cscale_gcd);
943+ return true ;
944+ }
945+
865946PrimExpr CanonicalSimplifier::Impl::VisitExpr_ (const DivNode* op) {
866947 if (!IsIndexType (op->dtype )) {
867948 return Rewriter::VisitExpr_ (op);
@@ -913,6 +994,12 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const DivNode* op) {
913994 // normal path
914995 a = Normalize (a);
915996 b = Normalize (b);
997+ PrimExpr scale;
998+ // note this is the case where b is not constant
999+ if (ProdDivSimplify (&a, &b, &scale)) {
1000+ // use operator ver so it can constant fold if b == 1
1001+ return truncdiv (a, b);
1002+ }
9161003 if (op->a .same_as (a) && op->b .same_as (b)) {
9171004 return GetRef<PrimExpr>(op);
9181005 } else {
@@ -967,6 +1054,11 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const FloorDivNode* op) {
9671054 // normal path
9681055 a = Normalize (a);
9691056 b = Normalize (b);
1057+ PrimExpr scale;
1058+ if (ProdDivSimplify (&a, &b, &scale)) {
1059+ // use operator ver so it can const fold.
1060+ return floordiv (a, b);
1061+ }
9701062 if (op->a .same_as (a) && op->b .same_as (b)) {
9711063 return GetRef<PrimExpr>(op);
9721064 } else {
@@ -1088,6 +1180,13 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const ModNode* op) {
10881180 // normal path
10891181 a = Normalize (a);
10901182 b = Normalize (b);
1183+
1184+ PrimExpr scale;
1185+ if (ProdDivSimplify (&a, &b, &scale)) {
1186+ // use operator version here so it can const fold b == 1
1187+ return truncmod (a, b) * scale;
1188+ }
1189+
10911190 if (op->a .same_as (a) && op->b .same_as (b)) {
10921191 return GetRef<PrimExpr>(op);
10931192 } else {
@@ -1146,6 +1245,13 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const FloorModNode* op) {
11461245 // normal path
11471246 a = Normalize (a);
11481247 b = Normalize (b);
1248+
1249+ PrimExpr scale;
1250+ if (ProdDivSimplify (&a, &b, &scale)) {
1251+ // use operator version here so it can const fold b == 1
1252+ return floormod (a, b) * scale;
1253+ }
1254+
11491255 if (op->a .same_as (a) && op->b .same_as (b)) {
11501256 return GetRef<PrimExpr>(op);
11511257 } else {
0 commit comments