@@ -362,8 +362,13 @@ using namespace tir;
362362// We might use better set analysis in the future to replace the intervalset.
363363class IntervalSetEvaluator : public ExprFunctor <IntervalSet(const PrimExpr&)> {
364364 public:
365- IntervalSetEvaluator (Analyzer* analyzer, const Map<Var, IntSet>& dom_map, bool eval_vec = false )
366- : analyzer_(analyzer), dom_map_(dom_map), eval_vec_(eval_vec) {}
365+ IntervalSetEvaluator (Analyzer* analyzer, const Map<Var, IntSet>& dom_map,
366+ const std::vector<std::pair<Var, IntSet>>& dom_constraints = {},
367+ bool eval_vec = false )
368+ : analyzer_(analyzer),
369+ dom_map_ (dom_map),
370+ dom_constraints_(dom_constraints),
371+ eval_vec_(eval_vec) {}
367372
368373 IntervalSet Eval (const PrimExpr& val) { return this ->VisitExpr (val); }
369374 // evaluate and relax the set
@@ -383,18 +388,38 @@ class IntervalSetEvaluator : public ExprFunctor<IntervalSet(const PrimExpr&)> {
383388
384389 IntervalSet VisitExpr_ (const VarNode* op) final {
385390 Var var = GetRef<Var>(op);
391+
392+ Array<IntSet> values;
393+ for (const auto & constraint : dom_constraints_) {
394+ if (var.same_as (constraint.first )) {
395+ values.push_back (constraint.second );
396+ }
397+ }
398+
386399 auto it = dom_map_.find (var);
387400 if (it != dom_map_.end ()) {
388- IntervalSet res = ToIntervalSet ((*it).second );
389- if (res->min_value .same_as (var) && res->max_value .same_as (var)) {
390- return res;
391- }
392- // recursively evaluate mapped result
393- // in case the domain contains variables to be relaxed.
394- return Eval (res);
395- } else {
401+ values.push_back ((*it).second );
402+ }
403+
404+ if (values.empty ()) {
396405 return IntervalSet::SinglePoint (var);
397406 }
407+
408+ IntSet intersection = [&]() {
409+ if (values.size () == 1 ) {
410+ return values.front ();
411+ } else {
412+ return Intersect (values);
413+ }
414+ }();
415+
416+ IntervalSet res = ToIntervalSet (intersection);
417+ if (res->min_value .same_as (var) && res->max_value .same_as (var)) {
418+ return res;
419+ }
420+ // recursively evaluate mapped result
421+ // in case the domain contains variables to be relaxed.
422+ return Eval (res);
398423 }
399424
400425 IntervalSet VisitExpr_ (const AddNode* op) final { return VisitBinaryExpr_<Add>(op); }
@@ -517,6 +542,7 @@ class IntervalSetEvaluator : public ExprFunctor<IntervalSet(const PrimExpr&)> {
517542 // analyzer
518543 Analyzer* analyzer_;
519544 const Map<Var, IntSet>& dom_map_;
545+ const std::vector<std::pair<Var, IntSet>>& dom_constraints_;
520546 bool eval_vec_{false };
521547};
522548
@@ -525,11 +551,11 @@ class IntSetAnalyzer::Impl {
525551 explicit Impl (Analyzer* analyzer) : analyzer_(analyzer) {}
526552
527553 IntSet Eval (const PrimExpr& expr, const Map<Var, IntSet>& dom_map) const {
528- return IntervalSetEvaluator (analyzer_, dom_map).Eval (expr);
554+ return IntervalSetEvaluator (analyzer_, dom_map, {} ).Eval (expr);
529555 }
530556
531557 IntSet Eval (const PrimExpr& expr) const {
532- return IntervalSetEvaluator (analyzer_, GetCurrentBounds () , true ).Eval (expr);
558+ return IntervalSetEvaluator (analyzer_, dom_map_, constraints_ , true ).Eval (expr);
533559 }
534560
535561 void Bind (const Var& var, const Range& range, bool allow_override) {
@@ -543,10 +569,6 @@ class IntSetAnalyzer::Impl {
543569 std::function<void ()> SuppressConstraints ();
544570
545571 private:
546- // Get the current variable bounds, including both global bounds and
547- // scope-dependent bounds.
548- Map<Var, IntSet> GetCurrentBounds () const ;
549-
550572 // Utility function to split a boolean condition into the domain
551573 // bounds implied by that condition.
552574 static std::vector<std::pair<Var, IntSet>> DetectBoundInfo (const PrimExpr& cond);
@@ -558,9 +580,11 @@ class IntSetAnalyzer::Impl {
558580 // ranges)
559581 Map<Var, IntSet> dom_map_;
560582
561- // Map of variables to implicit scope-dependent bounds (e.g. inside
562- // the body of an if-statement)
563- Map<Var, IntSet> constraints_;
583+ // List of implicit scope-dependent bounds (e.g. inside the body of
584+ // an if-statement). Maintained as a list of constraints, rather
585+ // than as a `Map<Var,IntSet>`, to avoid computing an Intersection
586+ // until required.
587+ std::vector<std::pair<Var, IntSet>> constraints_;
564588
565589 // Whether scope-based analysis should be temporarily disabled
566590 bool use_scoped_constraints_{true };
@@ -608,29 +632,6 @@ void IntSetAnalyzer::Impl::Bind(const Var& var, const PrimExpr& expr, bool can_o
608632 Update (var, Eval (expr), can_override);
609633}
610634
611- Map<Var, IntSet> IntSetAnalyzer::Impl::GetCurrentBounds () const {
612- // If either constraints_ or dom_map_ is empty, return the other to
613- // avoid constructing a new map.
614- if (constraints_.empty () || !use_scoped_constraints_) {
615- return dom_map_;
616- } else if (dom_map_.empty ()) {
617- return constraints_;
618- }
619-
620- // If neither is empty, construct a merged domain map with
621- // information from both sources.
622- Map<Var, IntSet> merged = dom_map_;
623- for (const auto & pair : constraints_) {
624- auto it = merged.find (pair.first );
625- if (it == merged.end ()) {
626- merged.Set (pair.first , pair.second );
627- } else {
628- merged.Set (pair.first , Intersect ({pair.second , (*it).second }));
629- }
630- }
631- return merged;
632- }
633-
634635std::vector<std::pair<Var, IntSet>> IntSetAnalyzer::Impl::DetectBoundInfo (
635636 const PrimExpr& constraint) {
636637 PVar<Var> x;
@@ -672,41 +673,16 @@ std::function<void()> IntSetAnalyzer::EnterConstraint(const PrimExpr& constraint
672673std::function<void ()> IntSetAnalyzer::SuppressConstraints () { return impl_->SuppressConstraints (); }
673674
674675std::function<void ()> IntSetAnalyzer::Impl::EnterConstraint (const PrimExpr& constraint) {
675- Map<Var, IntSet> cached_values;
676-
677676 auto bounds = DetectBoundInfo (constraint);
678677
679678 if (bounds.size () == 0 ) return nullptr ;
680679
681- // Collect the current values of each var that is changes by this
682- // constraint.
683- for (const auto & pair : bounds) {
684- auto it = constraints_.find (pair.first );
685- if (it == constraints_.end ()) {
686- cached_values.Set (pair.first , IntSet ());
687- } else {
688- cached_values.Set (pair.first , (*it).second );
689- }
690- }
691-
692- // Update all constraints
693- for (const auto & pair : bounds) {
694- auto it = constraints_.find (pair.first );
695- if (it == constraints_.end ()) {
696- constraints_.Set (pair.first , pair.second );
697- } else {
698- constraints_.Set (pair.first , Intersect ({pair.second , (*it).second }));
699- }
700- }
701-
702- auto frecover = [cached_values, this ]() {
703- for (const auto & it : cached_values) {
704- if (it.second .defined ()) {
705- constraints_.Set (it.first , it.second );
706- } else {
707- constraints_.erase (it.first );
708- }
709- }
680+ size_t old_size = constraints_.size ();
681+ constraints_.insert (constraints_.end (), bounds.begin (), bounds.end ());
682+ size_t new_size = constraints_.size ();
683+ auto frecover = [old_size, new_size, this ]() {
684+ ICHECK_EQ (constraints_.size (), new_size);
685+ constraints_.resize (old_size);
710686 };
711687 return frecover;
712688}
@@ -975,13 +951,13 @@ Map<Var, IntSet> ConvertDomMap(const std::unordered_map<const VarNode*, IntSet>&
975951
976952IntSet EvalSet (PrimExpr e, const Map<Var, IntSet>& dom_map) {
977953 Analyzer ana;
978- return IntervalSetEvaluator (&ana, dom_map, false ).Eval (e);
954+ return IntervalSetEvaluator (&ana, dom_map, {}, false ).Eval (e);
979955}
980956
981957IntSet IntSet::Vector (PrimExpr x) {
982958 Analyzer ana;
983959 Map<Var, IntSet> dmap;
984- return IntervalSetEvaluator (&ana, dmap, true ).Eval (x);
960+ return IntervalSetEvaluator (&ana, dmap, {}, true ).Eval (x);
985961}
986962
987963IntSet EvalSet (PrimExpr e, const Map<IterVar, IntSet>& dom_map) {
0 commit comments