@@ -240,10 +240,6 @@ class ConstIntBoundAnalyzer::Impl
240240 ret.min_value = InfAwareAdd (a.min_value , b.min_value );
241241 ret.max_value = InfAwareAdd (a.max_value , b.max_value );
242242
243- if (auto bound = BoundUsingReciprocal (GetRef<PrimExpr>(op))) {
244- ret = Intersect (ret, bound.value ());
245- }
246-
247243 return ret;
248244 }
249245
@@ -254,12 +250,6 @@ class ConstIntBoundAnalyzer::Impl
254250 ret.min_value = InfAwareAdd (a.min_value , -b.max_value );
255251 ret.max_value = InfAwareAdd (a.max_value , -b.min_value );
256252
257- if (auto bound = BoundUsingReciprocal (GetRef<Sub>(op))) {
258- ret = Intersect (ret, bound.value ());
259- }
260- if (auto bound = BoundUsingReciprocal (Sub (op->b , op->a ))) {
261- ret = Intersect (ret, Negative (bound.value ()));
262- }
263253 return ret;
264254 }
265255
@@ -775,164 +765,6 @@ class ConstIntBoundAnalyzer::Impl
775765 std::ceil (std::log2 (arg_bounds.max_value )));
776766 }
777767 }
778-
779- std::optional<Entry> BoundUsingReciprocal (PrimExpr expr) {
780- // Match expressions of the form `(A+B)*C - (A*B)*D`. Depending on
781- // previous simplifications, the exact form of the expression may vary.
782- auto opt_special_case = [&]() -> std::optional<std::tuple<Entry, Entry, Entry, Entry>> {
783- PVar<PrimExpr> A, B, C, D;
784-
785- if (PMatchesOneOf{
786- (A + B) * C - (A * B) * D,
787- (A + B) * C - (B * A) * D,
788- }
789- .Match (expr)) {
790- return std::tuple{VisitExpr (A.Eval ()), VisitExpr (B.Eval ()), VisitExpr (C.Eval ()),
791- VisitExpr (D.Eval ())};
792- } else if (PMatchesOneOf{
793- (A + B) * C - A * B,
794- (A + B) * C - B * A,
795- }
796- .Match (expr)) {
797- return std::tuple{VisitExpr (A.Eval ()), VisitExpr (B.Eval ()), VisitExpr (C.Eval ()),
798- MakeBound (1 , 1 )};
799- } else if (PMatchesOneOf{
800- (A * B) * D - (A + B) * C,
801- (B * A) * D - (A + B) * C,
802- }
803- .Match (expr)) {
804- return std::tuple{Negative (VisitExpr (A.Eval ())), Negative (VisitExpr (B.Eval ())),
805- Negative (VisitExpr (C.Eval ())), Negative (VisitExpr (D.Eval ()))};
806- } else if (PMatchesOneOf{
807- A * B - (A + B) * C,
808- B * A - (A + B) * C,
809- }
810- .Match (expr)) {
811- return std::tuple{Negative (VisitExpr (A.Eval ())), Negative (VisitExpr (B.Eval ())),
812- Negative (VisitExpr (C.Eval ())), MakeBound (-1 , -1 )};
813- } else if (PMatchesOneOf{
814- (A * B) * D + (A + B) * C,
815- (B * A) * D + (A + B) * C,
816- (A + B) * C + (A * B) * D,
817- (A + B) * C + (B * A) * D,
818- }
819- .Match (expr)) {
820- return std::tuple{Negative (VisitExpr (A.Eval ())), Negative (VisitExpr (B.Eval ())),
821- VisitExpr (C.Eval ()), Negative (VisitExpr (D.Eval ()))};
822- } else if (PMatchesOneOf{
823- (A * B) + (A + B) * C,
824- (B * A) + (A + B) * C,
825- (A + B) * C + (A * B),
826- (A + B) * C + (B * A),
827- }
828- .Match (expr)) {
829- return std::tuple{Negative (VisitExpr (A.Eval ())), Negative (VisitExpr (B.Eval ())),
830- VisitExpr (C.Eval ()), MakeBound (-1 , -1 )};
831- } else {
832- return std::nullopt ;
833- }
834- }();
835-
836- if (!opt_special_case.has_value ()) {
837- return std::nullopt ;
838- }
839- // Unpacking the tuple would be cleaner with a structured binding.
840- // However, until C++20, structured bindings cannot be captured for
841- // use in a lambda function.
842- auto A_bound = std::get<0 >(*opt_special_case);
843- auto B_bound = std::get<1 >(*opt_special_case);
844- auto C_bound = std::get<2 >(*opt_special_case);
845- auto D_bound = std::get<3 >(*opt_special_case);
846-
847- // If C and D have different signs, flip the signs of A/B/C so
848- // that C will match the sign of D.
849- if ((D_bound.max_value < 0 && C_bound.min_value > 0 ) ||
850- (D_bound.min_value > 0 && C_bound.max_value < 0 )) {
851- A_bound = Negative (A_bound);
852- B_bound = Negative (B_bound);
853- C_bound = Negative (C_bound);
854- }
855-
856- // If all terms are negative, then we'll be providing an upper bound
857- // rather than a lower bound. To avoid code duplication, flip all the
858- // signs here, find a lower bound, then flip the sign to produce the
859- // upper bound of the original expression.
860- bool all_terms_negative = (A_bound.max_value < 0 && B_bound.max_value < 0 &&
861- C_bound.max_value < 0 && D_bound.max_value < 0 );
862- if (all_terms_negative) {
863- A_bound = Negative (A_bound);
864- B_bound = Negative (B_bound);
865- C_bound = Negative (C_bound);
866- D_bound = Negative (D_bound);
867- }
868-
869- bool all_terms_positive = (A_bound.min_value > 0 && B_bound.min_value > 0 &&
870- C_bound.min_value > 0 && D_bound.min_value > 0 );
871- if (!all_terms_positive) {
872- return std::nullopt ;
873- }
874-
875- // (A + B) * C - (A * B) * D
876- // (A*B*C*D) * ( (A+B)/(A*B*D) - 1/C )
877- // (A*B*C*D) * ( (1/A + 1/B)/D - 1/C )
878- // (A*B*C*D) * (1/(A*D) + 1/(B*D) - 1/C)
879- //
880- // The constant (A*B*C*D) is positive, and its minimum value is the
881- // product of the minimum values of A, B, C, and D. If the reciprocal
882- // term (1/(A*D) + 1/(B*D) - 1/C) is positive, then this constant can
883- // be used to provide a lower bound on the expression.
884-
885- bool reciprocal_term_is_positive = [&]() {
886- if (D_bound.max_value == ConstIntBound::kPosInf ) {
887- // If D can grow without bound, the `1/(A*D)` and `1/(B*D)`
888- // terms will approach zero, at which point the `-1/C` term
889- // will determine the sign the sign.
890- return false ;
891- }
892-
893- if (std::min (A_bound.max_value , B_bound.max_value ) * D_bound.max_value <= C_bound.min_value ) {
894- // 1/(A*D) + 1/(B*D) - 1/C is positive if 1/C < 1/(A*D) + 1/(B*D).
895- // Since each term is positive, this condition can hold if either
896- // A*D <= C or B*D <= C.
897- return true ;
898- }
899- if (A_bound.max_value != ConstIntBound::kPosInf &&
900- B_bound.max_value != ConstIntBound::kPosInf ) {
901- // Even if neither term is sufficient on its own, if both A and B
902- // have known upper bounds, the inequality 1/C < 1/(A*D) + 1/(B*D)
903- // may still be provable.
904- //
905- // The maximum value of the LHS is found when C is minimized. The
906- // minimum value of the RHS is found when A, B, and D are
907- // maximized. If the condition holds in this case, then it holds
908- // in all cases.
909- //
910- // 1/C_min < 1/(A_max * D_max) + 1/(B_max*D_max)
911- // A_max*B_max*D_max < C_min*B_max + C_min*A_max
912- // A_max*B_max*D_max < C_min*(A_max + B_max)
913- //
914- if (A_bound.max_value * B_bound.max_value * D_bound.max_value <
915- C_bound.min_value * (A_bound.max_value + B_bound.max_value )) {
916- return true ;
917- }
918- }
919- return false ;
920- }();
921-
922- if (!reciprocal_term_is_positive) {
923- return std::nullopt ;
924- }
925-
926- auto ret = Everything (expr->dtype );
927- ret.min_value = A_bound.min_value * B_bound.min_value * C_bound.min_value * D_bound.min_value ;
928-
929- // If we flipped the sign of the original expression, flip the sign of
930- // the resulting set of possible values.
931- if (all_terms_negative) {
932- ret = Negative (ret);
933- }
934- return ret;
935- }
936768};
937769
938770ConstIntBound ConstIntBoundAnalyzer::operator ()(const PrimExpr& expr) const {
0 commit comments