Skip to content

Commit 9a24a83

Browse files
author
Min Chen
committed
Fix lint.
1 parent ac9749d commit 9a24a83

File tree

4 files changed

+46
-52
lines changed

4 files changed

+46
-52
lines changed

include/tvm/node/reflection.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,10 @@
3232
#include <tvm/runtime/object.h>
3333
#include <tvm/runtime/packed_func.h>
3434

35+
#include <limits>
3536
#include <string>
3637
#include <type_traits>
38+
#include <unordered_map>
3739
#include <vector>
3840

3941
namespace tvm {

src/arith/presburger_set.cc

Lines changed: 28 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -21,20 +21,22 @@
2121
* \file presburger_set.cc
2222
* \brief The presburger set functions
2323
*/
24+
#include "presburger_set.h"
25+
2426
#include <tvm/arith/int_set.h>
27+
#include <tvm/arith/int_solver.h>
28+
#include <tvm/arith/pattern.h>
2529
#include <tvm/runtime/registry.h>
2630
#include <tvm/tir/expr.h>
2731
#include <tvm/tir/expr_functor.h>
2832
#include <tvm/tir/stmt_functor.h>
29-
#include <tvm/arith/pattern.h>
30-
#include <tvm/arith/int_solver.h>
3133

3234
#include <algorithm>
3335
#include <unordered_map>
3436
#include <utility>
37+
#include <vector>
3538

3639
#include "constraint_extract.h"
37-
#include "presburger_set.h"
3840
#include "interval_set.h"
3941

4042
namespace tvm {
@@ -43,14 +45,12 @@ namespace arith {
4345
#ifdef TVM_MLIR_VERSION
4446
using namespace tir;
4547

46-
47-
void Update(const PrimExpr& constraint,
48-
PresburgerSetNode& intset) {
49-
auto& space = intset.space;
48+
void Update(const PrimExpr& constraint, PresburgerSetNode* intset) {
49+
auto& space = intset->space;
5050
auto constraints_union = ExtractComponents(constraint);
5151
for (const PrimExpr& subconstraint : constraints_union) {
5252
auto entries = ExtractConstraints(subconstraint, false);
53-
auto vars = intset.GetVars();
53+
auto vars = intset->GetVars();
5454
IntegerRelation disjunct(entries.size(), 0, vars.size() + 1, space);
5555
for (const PrimExpr& entry : entries) {
5656
// The expression is expect to be simplified to only contain ==, <= or <
@@ -83,19 +83,18 @@ void Update(const PrimExpr& constraint,
8383
LOG(FATAL) << "Unsupported constraint expression: " << entry->GetTypeKey();
8484
}
8585
}
86-
intset.unionInPlace(disjunct);
86+
intset->unionInPlace(disjunct);
8787
}
8888
}
8989

9090
PresburgerSet::PresburgerSet(const PrimExpr& constraint) {
9191
Array<Var> vars;
9292
PostOrderVisit(constraint, [&vars](const ObjectRef& obj) {
9393
if (const VarNode* new_var = obj.as<VarNode>()) {
94-
auto var = GetRef<Var>(new_var);
95-
if (!std::any_of(vars.begin(), vars.end(),
96-
[&var](const Var& v) { return v.same_as(var); })) {
97-
vars.push_back(var);
98-
}
94+
auto var = GetRef<Var>(new_var);
95+
if (!std::any_of(vars.begin(), vars.end(), [&var](const Var& v) { return v.same_as(var); })) {
96+
vars.push_back(var);
97+
}
9998
}
10099
});
101100
auto constraints_union = ExtractComponents(constraint);
@@ -104,25 +103,26 @@ PresburgerSet::PresburgerSet(const PrimExpr& constraint) {
104103
auto space = PresburgerSpace::getRelationSpace(vars.size(), 0, 0, 0);
105104
auto node = make_object<PresburgerSetNode>(std::move(space), vars);
106105
node->SetVars(vars);
107-
Update(simplified_constraint, *node);
106+
Update(simplified_constraint, node.get());
108107
data_ = std::move(node);
109108
}
110109

111-
PresburgerSet::PresburgerSet(const std::vector<IntegerRelation>& disjuncts, const Array<Var>& vars) {
110+
PresburgerSet::PresburgerSet(const std::vector<IntegerRelation>& disjuncts,
111+
const Array<Var>& vars) {
112112
auto node = make_object<PresburgerSetNode>(disjuncts, disjuncts[0].getSpace(), vars);
113113
data_ = std::move(node);
114114
}
115115

116116
void PresburgerSetNode::UpdateConstraint(const PrimExpr& constraint, const Array<Var>& vars) {
117117
Analyzer analyzer;
118118
PrimExpr simplified_constraint = analyzer.Simplify(constraint, kSimplifyRewriteCanonicalRewrite);
119-
Update(simplified_constraint, *this);
119+
Update(simplified_constraint, this);
120120
SetVars(vars);
121121
}
122122

123123
PrimExpr PresburgerSetNode::GenerateConstraint() const {
124124
PrimExpr constraint = Bool(0);
125-
for (const IntegerRelation &disjunct : disjuncts) {
125+
for (const IntegerRelation& disjunct : disjuncts) {
126126
PrimExpr union_entry = Bool(1);
127127
for (unsigned i = 0, e = disjunct.getNumEqualities(); i < e; ++i) {
128128
PrimExpr linear_eq = IntImm(DataType::Int(32), 0);
@@ -171,8 +171,9 @@ PresburgerSet Union(Array<PresburgerSet> sets) {
171171
if (sets.size() == 1) return sets[0];
172172
auto relations = sets[0]->disjuncts;
173173
for (size_t i = 1; i < sets.size(); ++i) {
174-
for (const auto rel : sets[i]->disjuncts)
174+
for (const IntegerRelation& rel : sets[i]->disjuncts) {
175175
relations.push_back(rel);
176+
}
176177
}
177178
return PresburgerSet(std::move(relations), sets[0]->GetVars());
178179
}
@@ -185,31 +186,29 @@ PresburgerSet Intersect(const Array<PresburgerSet>& sets) {
185186

186187
for (size_t i = 1; i < sets.size(); ++i) {
187188
ICHECK(space.isCompatible(sets[i]->space)) << "Spaces should match";
188-
for (const IntegerRelation &relA : sets[i]->disjuncts) {
189-
for (const IntegerRelation &relB : relations) {
189+
for (const IntegerRelation& relA : sets[i]->disjuncts) {
190+
for (const IntegerRelation& relB : relations) {
190191
IntegerRelation intersection = relA.intersect(relB);
191-
if (!intersection.isEmpty())
192-
relations.push_back(intersection);
192+
if (!intersection.isEmpty()) relations.push_back(intersection);
193193
}
194194
}
195195
}
196196
return PresburgerSet(std::move(relations), sets[0]->GetVars());
197197
}
198198

199199
IntSet EvalSet(const PrimExpr& e, const PresburgerSet& set) {
200-
auto tvm_coeffs = DetectLinearEquation(e, set->GetVars());
200+
Array<PrimExpr> tvm_coeffs = DetectLinearEquation(e, set->GetVars());
201201
SmallVector<int64_t> coeffs;
202202
coeffs.reserve(tvm_coeffs.size());
203-
for (auto &it : tvm_coeffs) {
203+
for (const PrimExpr& it : tvm_coeffs) {
204204
coeffs.push_back(*as_const_int(it));
205205
}
206206

207207
IntSet result = IntSet().Nothing();
208-
for (auto &it : set->disjuncts) {
208+
for (const IntegerRelation& it : set->disjuncts) {
209209
Simplex simplex(it);
210210
auto range = simplex.computeIntegerBounds(coeffs);
211-
auto maxRoundedDown(
212-
simplex.computeOptimum(Simplex::Direction::Up, coeffs));
211+
auto maxRoundedDown(simplex.computeOptimum(Simplex::Direction::Up, coeffs));
213212
auto opt = range.first.getOptimumIfBounded();
214213
auto min = opt.hasValue() ? IntImm(DataType::Int(64), opt.getValue()) : neg_inf();
215214
opt = range.second.getOptimumIfBounded();
@@ -232,9 +231,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
232231

233232
#endif
234233

235-
PresburgerSet MakePresburgerSet(const PrimExpr& constraint) {
236-
return PresburgerSet(constraint);
237-
}
234+
PresburgerSet MakePresburgerSet(const PrimExpr& constraint) { return PresburgerSet(constraint); }
238235

239236
TVM_REGISTER_GLOBAL("arith.PresburgerSet").set_body_typed(MakePresburgerSet);
240237

src/arith/presburger_set.h

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,16 @@
2525
#define TVM_ARITH_PRESBURGER_SET_H_
2626

2727
#ifdef TVM_MLIR_VERSION
28-
#include <mlir/Analysis/Presburger/PresburgerRelation.h>
2928
#include <mlir/Analysis/Presburger/IntegerRelation.h>
29+
#include <mlir/Analysis/Presburger/PresburgerRelation.h>
3030
#include <mlir/Analysis/Presburger/Simplex.h>
3131
#endif
3232

3333
#include <tvm/arith/analyzer.h>
3434
#include <tvm/tir/op.h>
3535

3636
#include <limits>
37+
#include <vector>
3738

3839
#include "const_fold.h"
3940

@@ -54,13 +55,12 @@ using namespace presburger;
5455
*/
5556
class PresburgerSetNode : public IntSetNode {
5657
public:
57-
explicit PresburgerSetNode(const PresburgerSpace &space, const Array<Var> &vars)
58-
: disjuncts({}), space(space), vars(vars) {};
59-
explicit PresburgerSetNode() : space(PresburgerSpace::getRelationSpace()) {};
60-
explicit PresburgerSetNode(const std::vector<IntegerRelation> &disjuncts,
61-
const PresburgerSpace &space,
62-
const Array<Var> &vars)
63-
: disjuncts(disjuncts), space(space), vars(vars) {}
58+
PresburgerSetNode() : space(PresburgerSpace::getRelationSpace()) {}
59+
explicit PresburgerSetNode(const PresburgerSpace& space, const Array<Var>& vars)
60+
: disjuncts({}), space(space), vars(vars) {}
61+
explicit PresburgerSetNode(const std::vector<IntegerRelation>& disjuncts,
62+
const PresburgerSpace& space, const Array<Var>& vars)
63+
: disjuncts(disjuncts), space(space), vars(vars) {}
6464

6565
/*! \brief Represent the union of multiple IntegerRelation */
6666
std::vector<IntegerRelation> disjuncts;
@@ -83,7 +83,7 @@ class PresburgerSetNode : public IntSetNode {
8383
* \brief Do inplace union with given disjunct
8484
* \param disjunct The given disjunct to be union with
8585
*/
86-
void unionInPlace(const IntegerRelation &disjunct) {
86+
void unionInPlace(const IntegerRelation& disjunct) {
8787
assert(space.isCompatible(disjunct.getSpace()) && "Spaces should match");
8888
disjuncts.push_back(disjunct);
8989
}
@@ -105,7 +105,7 @@ class PresburgerSetNode : public IntSetNode {
105105
* \brief Set domain vars
106106
* \param new_vars Vars that will be taken as the domain vars
107107
*/
108-
void SetVars(const Array<Var> &new_vars) { vars = new_vars; }
108+
void SetVars(const Array<Var>& new_vars) { vars = new_vars; }
109109

110110
/*!
111111
* \brief Get the current domain vars
@@ -115,8 +115,7 @@ class PresburgerSetNode : public IntSetNode {
115115

116116
/*! \return whether integer set is empty */
117117
bool IsEmpty() const {
118-
return std::all_of(disjuncts.begin(),
119-
disjuncts.end(),
118+
return std::all_of(disjuncts.begin(), disjuncts.end(),
120119
std::mem_fn(&IntegerRelation::isIntegerEmpty));
121120
}
122121

@@ -156,24 +155,20 @@ class PresburgerSet : public IntSet {
156155
class PresburgerSetNode : public IntSetNode {
157156
public:
158157
// dummy visitor overload.
159-
void VisitAttrs(tvm::AttrVisitor* v) {
160-
LOG(FATAL) << "MLIR is not enabled!";
161-
}
158+
void VisitAttrs(tvm::AttrVisitor* v) { LOG(FATAL) << "MLIR is not enabled!"; }
162159

163160
static constexpr const char* _type_key = "arith.PresburgerSet";
164161
TVM_DECLARE_FINAL_OBJECT_INFO(PresburgerSetNode, IntSetNode);
165162
};
166163

167164
class PresburgerSet : public IntSet {
168165
public:
169-
/*!
166+
/*!
170167
* \brief Constructor interface to prompt when MLIR is not enabled.
171168
* \param constraint The constraint to construct the set.
172169
* \return The created set.
173170
*/
174-
TVM_DLL PresburgerSet(const PrimExpr& constraint) {
175-
LOG(FATAL) << "MLIR is not enabled!";
176-
}
171+
TVM_DLL PresburgerSet(const PrimExpr& constraint) { LOG(FATAL) << "MLIR is not enabled!"; }
177172
};
178173
#endif
179174
/*!

tests/cpp/arith_integer_set_test.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,11 @@ TEST(PresburgerSet, eval) {
2929
auto x = tvm::tir::Var("x");
3030
auto y = tvm::tir::Var("y");
3131
auto sub_constraint0 = (x + y < 20) && (x - y <= 0);
32-
auto sub_constraint1 = x >= 0 && x < 20 && y >=0 && y < 20;
32+
auto sub_constraint1 = x >= 0 && x < 20 && y >= 0 && y < 20;
3333
auto constraint = sub_constraint0 && sub_constraint1;
3434
auto set = tvm::arith::PresburgerSet(constraint);
3535

36-
auto target = x + 2*y;
36+
auto target = x + 2 * y;
3737
auto result = EvalSet(target, set);
3838
ASSERT_TRUE(tvm::tir::is_zero(result.min()));
3939
ASSERT_TRUE(tvm::tir::is_const_int(result.max(), 38));

0 commit comments

Comments
 (0)