Skip to content

Commit a2c8a29

Browse files
authored
[SCHEDULE] Improve bound inference, support reduce codegen. (#30)
1 parent d4af7ad commit a2c8a29

32 files changed

+1247
-646
lines changed

include/tvm/expr.h

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ using Halide::Internal::IRPrinter;
3232
using Halide::Internal::Variable;
3333

3434
using Halide::Internal::make_const;
35+
using Halide::Internal::make_zero;
36+
using Halide::Internal::as_const_int;
37+
using Halide::Internal::as_const_uint;
3538

3639

3740
inline Type TVMType2Type(TVMType t) {
@@ -126,25 +129,25 @@ using Halide::abs;
126129
using Halide::select;
127130

128131
/*!
129-
* \brief sum of of source expression over rdom
132+
* \brief sum of of source expression over axis
130133
* \param source The source expression.
131-
* \param rdom List of iteration variables that will be used for reduction.
134+
* \param axis List of iteration variables that will be used for reduction.
132135
*/
133-
Expr sum(Expr source, Array<IterVar> rdom);
136+
Expr sum(Expr source, Array<IterVar> axis);
134137

135138
/*!
136-
* \brief max of of source expression over rdom
139+
* \brief max of of source expression over axis
137140
* \param source The source expression.
138-
* \param rdom List of iteration variables that will be used for reduction.
141+
* \param axis List of iteration variables that will be used for reduction.
139142
*/
140-
Expr max(Expr source, Array<IterVar> rdom);
143+
Expr max(Expr source, Array<IterVar> axis);
141144

142145
/*!
143-
* \brief max of of source expression over rdom
146+
* \brief max of of source expression over axis
144147
* \param source The source expression.
145-
* \param rdom List of iteration variables that will be used for reduction.
148+
* \param axis List of iteration variables that will be used for reduction.
146149
*/
147-
Expr min(Expr source, Array<IterVar> rdom);
150+
Expr min(Expr source, Array<IterVar> axis);
148151

149152

150153
// print functions for expr

include/tvm/ir.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ struct Reduce : public ExprNode<Reduce> {
3030
std::string op;
3131
/*! \brief The source operand */
3232
Expr source;
33-
/*! \brief The reduction domains */
34-
Array<IterVar> rdom;
33+
/*! \brief The reduction axis */
34+
Array<IterVar> axis;
3535

3636
/*! \brief construct expr from op and rdom */
3737
static Expr make(std::string op, Expr src, Array<IterVar> rdom);
@@ -40,7 +40,7 @@ struct Reduce : public ExprNode<Reduce> {
4040
v->Visit("dtype", &type);
4141
v->Visit("op", &op);
4242
v->Visit("source", &source);
43-
v->Visit("rdom", &rdom);
43+
v->Visit("axis", &axis);
4444
}
4545
static const IRNodeType _type_info = IRNodeType::ExtensionExpr;
4646
static constexpr const char* _type_key = "Reduce";

include/tvm/ir_pass.h

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
* \file ir_pass.h
44
* \brief Collection of IR pass functions
55
*
6-
* All the pass functions in this file are for Stmt,
7-
* We can use PassFunction(Evaluate(expr)) to apply it to Expr
6+
* When the pass functions in this file are for Stmt,
7+
* we can use PassFunction(Evaluate(expr)) to apply it to Expr
88
*/
99
#ifndef TVM_IR_PASS_H_
1010
#define TVM_IR_PASS_H_
@@ -37,15 +37,6 @@ inline Stmt Simplify(Stmt a) {
3737
return Halide::Internal::simplify(a);
3838
}
3939

40-
/*!
41-
* \brief Schedule s' dependent operations.
42-
*
43-
* \param s The schedule to be realized
44-
* \param dom_map The domain of each iter vars.
45-
* \return the result Stmt
46-
*/
47-
Stmt ScheduleOps(Schedule s, Map<IterVar, Range> dom_map);
48-
4940
/*!
5041
* \brief verifies whether the IR stmt or Expr is in SSA form.
5142
* That is: each VarExpr is defined and assigned once(in Let/For)
@@ -69,6 +60,14 @@ bool HasSideEffect(const Expr& e);
6960
*/
7061
Stmt ConvertSSA(Stmt stmt);
7162

63+
/*!
64+
* \brief Substitute the var specified in key->var to be value.
65+
* \param stmt The source statement to be substituted
66+
* \param value_map The map of new values.
67+
* \return The converted form.
68+
*/
69+
Stmt Substitute(Stmt stmt, const Map<IterVar, Expr>& value_map);
70+
7271
/*!
7372
* \brief inline all calls of f in stmt.
7473
*

include/tvm/operation.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ class ComputeOpNode : public OperationNode {
4949
public:
5050
/*! \brief IterVar on each axis */
5151
Array<IterVar> axis;
52+
/*! \brief IterVar on each reduction axis, if the body is a Reduce */
53+
Array<IterVar> reduce_axis;
5254
/*! \brief the compute expression */
5355
Expr body;
5456
/*! \brief constructor */
@@ -64,6 +66,7 @@ class ComputeOpNode : public OperationNode {
6466
void VisitAttrs(AttrVisitor* v) final {
6567
v->Visit("name", &name);
6668
v->Visit("axis", &axis);
69+
v->Visit("reduce_axis", &reduce_axis);
6770
v->Visit("body", &body);
6871
}
6972
static Operation make(std::string name,

include/tvm/schedule.h

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,8 @@ class Stage : public NodeRef {
123123
IterVar* p_x_outer, IterVar* p_y_outer,
124124
IterVar* p_x_inner, IterVar* p_y_inner,
125125
Expr x_factor, Expr y_factor);
126+
// declare container type
127+
using ContainerType = StageNode;
126128
};
127129

128130
/*!
@@ -152,11 +154,22 @@ class Schedule : public NodeRef {
152154
Stage operator[](const Tensor& tensor) {
153155
return this->operator[](tensor->op);
154156
}
157+
/*!
158+
* \brief Normalize the schedule.
159+
* This is needed before bound inference.
160+
* Insert necessary RebaseNode to make sure all leaf_iter_vars
161+
* are in form [0, extent)
162+
*
163+
* \return A normalized schedule, can be same as current one.
164+
*/
165+
void normalize();
155166
/*!
156167
* \brief access the internal node container
157168
* \return the pointer to the internal node container
158169
*/
159170
inline const ScheduleNode* operator->() const;
171+
// declare container type
172+
using ContainerType = ScheduleNode;
160173
};
161174

162175
/*!
@@ -308,6 +321,30 @@ class FuseNode : public IterVarRelationNode {
308321
TVM_DECLARE_NODE_TYPE_INFO(FuseNode);
309322
};
310323

324+
/*!
325+
* \brief Rebase the iteration to make min to be 0.
326+
* This is useful to normalize the Schedule
327+
* to make every leaf variable's min to be 0.
328+
*/
329+
class RebaseNode : public IterVarRelationNode {
330+
public:
331+
/*! \brief The parent domain */
332+
IterVar parent;
333+
/*! \brief The inner domain */
334+
IterVar rebased;
335+
336+
void VisitAttrs(AttrVisitor* v) final {
337+
v->Visit("parent", &parent);
338+
v->Visit("rebased", &rebased);
339+
}
340+
341+
static IterVarRelation make(IterVar parent, IterVar rebased);
342+
343+
static constexpr const char* _type_key = "Rebase";
344+
TVM_DECLARE_NODE_TYPE_INFO(RebaseNode);
345+
};
346+
347+
311348
// implementations
312349
inline const StageNode* Stage::operator->() const {
313350
return static_cast<const StageNode*>(node_.get());

include/tvm/schedule_pass.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,15 @@ namespace schedule {
2424
*/
2525
Map<IterVar, Range> InferBound(Schedule sch);
2626

27+
/*!
28+
* \brief Schedule s' dependent operations.
29+
*
30+
* \param s The schedule to be realized
31+
* \param dom_map The domain of each iter vars.
32+
* \return the result Stmt
33+
*/
34+
Stmt ScheduleOps(Schedule s, Map<IterVar, Range> dom_map);
35+
2736
} // namespace schedule
2837
} // namespace tvm
2938
#endif // TVM_SCHEDULE_PASS_H_

python/tvm/api.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -212,51 +212,51 @@ def IterVar(dom=None, name=None, thread_tag=''):
212212
return _api_internal._IterVar(dom, name, thread_tag)
213213

214214

215-
def sum(expr, rdom):
216-
"""Create a sum expression over rdom
215+
def sum(expr, axis):
216+
"""Create a sum expression over axis
217217
218218
Parameters
219219
----------
220220
expr : Expr
221221
The source expression.
222222
223-
rdom : RDomain
224-
The reduction domainx
223+
axis : IterVar
224+
The reduction IterVar axis
225225
"""
226-
rdom = rdom if isinstance(rdom, list) else [rdom]
227-
x = _make.Reduce("Add", expr, rdom)
226+
axis = axis if isinstance(axis, list) else [axis]
227+
x = _make.Reduce("Add", expr, axis)
228228
return x
229229

230230

231-
def min(expr, rdom):
232-
"""Create a min expression over rdom
231+
def min(expr, axis):
232+
"""Create a min expression over axis
233233
234234
Parameters
235235
----------
236236
expr : Expr
237237
The source expression.
238238
239-
rdom : RDomain
240-
The reduction domainx
239+
axis : IterVar
240+
The reduction IterVar axis
241241
"""
242-
rdom = rdom if isinstance(rdom, list) else [rdom]
243-
x = _make.Reduce("Min", expr, rdom)
242+
axis = axis if isinstance(axis, list) else [axis]
243+
x = _make.Reduce("Min", expr, axis)
244244
return x
245245

246246

247-
def max(expr, rdom):
248-
"""Create a min expression over rdom
247+
def max(expr, axis):
248+
"""Create a min expression over axis
249249
250250
Parameters
251251
----------
252252
expr : Expr
253253
The source expression.
254254
255-
rdom : RDomain
256-
The reduction domainx
255+
axis : IterVar
256+
The reduction IterVar axis
257257
"""
258-
rdom = rdom if isinstance(rdom, list) else [rdom]
259-
x = _make.Reduce("Max", expr, rdom)
258+
axis = axis if isinstance(axis, list) else [axis]
259+
x = _make.Reduce("Max", expr, axis)
260260
return x
261261

262262

python/tvm/build.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,10 @@ def build(sch,
6262

6363
# lowering
6464
bounds = schedule.InferBound(sch)
65-
stmt = ir_pass.ScheduleOps(sch, bounds)
65+
stmt = schedule.ScheduleOps(sch, bounds)
6666
stmt = ir_pass.StorageFlatten(stmt, binds)
6767
stmt = ir_pass.Simplify(stmt)
68+
print(stmt)
6869
fapi = codegen.MakeAPI(stmt, name, arg_list, len(arg_list))
6970
fsplits = codegen.SplitHostDevice(fapi)
7071

@@ -73,7 +74,8 @@ def build(sch,
7374
for i, f in enumerate(fsplits):
7475
t = target if i >= 1 else "c"
7576
record_codes.append(codegen.CompileToC(f, output_ssa, t))
76-
77+
for c in record_codes:
78+
print(c)
7779
if target == "cuda":
7880
ret = codegen.BuildNVRTC(fsplits, "stackvm")
7981
elif target == "opencl":

python/tvm/schedule.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,14 @@ def __getitem__(self, k):
3333
raise ValueError("Cannot find the operation %s in schedule" % (str(k)))
3434
return self.stage_map[k]
3535

36+
def normalize(self):
37+
"""Build a normalized schedule.
38+
39+
Insert necessary rebase to make certain iter var to start from 0.
40+
This is needed before bound inference and followup step.
41+
"""
42+
_api_internal._ScheduleNormalize(self)
43+
3644
@register_node
3745
class Stage(NodeBase):
3846
"""A Stage represents schedule for one operation."""

src/api/api_lang.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,4 +253,10 @@ TVM_REGISTER_API(_StageTile)
253253
*ret = Array<IterVar>({x_outer, y_outer, x_inner, y_inner});
254254
});
255255

256+
TVM_REGISTER_API(_ScheduleNormalize)
257+
.set_body([](TVMArgs args, TVMRetValue* ret) {
258+
args[0].operator Schedule()
259+
.normalize();
260+
});
261+
256262
} // namespace tvm

0 commit comments

Comments
 (0)