2121 * \file presburger_set.cc
2222 * \brief The presburger set functions
2323 */
24+ #include " presburger_set.h"
25+
2426#include < tvm/arith/int_set.h>
27+ #include < tvm/arith/int_solver.h>
28+ #include < tvm/arith/pattern.h>
2529#include < tvm/runtime/registry.h>
2630#include < tvm/tir/expr.h>
2731#include < tvm/tir/expr_functor.h>
2832#include < tvm/tir/stmt_functor.h>
29- #include < tvm/arith/pattern.h>
30- #include < tvm/arith/int_solver.h>
3133
3234#include < algorithm>
3335#include < unordered_map>
3436#include < utility>
37+ #include < vector>
3538
3639#include " constraint_extract.h"
37- #include " presburger_set.h"
3840#include " interval_set.h"
3941
4042namespace tvm {
@@ -43,14 +45,12 @@ namespace arith {
4345#ifdef TVM_MLIR_VERSION
4446using namespace tir ;
4547
46-
47- void Update (const PrimExpr& constraint,
48- PresburgerSetNode& intset) {
49- auto & space = intset.space ;
48+ void Update (const PrimExpr& constraint, PresburgerSetNode* intset) {
49+ auto & space = intset->space ;
5050 auto constraints_union = ExtractComponents (constraint);
5151 for (const PrimExpr& subconstraint : constraints_union) {
5252 auto entries = ExtractConstraints (subconstraint, false );
53- auto vars = intset. GetVars ();
53+ auto vars = intset-> GetVars ();
5454 IntegerRelation disjunct (entries.size (), 0 , vars.size () + 1 , space);
5555 for (const PrimExpr& entry : entries) {
5656 // The expression is expect to be simplified to only contain ==, <= or <
@@ -83,19 +83,18 @@ void Update(const PrimExpr& constraint,
8383 LOG (FATAL) << " Unsupported constraint expression: " << entry->GetTypeKey ();
8484 }
8585 }
86- intset. unionInPlace (disjunct);
86+ intset-> unionInPlace (disjunct);
8787 }
8888}
8989
9090PresburgerSet::PresburgerSet (const PrimExpr& constraint) {
9191 Array<Var> vars;
9292 PostOrderVisit (constraint, [&vars](const ObjectRef& obj) {
9393 if (const VarNode* new_var = obj.as <VarNode>()) {
94- auto var = GetRef<Var>(new_var);
95- if (!std::any_of (vars.begin (), vars.end (),
96- [&var](const Var& v) { return v.same_as (var); })) {
97- vars.push_back (var);
98- }
94+ auto var = GetRef<Var>(new_var);
95+ if (!std::any_of (vars.begin (), vars.end (), [&var](const Var& v) { return v.same_as (var); })) {
96+ vars.push_back (var);
97+ }
9998 }
10099 });
101100 auto constraints_union = ExtractComponents (constraint);
@@ -104,25 +103,26 @@ PresburgerSet::PresburgerSet(const PrimExpr& constraint) {
104103 auto space = PresburgerSpace::getRelationSpace (vars.size (), 0 , 0 , 0 );
105104 auto node = make_object<PresburgerSetNode>(std::move (space), vars);
106105 node->SetVars (vars);
107- Update (simplified_constraint, * node);
106+ Update (simplified_constraint, node. get () );
108107 data_ = std::move (node);
109108}
110109
111- PresburgerSet::PresburgerSet (const std::vector<IntegerRelation>& disjuncts, const Array<Var>& vars) {
110+ PresburgerSet::PresburgerSet (const std::vector<IntegerRelation>& disjuncts,
111+ const Array<Var>& vars) {
112112 auto node = make_object<PresburgerSetNode>(disjuncts, disjuncts[0 ].getSpace (), vars);
113113 data_ = std::move (node);
114114}
115115
116116void PresburgerSetNode::UpdateConstraint (const PrimExpr& constraint, const Array<Var>& vars) {
117117 Analyzer analyzer;
118118 PrimExpr simplified_constraint = analyzer.Simplify (constraint, kSimplifyRewriteCanonicalRewrite );
119- Update (simplified_constraint, * this );
119+ Update (simplified_constraint, this );
120120 SetVars (vars);
121121}
122122
123123PrimExpr PresburgerSetNode::GenerateConstraint () const {
124124 PrimExpr constraint = Bool (0 );
125- for (const IntegerRelation & disjunct : disjuncts) {
125+ for (const IntegerRelation& disjunct : disjuncts) {
126126 PrimExpr union_entry = Bool (1 );
127127 for (unsigned i = 0 , e = disjunct.getNumEqualities (); i < e; ++i) {
128128 PrimExpr linear_eq = IntImm (DataType::Int (32 ), 0 );
@@ -171,8 +171,9 @@ PresburgerSet Union(Array<PresburgerSet> sets) {
171171 if (sets.size () == 1 ) return sets[0 ];
172172 auto relations = sets[0 ]->disjuncts ;
173173 for (size_t i = 1 ; i < sets.size (); ++i) {
174- for (const auto rel : sets[i]->disjuncts )
174+ for (const IntegerRelation& rel : sets[i]->disjuncts ) {
175175 relations.push_back (rel);
176+ }
176177 }
177178 return PresburgerSet (std::move (relations), sets[0 ]->GetVars ());
178179}
@@ -185,31 +186,29 @@ PresburgerSet Intersect(const Array<PresburgerSet>& sets) {
185186
186187 for (size_t i = 1 ; i < sets.size (); ++i) {
187188 ICHECK (space.isCompatible (sets[i]->space )) << " Spaces should match" ;
188- for (const IntegerRelation & relA : sets[i]->disjuncts ) {
189- for (const IntegerRelation & relB : relations) {
189+ for (const IntegerRelation& relA : sets[i]->disjuncts ) {
190+ for (const IntegerRelation& relB : relations) {
190191 IntegerRelation intersection = relA.intersect (relB);
191- if (!intersection.isEmpty ())
192- relations.push_back (intersection);
192+ if (!intersection.isEmpty ()) relations.push_back (intersection);
193193 }
194194 }
195195 }
196196 return PresburgerSet (std::move (relations), sets[0 ]->GetVars ());
197197}
198198
199199IntSet EvalSet (const PrimExpr& e, const PresburgerSet& set) {
200- auto tvm_coeffs = DetectLinearEquation (e, set->GetVars ());
200+ Array<PrimExpr> tvm_coeffs = DetectLinearEquation (e, set->GetVars ());
201201 SmallVector<int64_t > coeffs;
202202 coeffs.reserve (tvm_coeffs.size ());
203- for (auto & it : tvm_coeffs) {
203+ for (const PrimExpr& it : tvm_coeffs) {
204204 coeffs.push_back (*as_const_int (it));
205205 }
206206
207207 IntSet result = IntSet ().Nothing ();
208- for (auto & it : set->disjuncts ) {
208+ for (const IntegerRelation& it : set->disjuncts ) {
209209 Simplex simplex (it);
210210 auto range = simplex.computeIntegerBounds (coeffs);
211- auto maxRoundedDown (
212- simplex.computeOptimum (Simplex::Direction::Up, coeffs));
211+ auto maxRoundedDown (simplex.computeOptimum (Simplex::Direction::Up, coeffs));
213212 auto opt = range.first .getOptimumIfBounded ();
214213 auto min = opt.hasValue () ? IntImm (DataType::Int (64 ), opt.getValue ()) : neg_inf ();
215214 opt = range.second .getOptimumIfBounded ();
@@ -232,9 +231,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
232231
233232#endif
234233
235- PresburgerSet MakePresburgerSet (const PrimExpr& constraint) {
236- return PresburgerSet (constraint);
237- }
234+ PresburgerSet MakePresburgerSet (const PrimExpr& constraint) { return PresburgerSet (constraint); }
238235
239236TVM_REGISTER_GLOBAL (" arith.PresburgerSet" ).set_body_typed(MakePresburgerSet);
240237
0 commit comments