Skip to content

Commit 95ec38b

Browse files
authored
[Arith] Provide tighter ConstIntBounds for special cases (#16588)
* [Arith] Provide tighter ConstIntBounds for special cases Expressions of the form `(A+B)*C < (A*B)*D` can occur occur when comparing the number of operations required for two different orderings in which matrix multiplications can be performed. Proving or disproving this conditional allows an optimal order of execution to be selected, even for dynamic argument shapes. The default behavior of `ConstIntBounds` assumes that each term in an expression is independent. For example, the maximum value of `(A+B)*C - (A*B)*D` is determined by taking the maximum value of `(A+B)*C` and subtracting the minimum value of `(A*B)*D`. This algorithm can be applied in all cases, but can provide a bound that is looser than strictly required. This commit adds a check for this case in `ConstIntBounds`, to provide a tighter bound of possible values. When `A`, `B`, `C`, and `D` are all positive values, as is the case for tensor shapes, the inequality can be written as `1/A + 1/B < D/C`. If this inequality holds for the minimum values of `A`, `B`, and `D`, along with the maximum value of `C`, then it holds for all values. * Parametrize ConstIntBound tests * Benchmark with/without the BoundUsingReciprocal function * Revert "Benchmark with/without the BoundUsingReciprocal function" This reverts commit 47a1fbd.
1 parent 5cbcaf4 commit 95ec38b

File tree

4 files changed

+434
-289
lines changed

4 files changed

+434
-289
lines changed

src/arith/const_int_bound.cc

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include <tvm/tir/expr_functor.h>
2727

2828
#include <algorithm>
29+
#include <optional>
2930

3031
#include "constraint_extract.h"
3132
#include "int_operator.h"
@@ -81,6 +82,16 @@ struct ConstIntBoundAnalyzer::Entry {
8182
bool operator==(const Entry& other) const {
8283
return min_value == other.min_value && max_value == other.max_value;
8384
}
85+
86+
friend std::ostream& operator<<(std::ostream& os, const Entry& entry) {
87+
os << "Entry[";
88+
PrintBoundValue(os, entry.min_value);
89+
os << ", ";
90+
PrintBoundValue(os, entry.max_value);
91+
os << "]";
92+
93+
return os;
94+
}
8495
};
8596

8697
class ConstIntBoundAnalyzer::Impl
@@ -228,6 +239,11 @@ class ConstIntBoundAnalyzer::Impl
228239
Entry ret;
229240
ret.min_value = InfAwareAdd(a.min_value, b.min_value);
230241
ret.max_value = InfAwareAdd(a.max_value, b.max_value);
242+
243+
if (auto bound = BoundUsingReciprocal(GetRef<PrimExpr>(op))) {
244+
ret = Intersect(ret, bound.value());
245+
}
246+
231247
return ret;
232248
}
233249

@@ -237,6 +253,13 @@ class ConstIntBoundAnalyzer::Impl
237253
Entry ret;
238254
ret.min_value = InfAwareAdd(a.min_value, -b.max_value);
239255
ret.max_value = InfAwareAdd(a.max_value, -b.min_value);
256+
257+
if (auto bound = BoundUsingReciprocal(GetRef<Sub>(op))) {
258+
ret = Intersect(ret, bound.value());
259+
}
260+
if (auto bound = BoundUsingReciprocal(Sub(op->b, op->a))) {
261+
ret = Intersect(ret, Negative(bound.value()));
262+
}
240263
return ret;
241264
}
242265

@@ -628,6 +651,25 @@ class ConstIntBoundAnalyzer::Impl
628651
ret.max_value = std::min(a.max_value, b.max_value);
629652
return ret;
630653
}
654+
/*!
655+
* \brief Flip the sign of a set.
656+
* \param entry The set of values
657+
*/
658+
static Entry Negative(Entry entry) {
659+
Entry ret;
660+
if (entry.max_value == kPosInf) {
661+
ret.min_value = kNegInf;
662+
} else {
663+
ret.min_value = -entry.max_value;
664+
}
665+
if (entry.min_value == kNegInf) {
666+
ret.max_value = kPosInf;
667+
} else {
668+
ret.max_value = -entry.min_value;
669+
}
670+
671+
return ret;
672+
}
631673
/*!
632674
* \brief return everything dtype can represent.
633675
* \param dtype The data type.
@@ -733,6 +775,164 @@ class ConstIntBoundAnalyzer::Impl
733775
std::ceil(std::log2(arg_bounds.max_value)));
734776
}
735777
}
778+
779+
std::optional<Entry> BoundUsingReciprocal(PrimExpr expr) {
780+
// Match expressions of the form `(A+B)*C - (A*B)*D`. Depending on
781+
// previous simplifications, the exact form of the expression may vary.
782+
auto opt_special_case = [&]() -> std::optional<std::tuple<Entry, Entry, Entry, Entry>> {
783+
PVar<PrimExpr> A, B, C, D;
784+
785+
if (PMatchesOneOf{
786+
(A + B) * C - (A * B) * D,
787+
(A + B) * C - (B * A) * D,
788+
}
789+
.Match(expr)) {
790+
return std::tuple{VisitExpr(A.Eval()), VisitExpr(B.Eval()), VisitExpr(C.Eval()),
791+
VisitExpr(D.Eval())};
792+
} else if (PMatchesOneOf{
793+
(A + B) * C - A * B,
794+
(A + B) * C - B * A,
795+
}
796+
.Match(expr)) {
797+
return std::tuple{VisitExpr(A.Eval()), VisitExpr(B.Eval()), VisitExpr(C.Eval()),
798+
MakeBound(1, 1)};
799+
} else if (PMatchesOneOf{
800+
(A * B) * D - (A + B) * C,
801+
(B * A) * D - (A + B) * C,
802+
}
803+
.Match(expr)) {
804+
return std::tuple{Negative(VisitExpr(A.Eval())), Negative(VisitExpr(B.Eval())),
805+
Negative(VisitExpr(C.Eval())), Negative(VisitExpr(D.Eval()))};
806+
} else if (PMatchesOneOf{
807+
A * B - (A + B) * C,
808+
B * A - (A + B) * C,
809+
}
810+
.Match(expr)) {
811+
return std::tuple{Negative(VisitExpr(A.Eval())), Negative(VisitExpr(B.Eval())),
812+
Negative(VisitExpr(C.Eval())), MakeBound(-1, -1)};
813+
} else if (PMatchesOneOf{
814+
(A * B) * D + (A + B) * C,
815+
(B * A) * D + (A + B) * C,
816+
(A + B) * C + (A * B) * D,
817+
(A + B) * C + (B * A) * D,
818+
}
819+
.Match(expr)) {
820+
return std::tuple{Negative(VisitExpr(A.Eval())), Negative(VisitExpr(B.Eval())),
821+
VisitExpr(C.Eval()), Negative(VisitExpr(D.Eval()))};
822+
} else if (PMatchesOneOf{
823+
(A * B) + (A + B) * C,
824+
(B * A) + (A + B) * C,
825+
(A + B) * C + (A * B),
826+
(A + B) * C + (B * A),
827+
}
828+
.Match(expr)) {
829+
return std::tuple{Negative(VisitExpr(A.Eval())), Negative(VisitExpr(B.Eval())),
830+
VisitExpr(C.Eval()), MakeBound(-1, -1)};
831+
} else {
832+
return std::nullopt;
833+
}
834+
}();
835+
836+
if (!opt_special_case.has_value()) {
837+
return std::nullopt;
838+
}
839+
// Unpacking the tuple would be cleaner with a structured binding.
840+
// However, until C++20, structured bindings cannot be captured for
841+
// use in a lambda function.
842+
auto A_bound = std::get<0>(*opt_special_case);
843+
auto B_bound = std::get<1>(*opt_special_case);
844+
auto C_bound = std::get<2>(*opt_special_case);
845+
auto D_bound = std::get<3>(*opt_special_case);
846+
847+
// If C and D have different signs, flip the signs of A/B/C so
848+
// that C will match the sign of D.
849+
if ((D_bound.max_value < 0 && C_bound.min_value > 0) ||
850+
(D_bound.min_value > 0 && C_bound.max_value < 0)) {
851+
A_bound = Negative(A_bound);
852+
B_bound = Negative(B_bound);
853+
C_bound = Negative(C_bound);
854+
}
855+
856+
// If all terms are negative, then we'll be providing an upper bound
857+
// rather than a lower bound. To avoid code duplication, flip all the
858+
// signs here, find a lower bound, then flip the sign to produce the
859+
// upper bound of the original expression.
860+
bool all_terms_negative = (A_bound.max_value < 0 && B_bound.max_value < 0 &&
861+
C_bound.max_value < 0 && D_bound.max_value < 0);
862+
if (all_terms_negative) {
863+
A_bound = Negative(A_bound);
864+
B_bound = Negative(B_bound);
865+
C_bound = Negative(C_bound);
866+
D_bound = Negative(D_bound);
867+
}
868+
869+
bool all_terms_positive = (A_bound.min_value > 0 && B_bound.min_value > 0 &&
870+
C_bound.min_value > 0 && D_bound.min_value > 0);
871+
if (!all_terms_positive) {
872+
return std::nullopt;
873+
}
874+
875+
// (A + B) * C - (A * B) * D
876+
// (A*B*C*D) * ( (A+B)/(A*B*D) - 1/C )
877+
// (A*B*C*D) * ( (1/A + 1/B)/D - 1/C )
878+
// (A*B*C*D) * (1/(A*D) + 1/(B*D) - 1/C)
879+
//
880+
// The constant (A*B*C*D) is positive, and its minimum value is the
881+
// product of the minimum values of A, B, C, and D. If the reciprocal
882+
// term (1/(A*D) + 1/(B*D) - 1/C) is positive, then this constant can
883+
// be used to provide a lower bound on the expression.
884+
885+
bool reciprocal_term_is_positive = [&]() {
886+
if (D_bound.max_value == ConstIntBound::kPosInf) {
887+
// If D can grow without bound, the `1/(A*D)` and `1/(B*D)`
888+
// terms will approach zero, at which point the `-1/C` term
889+
// will determine the sign the sign.
890+
return false;
891+
}
892+
893+
if (std::min(A_bound.max_value, B_bound.max_value) * D_bound.max_value <= C_bound.min_value) {
894+
// 1/(A*D) + 1/(B*D) - 1/C is positive if 1/C < 1/(A*D) + 1/(B*D).
895+
// Since each term is positive, this condition can hold if either
896+
// A*D <= C or B*D <= C.
897+
return true;
898+
}
899+
if (A_bound.max_value != ConstIntBound::kPosInf &&
900+
B_bound.max_value != ConstIntBound::kPosInf) {
901+
// Even if neither term is sufficient on its own, if both A and B
902+
// have known upper bounds, the inequality 1/C < 1/(A*D) + 1/(B*D)
903+
// may still be provable.
904+
//
905+
// The maximum value of the LHS is found when C is minimized. The
906+
// minimum value of the RHS is found when A, B, and D are
907+
// maximized. If the condition holds in this case, then it holds
908+
// in all cases.
909+
//
910+
// 1/C_min < 1/(A_max * D_max) + 1/(B_max*D_max)
911+
// A_max*B_max*D_max < C_min*B_max + C_min*A_max
912+
// A_max*B_max*D_max < C_min*(A_max + B_max)
913+
//
914+
if (A_bound.max_value * B_bound.max_value * D_bound.max_value <
915+
C_bound.min_value * (A_bound.max_value + B_bound.max_value)) {
916+
return true;
917+
}
918+
}
919+
return false;
920+
}();
921+
922+
if (!reciprocal_term_is_positive) {
923+
return std::nullopt;
924+
}
925+
926+
auto ret = Everything(expr->dtype);
927+
ret.min_value = A_bound.min_value * B_bound.min_value * C_bound.min_value * D_bound.min_value;
928+
929+
// If we flipped the sign of the original expression, flip the sign of
930+
// the resulting set of possible values.
931+
if (all_terms_negative) {
932+
ret = Negative(ret);
933+
}
934+
return ret;
935+
}
736936
};
737937

738938
ConstIntBound ConstIntBoundAnalyzer::operator()(const PrimExpr& expr) const {

src/arith/rewrite_simplify.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1768,6 +1768,17 @@ PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(LT ret) {
17681768
if (merge_constants) {
17691769
return RecursiveRewrite(merge_constants.value());
17701770
}
1771+
1772+
auto common_factor = [&]() -> int64_t {
1773+
auto modular_a = analyzer_->modular_set(ret->a);
1774+
auto modular_b = analyzer_->modular_set(ret->b);
1775+
auto gcd_lhs = ZeroAwareGCD(modular_a->base, modular_a->coeff);
1776+
auto gcd_rhs = ZeroAwareGCD(modular_b->base, modular_b->coeff);
1777+
return ZeroAwareGCD(gcd_lhs, gcd_rhs);
1778+
}();
1779+
if (common_factor > 1) {
1780+
return RecursiveRewrite(floordiv(ret->a, common_factor) < floordiv(ret->b, common_factor));
1781+
}
17711782
}
17721783
return std::move(ret);
17731784
}

0 commit comments

Comments
 (0)