2626#include < tvm/tir/expr_functor.h>
2727
2828#include < algorithm>
29+ #include < optional>
2930
3031#include " constraint_extract.h"
3132#include " int_operator.h"
@@ -81,6 +82,16 @@ struct ConstIntBoundAnalyzer::Entry {
8182 bool operator ==(const Entry& other) const {
8283 return min_value == other.min_value && max_value == other.max_value ;
8384 }
85+
86+ friend std::ostream& operator <<(std::ostream& os, const Entry& entry) {
87+ os << " Entry[" ;
88+ PrintBoundValue (os, entry.min_value );
89+ os << " , " ;
90+ PrintBoundValue (os, entry.max_value );
91+ os << " ]" ;
92+
93+ return os;
94+ }
8495};
8596
8697class ConstIntBoundAnalyzer ::Impl
@@ -228,6 +239,11 @@ class ConstIntBoundAnalyzer::Impl
228239 Entry ret;
229240 ret.min_value = InfAwareAdd (a.min_value , b.min_value );
230241 ret.max_value = InfAwareAdd (a.max_value , b.max_value );
242+
243+ if (auto bound = BoundUsingReciprocal (GetRef<PrimExpr>(op))) {
244+ ret = Intersect (ret, bound.value ());
245+ }
246+
231247 return ret;
232248 }
233249
@@ -237,6 +253,13 @@ class ConstIntBoundAnalyzer::Impl
237253 Entry ret;
238254 ret.min_value = InfAwareAdd (a.min_value , -b.max_value );
239255 ret.max_value = InfAwareAdd (a.max_value , -b.min_value );
256+
257+ if (auto bound = BoundUsingReciprocal (GetRef<Sub>(op))) {
258+ ret = Intersect (ret, bound.value ());
259+ }
260+ if (auto bound = BoundUsingReciprocal (Sub (op->b , op->a ))) {
261+ ret = Intersect (ret, Negative (bound.value ()));
262+ }
240263 return ret;
241264 }
242265
@@ -628,6 +651,25 @@ class ConstIntBoundAnalyzer::Impl
628651 ret.max_value = std::min (a.max_value , b.max_value );
629652 return ret;
630653 }
654+ /* !
655+ * \brief Flip the sign of a set.
656+ * \param entry The set of values
657+ */
658+ static Entry Negative (Entry entry) {
659+ Entry ret;
660+ if (entry.max_value == kPosInf ) {
661+ ret.min_value = kNegInf ;
662+ } else {
663+ ret.min_value = -entry.max_value ;
664+ }
665+ if (entry.min_value == kNegInf ) {
666+ ret.max_value = kPosInf ;
667+ } else {
668+ ret.max_value = -entry.min_value ;
669+ }
670+
671+ return ret;
672+ }
631673 /* !
632674 * \brief return everything dtype can represent.
633675 * \param dtype The data type.
@@ -733,6 +775,164 @@ class ConstIntBoundAnalyzer::Impl
733775 std::ceil (std::log2 (arg_bounds.max_value )));
734776 }
735777 }
778+
779+ std::optional<Entry> BoundUsingReciprocal (PrimExpr expr) {
780+ // Match expressions of the form `(A+B)*C - (A*B)*D`. Depending on
781+ // previous simplifications, the exact form of the expression may vary.
782+ auto opt_special_case = [&]() -> std::optional<std::tuple<Entry, Entry, Entry, Entry>> {
783+ PVar<PrimExpr> A, B, C, D;
784+
785+ if (PMatchesOneOf{
786+ (A + B) * C - (A * B) * D,
787+ (A + B) * C - (B * A) * D,
788+ }
789+ .Match (expr)) {
790+ return std::tuple{VisitExpr (A.Eval ()), VisitExpr (B.Eval ()), VisitExpr (C.Eval ()),
791+ VisitExpr (D.Eval ())};
792+ } else if (PMatchesOneOf{
793+ (A + B) * C - A * B,
794+ (A + B) * C - B * A,
795+ }
796+ .Match (expr)) {
797+ return std::tuple{VisitExpr (A.Eval ()), VisitExpr (B.Eval ()), VisitExpr (C.Eval ()),
798+ MakeBound (1 , 1 )};
799+ } else if (PMatchesOneOf{
800+ (A * B) * D - (A + B) * C,
801+ (B * A) * D - (A + B) * C,
802+ }
803+ .Match (expr)) {
804+ return std::tuple{Negative (VisitExpr (A.Eval ())), Negative (VisitExpr (B.Eval ())),
805+ Negative (VisitExpr (C.Eval ())), Negative (VisitExpr (D.Eval ()))};
806+ } else if (PMatchesOneOf{
807+ A * B - (A + B) * C,
808+ B * A - (A + B) * C,
809+ }
810+ .Match (expr)) {
811+ return std::tuple{Negative (VisitExpr (A.Eval ())), Negative (VisitExpr (B.Eval ())),
812+ Negative (VisitExpr (C.Eval ())), MakeBound (-1 , -1 )};
813+ } else if (PMatchesOneOf{
814+ (A * B) * D + (A + B) * C,
815+ (B * A) * D + (A + B) * C,
816+ (A + B) * C + (A * B) * D,
817+ (A + B) * C + (B * A) * D,
818+ }
819+ .Match (expr)) {
820+ return std::tuple{Negative (VisitExpr (A.Eval ())), Negative (VisitExpr (B.Eval ())),
821+ VisitExpr (C.Eval ()), Negative (VisitExpr (D.Eval ()))};
822+ } else if (PMatchesOneOf{
823+ (A * B) + (A + B) * C,
824+ (B * A) + (A + B) * C,
825+ (A + B) * C + (A * B),
826+ (A + B) * C + (B * A),
827+ }
828+ .Match (expr)) {
829+ return std::tuple{Negative (VisitExpr (A.Eval ())), Negative (VisitExpr (B.Eval ())),
830+ VisitExpr (C.Eval ()), MakeBound (-1 , -1 )};
831+ } else {
832+ return std::nullopt ;
833+ }
834+ }();
835+
836+ if (!opt_special_case.has_value ()) {
837+ return std::nullopt ;
838+ }
839+ // Unpacking the tuple would be cleaner with a structured binding.
840+ // However, until C++20, structured bindings cannot be captured for
841+ // use in a lambda function.
842+ auto A_bound = std::get<0 >(*opt_special_case);
843+ auto B_bound = std::get<1 >(*opt_special_case);
844+ auto C_bound = std::get<2 >(*opt_special_case);
845+ auto D_bound = std::get<3 >(*opt_special_case);
846+
847+ // If C and D have different signs, flip the signs of A/B/C so
848+ // that C will match the sign of D.
849+ if ((D_bound.max_value < 0 && C_bound.min_value > 0 ) ||
850+ (D_bound.min_value > 0 && C_bound.max_value < 0 )) {
851+ A_bound = Negative (A_bound);
852+ B_bound = Negative (B_bound);
853+ C_bound = Negative (C_bound);
854+ }
855+
856+ // If all terms are negative, then we'll be providing an upper bound
857+ // rather than a lower bound. To avoid code duplication, flip all the
858+ // signs here, find a lower bound, then flip the sign to produce the
859+ // upper bound of the original expression.
860+ bool all_terms_negative = (A_bound.max_value < 0 && B_bound.max_value < 0 &&
861+ C_bound.max_value < 0 && D_bound.max_value < 0 );
862+ if (all_terms_negative) {
863+ A_bound = Negative (A_bound);
864+ B_bound = Negative (B_bound);
865+ C_bound = Negative (C_bound);
866+ D_bound = Negative (D_bound);
867+ }
868+
869+ bool all_terms_positive = (A_bound.min_value > 0 && B_bound.min_value > 0 &&
870+ C_bound.min_value > 0 && D_bound.min_value > 0 );
871+ if (!all_terms_positive) {
872+ return std::nullopt ;
873+ }
874+
875+ // (A + B) * C - (A * B) * D
876+ // (A*B*C*D) * ( (A+B)/(A*B*D) - 1/C )
877+ // (A*B*C*D) * ( (1/A + 1/B)/D - 1/C )
878+ // (A*B*C*D) * (1/(A*D) + 1/(B*D) - 1/C)
879+ //
880+ // The constant (A*B*C*D) is positive, and its minimum value is the
881+ // product of the minimum values of A, B, C, and D. If the reciprocal
882+ // term (1/(A*D) + 1/(B*D) - 1/C) is positive, then this constant can
883+ // be used to provide a lower bound on the expression.
884+
885+ bool reciprocal_term_is_positive = [&]() {
886+ if (D_bound.max_value == ConstIntBound::kPosInf ) {
887+ // If D can grow without bound, the `1/(A*D)` and `1/(B*D)`
888+ // terms will approach zero, at which point the `-1/C` term
889+ // will determine the sign the sign.
890+ return false ;
891+ }
892+
893+ if (std::min (A_bound.max_value , B_bound.max_value ) * D_bound.max_value <= C_bound.min_value ) {
894+ // 1/(A*D) + 1/(B*D) - 1/C is positive if 1/C < 1/(A*D) + 1/(B*D).
895+ // Since each term is positive, this condition can hold if either
896+ // A*D <= C or B*D <= C.
897+ return true ;
898+ }
899+ if (A_bound.max_value != ConstIntBound::kPosInf &&
900+ B_bound.max_value != ConstIntBound::kPosInf ) {
901+ // Even if neither term is sufficient on its own, if both A and B
902+ // have known upper bounds, the inequality 1/C < 1/(A*D) + 1/(B*D)
903+ // may still be provable.
904+ //
905+ // The maximum value of the LHS is found when C is minimized. The
906+ // minimum value of the RHS is found when A, B, and D are
907+ // maximized. If the condition holds in this case, then it holds
908+ // in all cases.
909+ //
910+ // 1/C_min < 1/(A_max * D_max) + 1/(B_max*D_max)
911+ // A_max*B_max*D_max < C_min*B_max + C_min*A_max
912+ // A_max*B_max*D_max < C_min*(A_max + B_max)
913+ //
914+ if (A_bound.max_value * B_bound.max_value * D_bound.max_value <
915+ C_bound.min_value * (A_bound.max_value + B_bound.max_value )) {
916+ return true ;
917+ }
918+ }
919+ return false ;
920+ }();
921+
922+ if (!reciprocal_term_is_positive) {
923+ return std::nullopt ;
924+ }
925+
926+ auto ret = Everything (expr->dtype );
927+ ret.min_value = A_bound.min_value * B_bound.min_value * C_bound.min_value * D_bound.min_value ;
928+
929+ // If we flipped the sign of the original expression, flip the sign of
930+ // the resulting set of possible values.
931+ if (all_terms_negative) {
932+ ret = Negative (ret);
933+ }
934+ return ret;
935+ }
736936};
737937
738938ConstIntBound ConstIntBoundAnalyzer::operator ()(const PrimExpr& expr) const {
0 commit comments