Skip to content

Commit eaa2e71

Browse files
committed
Move the pattern match to RewriteSimplifier, with feature flag
1 parent 0cc77e2 commit eaa2e71

File tree

9 files changed

+241
-173
lines changed

9 files changed

+241
-173
lines changed

include/tvm/arith/analyzer.h

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,35 @@ class RewriteSimplifier {
334334
* (n < 10) || (n < 5) => (n < 5)
335335
*/
336336
kApplyConstraintsToBooleanBranches = (1 << 2),
337+
338+
/* Special handling for expressions `(A+B)*C < (A*B)*D`
339+
*
340+
* Expressions of the form `(A+B)*C < (A*B)*D` can occur occur
341+
* when comparing the number of operations required for two
342+
* different orderings in which matrix multiplications can be
343+
* performed. Proving or disproving this conditional allows an
344+
* optimal order of execution to be selected, even for dynamic
345+
* argument shapes.
346+
*
347+
* The default behavior of `ConstIntBounds` assumes that each term
348+
* in an expression is independent, and is insufficient to prove
349+
* these inequalities. For example, the maximum value of `(A+B)*C
350+
* - (A*B)*D` is determined by taking the maximum value of
351+
* `(A+B)*C` and subtracting the minimum value of `(A*B)*D`.
352+
* While this algorithm can be applied in all cases, the bound it
353+
* provides is looser than strictly required.
354+
*
355+
* This extension adds a check for this case. When `A`, `B`, `C`,
356+
* and `D` are all positive values, as is the case for tensor
357+
* shapes, the inequality can be written as `1/A + 1/B < D/C`. If
358+
* this inequality holds for the minimum values of `A`, `B`, and
359+
* `D`, along with the maximum value of `C`, then the inequality
360+
* holds for all values.
361+
*
362+
* This extension requires little to no performance overhead, and
363+
* may be enabled by default in future releases.
364+
*/
365+
kComparisonOfProductAndSum = (1 << 3),
337366
};
338367

339368
/*! \brief Enable an optional extension or extensions

python/tvm/arith/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
estimate_region_strict_bound,
2525
estimate_region_upper_bound,
2626
)
27-
from .analyzer import ModularSet, ConstIntBound, Analyzer, ProofStrength
27+
from .analyzer import ModularSet, ConstIntBound, Analyzer, ProofStrength, Extension
2828
from .bound import deduce_bound
2929
from .pattern import detect_linear_equation, detect_clip_bound, detect_common_subexpr
3030
from .int_solver import solve_linear_equations, solve_linear_inequalities

python/tvm/arith/analyzer.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# under the License.
1717
# pylint: disable=invalid-name
1818
"""Arithmetic data structure and utility"""
19-
from enum import IntEnum
19+
import enum
2020
from typing import Union
2121

2222
import tvm._ffi
@@ -26,13 +26,26 @@
2626
from . import _ffi_api
2727

2828

29-
class ProofStrength(IntEnum):
29+
class ProofStrength(enum.IntEnum):
3030
"""Proof strength of the analysis"""
3131

3232
DEFAULT = 0
3333
SYMBOLIC_BOUND = 1
3434

3535

36+
class Extension(enum.Flag):
37+
"""Extensions enabled for RewriteSimplifier
38+
39+
Values should match `RewriteSimplifier::Extensions`
40+
"""
41+
42+
NoExtensions = 0
43+
TransitivelyProveInequalities = 1 << 0
44+
ConvertBooleanToAndOfOrs = 1 << 1
45+
ApplyConstraintsToBooleanBranches = 1 << 2
46+
ComparisonOfProductAndSum = 1 << 3
47+
48+
3649
@tvm._ffi.register_object("arith.ModularSet")
3750
class ModularSet(Object):
3851
"""Represent range of (coeff * x + base) for x in Z"""
@@ -107,6 +120,8 @@ def __init__(self):
107120
self._enter_constraint_context = _mod("enter_constraint_context")
108121
self._can_prove_equal = _mod("can_prove_equal")
109122
self._can_prove = _mod("can_prove")
123+
self._get_enabled_extensions = _mod("get_enabled_extensions")
124+
self._set_enabled_extensions = _mod("set_enabled_extensions")
110125

111126
def const_int_bound(self, expr):
112127
"""Find constant integer bound for expr.
@@ -311,3 +326,22 @@ def can_prove_equal(self, lhs: "PrimExpr", rhs: "PrimExpr"):
311326
Whether we can prove that lhs == rhs
312327
"""
313328
return self._can_prove_equal(lhs, rhs)
329+
330+
@property
331+
def enabled_extensions(self) -> Extension:
332+
"""Return the currently enabled extensions"""
333+
value = self._get_enabled_extensions()
334+
return Extension(value)
335+
336+
@enabled_extensions.setter
337+
def enabled_extensions(self, flags: Union[int, Extension]):
338+
"""Enable extensions for the analyzer
339+
340+
Parameters
341+
----------
342+
flags: Union[int,Extension]
343+
344+
The extensions to enable.
345+
"""
346+
flags = Extension(flags).value
347+
self._set_enabled_extensions(flags)

src/arith/analyzer.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,16 @@ TVM_REGISTER_GLOBAL("arith.CreateAnalyzer").set_body([](TVMArgs args, TVMRetValu
317317
} else if (name == "can_prove_equal") {
318318
return PackedFunc(
319319
[self](TVMArgs args, TVMRetValue* ret) { *ret = self->CanProveEqual(args[0], args[1]); });
320+
} else if (name == "get_enabled_extensions") {
321+
return PackedFunc([self](TVMArgs args, TVMRetValue* ret) {
322+
*ret = static_cast<std::int64_t>(self->rewrite_simplify.GetEnabledExtensions());
323+
});
324+
} else if (name == "set_enabled_extensions") {
325+
return PackedFunc([self](TVMArgs args, TVMRetValue* ret) {
326+
std::int64_t flags = args[0];
327+
self->rewrite_simplify.SetEnabledExtensions(
328+
static_cast<RewriteSimplifier::Extension>(flags));
329+
});
320330
}
321331
return PackedFunc();
322332
};

src/arith/const_int_bound.cc

Lines changed: 0 additions & 168 deletions
Original file line numberDiff line numberDiff line change
@@ -240,10 +240,6 @@ class ConstIntBoundAnalyzer::Impl
240240
ret.min_value = InfAwareAdd(a.min_value, b.min_value);
241241
ret.max_value = InfAwareAdd(a.max_value, b.max_value);
242242

243-
if (auto bound = BoundUsingReciprocal(GetRef<PrimExpr>(op))) {
244-
ret = Intersect(ret, bound.value());
245-
}
246-
247243
return ret;
248244
}
249245

@@ -254,12 +250,6 @@ class ConstIntBoundAnalyzer::Impl
254250
ret.min_value = InfAwareAdd(a.min_value, -b.max_value);
255251
ret.max_value = InfAwareAdd(a.max_value, -b.min_value);
256252

257-
if (auto bound = BoundUsingReciprocal(GetRef<Sub>(op))) {
258-
ret = Intersect(ret, bound.value());
259-
}
260-
if (auto bound = BoundUsingReciprocal(Sub(op->b, op->a))) {
261-
ret = Intersect(ret, Negative(bound.value()));
262-
}
263253
return ret;
264254
}
265255

@@ -775,164 +765,6 @@ class ConstIntBoundAnalyzer::Impl
775765
std::ceil(std::log2(arg_bounds.max_value)));
776766
}
777767
}
778-
779-
std::optional<Entry> BoundUsingReciprocal(PrimExpr expr) {
780-
// Match expressions of the form `(A+B)*C - (A*B)*D`. Depending on
781-
// previous simplifications, the exact form of the expression may vary.
782-
auto opt_special_case = [&]() -> std::optional<std::tuple<Entry, Entry, Entry, Entry>> {
783-
PVar<PrimExpr> A, B, C, D;
784-
785-
if (PMatchesOneOf{
786-
(A + B) * C - (A * B) * D,
787-
(A + B) * C - (B * A) * D,
788-
}
789-
.Match(expr)) {
790-
return std::tuple{VisitExpr(A.Eval()), VisitExpr(B.Eval()), VisitExpr(C.Eval()),
791-
VisitExpr(D.Eval())};
792-
} else if (PMatchesOneOf{
793-
(A + B) * C - A * B,
794-
(A + B) * C - B * A,
795-
}
796-
.Match(expr)) {
797-
return std::tuple{VisitExpr(A.Eval()), VisitExpr(B.Eval()), VisitExpr(C.Eval()),
798-
MakeBound(1, 1)};
799-
} else if (PMatchesOneOf{
800-
(A * B) * D - (A + B) * C,
801-
(B * A) * D - (A + B) * C,
802-
}
803-
.Match(expr)) {
804-
return std::tuple{Negative(VisitExpr(A.Eval())), Negative(VisitExpr(B.Eval())),
805-
Negative(VisitExpr(C.Eval())), Negative(VisitExpr(D.Eval()))};
806-
} else if (PMatchesOneOf{
807-
A * B - (A + B) * C,
808-
B * A - (A + B) * C,
809-
}
810-
.Match(expr)) {
811-
return std::tuple{Negative(VisitExpr(A.Eval())), Negative(VisitExpr(B.Eval())),
812-
Negative(VisitExpr(C.Eval())), MakeBound(-1, -1)};
813-
} else if (PMatchesOneOf{
814-
(A * B) * D + (A + B) * C,
815-
(B * A) * D + (A + B) * C,
816-
(A + B) * C + (A * B) * D,
817-
(A + B) * C + (B * A) * D,
818-
}
819-
.Match(expr)) {
820-
return std::tuple{Negative(VisitExpr(A.Eval())), Negative(VisitExpr(B.Eval())),
821-
VisitExpr(C.Eval()), Negative(VisitExpr(D.Eval()))};
822-
} else if (PMatchesOneOf{
823-
(A * B) + (A + B) * C,
824-
(B * A) + (A + B) * C,
825-
(A + B) * C + (A * B),
826-
(A + B) * C + (B * A),
827-
}
828-
.Match(expr)) {
829-
return std::tuple{Negative(VisitExpr(A.Eval())), Negative(VisitExpr(B.Eval())),
830-
VisitExpr(C.Eval()), MakeBound(-1, -1)};
831-
} else {
832-
return std::nullopt;
833-
}
834-
}();
835-
836-
if (!opt_special_case.has_value()) {
837-
return std::nullopt;
838-
}
839-
// Unpacking the tuple would be cleaner with a structured binding.
840-
// However, until C++20, structured bindings cannot be captured for
841-
// use in a lambda function.
842-
auto A_bound = std::get<0>(*opt_special_case);
843-
auto B_bound = std::get<1>(*opt_special_case);
844-
auto C_bound = std::get<2>(*opt_special_case);
845-
auto D_bound = std::get<3>(*opt_special_case);
846-
847-
// If C and D have different signs, flip the signs of A/B/C so
848-
// that C will match the sign of D.
849-
if ((D_bound.max_value < 0 && C_bound.min_value > 0) ||
850-
(D_bound.min_value > 0 && C_bound.max_value < 0)) {
851-
A_bound = Negative(A_bound);
852-
B_bound = Negative(B_bound);
853-
C_bound = Negative(C_bound);
854-
}
855-
856-
// If all terms are negative, then we'll be providing an upper bound
857-
// rather than a lower bound. To avoid code duplication, flip all the
858-
// signs here, find a lower bound, then flip the sign to produce the
859-
// upper bound of the original expression.
860-
bool all_terms_negative = (A_bound.max_value < 0 && B_bound.max_value < 0 &&
861-
C_bound.max_value < 0 && D_bound.max_value < 0);
862-
if (all_terms_negative) {
863-
A_bound = Negative(A_bound);
864-
B_bound = Negative(B_bound);
865-
C_bound = Negative(C_bound);
866-
D_bound = Negative(D_bound);
867-
}
868-
869-
bool all_terms_positive = (A_bound.min_value > 0 && B_bound.min_value > 0 &&
870-
C_bound.min_value > 0 && D_bound.min_value > 0);
871-
if (!all_terms_positive) {
872-
return std::nullopt;
873-
}
874-
875-
// (A + B) * C - (A * B) * D
876-
// (A*B*C*D) * ( (A+B)/(A*B*D) - 1/C )
877-
// (A*B*C*D) * ( (1/A + 1/B)/D - 1/C )
878-
// (A*B*C*D) * (1/(A*D) + 1/(B*D) - 1/C)
879-
//
880-
// The constant (A*B*C*D) is positive, and its minimum value is the
881-
// product of the minimum values of A, B, C, and D. If the reciprocal
882-
// term (1/(A*D) + 1/(B*D) - 1/C) is positive, then this constant can
883-
// be used to provide a lower bound on the expression.
884-
885-
bool reciprocal_term_is_positive = [&]() {
886-
if (D_bound.max_value == ConstIntBound::kPosInf) {
887-
// If D can grow without bound, the `1/(A*D)` and `1/(B*D)`
888-
// terms will approach zero, at which point the `-1/C` term
889-
// will determine the sign the sign.
890-
return false;
891-
}
892-
893-
if (std::min(A_bound.max_value, B_bound.max_value) * D_bound.max_value <= C_bound.min_value) {
894-
// 1/(A*D) + 1/(B*D) - 1/C is positive if 1/C < 1/(A*D) + 1/(B*D).
895-
// Since each term is positive, this condition can hold if either
896-
// A*D <= C or B*D <= C.
897-
return true;
898-
}
899-
if (A_bound.max_value != ConstIntBound::kPosInf &&
900-
B_bound.max_value != ConstIntBound::kPosInf) {
901-
// Even if neither term is sufficient on its own, if both A and B
902-
// have known upper bounds, the inequality 1/C < 1/(A*D) + 1/(B*D)
903-
// may still be provable.
904-
//
905-
// The maximum value of the LHS is found when C is minimized. The
906-
// minimum value of the RHS is found when A, B, and D are
907-
// maximized. If the condition holds in this case, then it holds
908-
// in all cases.
909-
//
910-
// 1/C_min < 1/(A_max * D_max) + 1/(B_max*D_max)
911-
// A_max*B_max*D_max < C_min*B_max + C_min*A_max
912-
// A_max*B_max*D_max < C_min*(A_max + B_max)
913-
//
914-
if (A_bound.max_value * B_bound.max_value * D_bound.max_value <
915-
C_bound.min_value * (A_bound.max_value + B_bound.max_value)) {
916-
return true;
917-
}
918-
}
919-
return false;
920-
}();
921-
922-
if (!reciprocal_term_is_positive) {
923-
return std::nullopt;
924-
}
925-
926-
auto ret = Everything(expr->dtype);
927-
ret.min_value = A_bound.min_value * B_bound.min_value * C_bound.min_value * D_bound.min_value;
928-
929-
// If we flipped the sign of the original expression, flip the sign of
930-
// the resulting set of possible values.
931-
if (all_terms_negative) {
932-
ret = Negative(ret);
933-
}
934-
return ret;
935-
}
936768
};
937769

938770
ConstIntBound ConstIntBoundAnalyzer::operator()(const PrimExpr& expr) const {

0 commit comments

Comments
 (0)