Skip to content

Commit f5bf323

Browse files
committed
[Arith][TIR] IntSetAnalyzer, delay intersection of IntSet until use
Follow-up from apache#11970, to improve performance. In the initial implementation, the `analyzer->int_set` would compute the intersection of all scope-based constraints when entering the scope, even if they weren't actually used. This commit delays the call to `Intersect` until required, following the same behavior as `ConstIntBound`.
1 parent 832f7fa commit f5bf323

File tree

1 file changed

+51
-75
lines changed

1 file changed

+51
-75
lines changed

src/arith/int_set.cc

Lines changed: 51 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -362,8 +362,13 @@ using namespace tir;
362362
// We might use better set analysis in the future to replace the intervalset.
363363
class IntervalSetEvaluator : public ExprFunctor<IntervalSet(const PrimExpr&)> {
364364
public:
365-
IntervalSetEvaluator(Analyzer* analyzer, const Map<Var, IntSet>& dom_map, bool eval_vec = false)
366-
: analyzer_(analyzer), dom_map_(dom_map), eval_vec_(eval_vec) {}
365+
IntervalSetEvaluator(Analyzer* analyzer, const Map<Var, IntSet>& dom_map,
366+
const std::vector<std::pair<Var, IntSet>>& dom_constraints = {},
367+
bool eval_vec = false)
368+
: analyzer_(analyzer),
369+
dom_map_(dom_map),
370+
dom_constraints_(dom_constraints),
371+
eval_vec_(eval_vec) {}
367372

368373
IntervalSet Eval(const PrimExpr& val) { return this->VisitExpr(val); }
369374
// evaluate and relax the set
@@ -383,18 +388,38 @@ class IntervalSetEvaluator : public ExprFunctor<IntervalSet(const PrimExpr&)> {
383388

384389
IntervalSet VisitExpr_(const VarNode* op) final {
385390
Var var = GetRef<Var>(op);
391+
392+
Array<IntSet> values;
393+
for (const auto& constraint : dom_constraints_) {
394+
if (var.same_as(constraint.first)) {
395+
values.push_back(constraint.second);
396+
}
397+
}
398+
386399
auto it = dom_map_.find(var);
387400
if (it != dom_map_.end()) {
388-
IntervalSet res = ToIntervalSet((*it).second);
389-
if (res->min_value.same_as(var) && res->max_value.same_as(var)) {
390-
return res;
391-
}
392-
// recursively evaluate mapped result
393-
// in case the domain contains variables to be relaxed.
394-
return Eval(res);
395-
} else {
401+
values.push_back((*it).second);
402+
}
403+
404+
if (values.empty()) {
396405
return IntervalSet::SinglePoint(var);
397406
}
407+
408+
IntSet intersection = [&]() {
409+
if (values.size() == 1) {
410+
return values.front();
411+
} else {
412+
return Intersect(values);
413+
}
414+
}();
415+
416+
IntervalSet res = ToIntervalSet(intersection);
417+
if (res->min_value.same_as(var) && res->max_value.same_as(var)) {
418+
return res;
419+
}
420+
// recursively evaluate mapped result
421+
// in case the domain contains variables to be relaxed.
422+
return Eval(res);
398423
}
399424

400425
IntervalSet VisitExpr_(const AddNode* op) final { return VisitBinaryExpr_<Add>(op); }
@@ -517,6 +542,7 @@ class IntervalSetEvaluator : public ExprFunctor<IntervalSet(const PrimExpr&)> {
517542
// analyzer
518543
Analyzer* analyzer_;
519544
const Map<Var, IntSet>& dom_map_;
545+
const std::vector<std::pair<Var, IntSet>>& dom_constraints_;
520546
bool eval_vec_{false};
521547
};
522548

@@ -525,11 +551,11 @@ class IntSetAnalyzer::Impl {
525551
explicit Impl(Analyzer* analyzer) : analyzer_(analyzer) {}
526552

527553
IntSet Eval(const PrimExpr& expr, const Map<Var, IntSet>& dom_map) const {
528-
return IntervalSetEvaluator(analyzer_, dom_map).Eval(expr);
554+
return IntervalSetEvaluator(analyzer_, dom_map, {}).Eval(expr);
529555
}
530556

531557
IntSet Eval(const PrimExpr& expr) const {
532-
return IntervalSetEvaluator(analyzer_, GetCurrentBounds(), true).Eval(expr);
558+
return IntervalSetEvaluator(analyzer_, dom_map_, constraints_, true).Eval(expr);
533559
}
534560

535561
void Bind(const Var& var, const Range& range, bool allow_override) {
@@ -543,10 +569,6 @@ class IntSetAnalyzer::Impl {
543569
std::function<void()> SuppressConstraints();
544570

545571
private:
546-
// Get the current variable bounds, including both global bounds and
547-
// scope-dependent bounds.
548-
Map<Var, IntSet> GetCurrentBounds() const;
549-
550572
// Utility function to split a boolean condition into the domain
551573
// bounds implied by that condition.
552574
static std::vector<std::pair<Var, IntSet>> DetectBoundInfo(const PrimExpr& cond);
@@ -558,9 +580,11 @@ class IntSetAnalyzer::Impl {
558580
// ranges)
559581
Map<Var, IntSet> dom_map_;
560582

561-
// Map of variables to implicit scope-dependent bounds (e.g. inside
562-
// the body of an if-statement)
563-
Map<Var, IntSet> constraints_;
583+
// List of implicit scope-dependent bounds (e.g. inside the body of
584+
// an if-statement). Maintained as a list of constraints, rather
585+
// than as a `Map<Var,IntSet>`, to avoid computing an Intersection
586+
// until required.
587+
std::vector<std::pair<Var, IntSet>> constraints_;
564588

565589
// Whether scope-based analysis should be temporarily disabled
566590
bool use_scoped_constraints_{true};
@@ -608,29 +632,6 @@ void IntSetAnalyzer::Impl::Bind(const Var& var, const PrimExpr& expr, bool can_o
608632
Update(var, Eval(expr), can_override);
609633
}
610634

611-
Map<Var, IntSet> IntSetAnalyzer::Impl::GetCurrentBounds() const {
612-
// If either constraints_ or dom_map_ is empty, return the other to
613-
// avoid constructing a new map.
614-
if (constraints_.empty() || !use_scoped_constraints_) {
615-
return dom_map_;
616-
} else if (dom_map_.empty()) {
617-
return constraints_;
618-
}
619-
620-
// If neither is empty, construct a merged domain map with
621-
// information from both sources.
622-
Map<Var, IntSet> merged = dom_map_;
623-
for (const auto& pair : constraints_) {
624-
auto it = merged.find(pair.first);
625-
if (it == merged.end()) {
626-
merged.Set(pair.first, pair.second);
627-
} else {
628-
merged.Set(pair.first, Intersect({pair.second, (*it).second}));
629-
}
630-
}
631-
return merged;
632-
}
633-
634635
std::vector<std::pair<Var, IntSet>> IntSetAnalyzer::Impl::DetectBoundInfo(
635636
const PrimExpr& constraint) {
636637
PVar<Var> x;
@@ -672,41 +673,16 @@ std::function<void()> IntSetAnalyzer::EnterConstraint(const PrimExpr& constraint
672673
std::function<void()> IntSetAnalyzer::SuppressConstraints() { return impl_->SuppressConstraints(); }
673674

674675
std::function<void()> IntSetAnalyzer::Impl::EnterConstraint(const PrimExpr& constraint) {
675-
Map<Var, IntSet> cached_values;
676-
677676
auto bounds = DetectBoundInfo(constraint);
678677

679678
if (bounds.size() == 0) return nullptr;
680679

681-
// Collect the current values of each var that is changes by this
682-
// constraint.
683-
for (const auto& pair : bounds) {
684-
auto it = constraints_.find(pair.first);
685-
if (it == constraints_.end()) {
686-
cached_values.Set(pair.first, IntSet());
687-
} else {
688-
cached_values.Set(pair.first, (*it).second);
689-
}
690-
}
691-
692-
// Update all constraints
693-
for (const auto& pair : bounds) {
694-
auto it = constraints_.find(pair.first);
695-
if (it == constraints_.end()) {
696-
constraints_.Set(pair.first, pair.second);
697-
} else {
698-
constraints_.Set(pair.first, Intersect({pair.second, (*it).second}));
699-
}
700-
}
701-
702-
auto frecover = [cached_values, this]() {
703-
for (const auto& it : cached_values) {
704-
if (it.second.defined()) {
705-
constraints_.Set(it.first, it.second);
706-
} else {
707-
constraints_.erase(it.first);
708-
}
709-
}
680+
size_t old_size = constraints_.size();
681+
constraints_.insert(constraints_.end(), bounds.begin(), bounds.end());
682+
size_t new_size = constraints_.size();
683+
auto frecover = [old_size, new_size, this]() {
684+
ICHECK_EQ(constraints_.size(), new_size);
685+
constraints_.resize(old_size);
710686
};
711687
return frecover;
712688
}
@@ -975,13 +951,13 @@ Map<Var, IntSet> ConvertDomMap(const std::unordered_map<const VarNode*, IntSet>&
975951

976952
IntSet EvalSet(PrimExpr e, const Map<Var, IntSet>& dom_map) {
977953
Analyzer ana;
978-
return IntervalSetEvaluator(&ana, dom_map, false).Eval(e);
954+
return IntervalSetEvaluator(&ana, dom_map, {}, false).Eval(e);
979955
}
980956

981957
IntSet IntSet::Vector(PrimExpr x) {
982958
Analyzer ana;
983959
Map<Var, IntSet> dmap;
984-
return IntervalSetEvaluator(&ana, dmap, true).Eval(x);
960+
return IntervalSetEvaluator(&ana, dmap, {}, true).Eval(x);
985961
}
986962

987963
IntSet EvalSet(PrimExpr e, const Map<IterVar, IntSet>& dom_map) {

0 commit comments

Comments
 (0)