@@ -75,13 +75,15 @@ inline std::vector<const PrimExpr*> ExprSplitAddition(const PrimExpr& expr) {
7575}
7676
7777// Searches for the following types of expr:
78- // mult_expr = (a1 + a2 + ... + aj + c / (k1 * k2 * ... * ki) * k1 * ... * kt-1 ) * kt * ... * ki
79- // mod_l_expr = c
78+ // mult_expr = (a1 + a2 + ... + aj + c1 / (k1 * k2 * ... * ki) * k1 * ... * kt-1 ) * kt * ... * ki
79+ // mod_l_expr = c2
8080// mod_r_expr = k1 * k2 * ... * ki
81- // If it can be optimized, returns (true, (a1 + a2 + ... + aj) * kt * ... * ki + c)
81+ // where c1 ~= c2 mod k1 * k2 * ... * ki
82+ // If it can be optimized, returns (true, (a1 + a2 + ... + aj) * kt * ... * ki + c1)
8283// Currently the we will not search the add/mult combinations exhaustively
8384// as it will take too much computation.
84- inline std::pair<bool , PrimExpr> MergeMulModInner (const PrimExpr& mult_expr,
85+ inline std::pair<bool , PrimExpr> MergeMulModInner (arith::Analyzer* analyzer,
86+ const PrimExpr& mult_expr,
8587 const PrimExpr& mod_l_expr,
8688 const PrimExpr& mod_r_expr) {
8789 using namespace tir ;
@@ -119,9 +121,10 @@ inline std::pair<bool, PrimExpr> MergeMulModInner(const PrimExpr& mult_expr,
119121 } else if (inner_div_ptr) {
120122 PrimExpr overall_mult = mult_inner.get () ? mult_inner * mult_outer : mult_outer;
121123 if (expr_equal (overall_mult, inner_div_ptr->b ) && expr_equal (overall_mult, mod_r_expr) &&
122- expr_equal ( inner_div_ptr->a , mod_l_expr)) {
124+ analyzer-> CanProveEqual ( floormod ( inner_div_ptr->a - mod_l_expr, mod_r_expr), 0 )) {
123125 // Found!
124- PrimExpr ret = no_opt_sum.get () ? no_opt_sum * mult_outer + mod_l_expr : mod_l_expr;
126+ PrimExpr ret =
127+ no_opt_sum.get () ? no_opt_sum * mult_outer + inner_div_ptr->a : inner_div_ptr->a ;
125128 return std::make_pair (true , ret);
126129 } else {
127130 return std::make_pair (false , PrimExpr ());
@@ -204,7 +207,7 @@ inline PrimExpr MergeMulMod(arith::Analyzer* analyzer, const PrimExpr& base) {
204207 bool inner_find_opt = false ;
205208 while (mult_it != mult_exprs.end ()) {
206209 std::pair<bool , PrimExpr> ret =
207- MergeMulModInner (*mult_it, search_mod_it->first , search_mod_it->second );
210+ MergeMulModInner (analyzer, *mult_it, search_mod_it->first , search_mod_it->second );
208211 if (ret.first ) {
209212 inner_find_opt = true ;
210213 auto temp_mod_it = search_mod_it;
0 commit comments