Skip to content

Commit 3931401

Browse files
tqchenicemelon
authored andcommitted
[LANG] Change Schedule->Stage, Use Schedule for global schedule (#8)
* [LANG] Change Schedule->Stage, Use Schedule for global schedule * add numpy as dep * add numpy installation, temporary disable osx
1 parent e953e2e commit 3931401

File tree

16 files changed

+383
-268
lines changed

16 files changed

+383
-268
lines changed

.travis.yml

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,11 @@ language: cpp
44

55
os:
66
- linux
7-
- osx
7+
# - osx
88

99
env:
1010
# code analysis
11-
- TASK=lint
12-
- TASK=cpp_test
13-
- TASK=python_test
11+
- TASK=all_test
1412

1513
branches:
1614
only:
@@ -35,6 +33,7 @@ addons:
3533
- g++-4.8
3634
- python-numpy
3735
- python-nose
36+
- python3-numpy
3837
- python3-dev
3938
- python3-nose
4039
- graphviz

HalideIR

Submodule HalideIR updated from 1ec478b to 98e8df5

include/tvm/schedule.h

Lines changed: 112 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
namespace tvm {
1414

15+
// Node container for Stage
16+
class StageNode;
1517
// Node container for Schedule
1618
class ScheduleNode;
1719
// Node container for IterVarRelation
@@ -25,46 +27,48 @@ enum AttachType : int {
2527
kScope = 3
2628
};
2729

28-
/*! \brief schedule container */
29-
class Schedule : public NodeRef {
30+
/*! \brief Stage, contains scheduling for a stage of computation. */
31+
class Stage : public NodeRef {
3032
public:
31-
Schedule() {}
32-
explicit Schedule(std::shared_ptr<Node> n) : NodeRef(n) {}
33+
Stage() {}
34+
explicit Stage(std::shared_ptr<Node> n) : NodeRef(n) {}
3335
/*!
3436
* \brief create a new schedule for op.
3537
* \param op The operator in the schedule
36-
* \param scope The scope of the schedule
3738
*/
38-
Schedule(Operation op, std::string scope);
39+
explicit Stage(Operation op);
3940
/*!
4041
* \brief access the internal node container
4142
* \return the pointer to the internal node container
4243
*/
43-
inline const ScheduleNode* operator->() const;
44+
inline const StageNode* operator->() const;
4445
/*!
4546
* \brief access the internal node container
4647
* \return the pointer to the internal node container
4748
*/
48-
inline ScheduleNode* operator->();
49+
inline StageNode* operator->();
50+
/*!
51+
* \brief set the memory scope of the stage
52+
* \param scope The memory scope.
53+
*/
54+
Stage& set_scope(std::string scope); // NOLINT(*)
4955
/*!
5056
* \brief specify the schedule to be computed at the parent schedule's scope.
5157
* \param parent The parent schedule.
5258
* \param scope The iteration point to carry the schedule.
5359
* \return reference to self.
5460
*/
55-
Schedule& compute_at(Schedule parent, IterVar scope); // NOLINT(*)
61+
Stage& compute_at(Stage parent, IterVar scope); // NOLINT(*)
5662
/*!
5763
* \brief Compute the function inline, attach it at parent.
58-
* \param parent The parent schedule to be attached to.
5964
* \return reference to self.
6065
*/
61-
Schedule& compute_inline(Schedule parent); // NOLINT(*)
66+
Stage& compute_inline(); // NOLINT(*)
6267
/*!
6368
* \brief Compute the function at root, attach it to its parent.
64-
* \param parent The parent schedule to be attached to.
6569
* \return reference to self.
6670
*/
67-
Schedule& compute_root(Schedule parent); // NOLINT(*)
71+
Stage& compute_root(); // NOLINT(*)
6872
/*!
6973
* \brief Split the parent by factor, generate
7074
* \param parent The parent iteration domain.
@@ -73,7 +77,7 @@ class Schedule : public NodeRef {
7377
* \param factor The split factor of the loop.
7478
* \return reference to self.
7579
*/
76-
Schedule& split(IterVar parent, IterVar* p_outer, IterVar* p_inner, Expr factor); // NOLINT(*)
80+
Stage& split(IterVar parent, IterVar* p_outer, IterVar* p_inner, Expr factor); // NOLINT(*)
7781
/*!
7882
* \brief Split the iteration with a given outer domain,
7983
* the outer domain must have a thread-tag.
@@ -85,24 +89,74 @@ class Schedule : public NodeRef {
8589
* factor must be provided such that factor * outer.extent >= parent.extent.
8690
* \return reference to self.
8791
*/
88-
Schedule& split(IterVar parent, IterVar outer, IterVar* p_inner, Expr factor = Expr()); // NOLINT(*)
92+
Stage& split(IterVar parent, IterVar outer, IterVar* p_inner, Expr factor = Expr()); // NOLINT(*)
8993
/*!
9094
* \brief Fuse the inner outer domain to the target
9195
* \param inner The inner domain to be fused
9296
* \param outer The outer domain to be fused.
9397
* \param p_target The result target domain.
9498
* \return reference to self.
9599
*/
96-
Schedule& fuse(IterVar inner, IterVar outer, IterVar* p_target); // NOLINT(*)
100+
Stage& fuse(IterVar inner, IterVar outer, IterVar* p_target); // NOLINT(*)
97101
/*!
98102
* \brief Reorder the iteration
99103
* \param order The order of iteration variable.
100104
* \return reference to self.
101105
*/
102-
Schedule& reorder(const Array<IterVar>& order); // NOLINT(*)
103-
Schedule& tile(IterVar x_parent, IterVar y_parent, IterVar* p_x_outer,
104-
IterVar* p_y_outer, IterVar* p_x_inner, IterVar* p_y_inner,
105-
Expr x_factor, Expr y_factor); // NOLINT(*)
106+
Stage& reorder(const Array<IterVar>& order); // NOLINT(*)
107+
/*!
108+
* \brief Perform tiling on two dimensions
109+
* The final loop order from outmost to inner most are
110+
* [x_outer, y_outer, x_inner, y_inner]
111+
*
112+
* \param x_parent The original x dimension
113+
* \param y_parent The original y dimension
114+
* \param p_x_outer Outer axis of x dimension
115+
* \param p_y_outer Outer axis of y dimension
116+
* \param p_x_inner Inner axis of x dimension
117+
* \param p_y_inner Inner axis of y dimension
118+
* \param x_factor The stride factor on x axis
119+
* \param y_factor The stride factor on y axis
120+
* \return reference to self.
121+
*/
122+
Stage& tile(IterVar x_parent, IterVar y_parent, // NOLINT(*)
123+
IterVar* p_x_outer, IterVar* p_y_outer,
124+
IterVar* p_x_inner, IterVar* p_y_inner,
125+
Expr x_factor, Expr y_factor);
126+
};
127+
128+
/*!
129+
* \brief Global schedule container
130+
* For operations and all the operations they depend on.
131+
* The schedule per Operation is named as stage.
132+
*/
133+
class Schedule : public NodeRef {
134+
public:
135+
Schedule() {}
136+
explicit Schedule(std::shared_ptr<Node> n) : NodeRef(n) {}
137+
/*!
138+
* \brief construct schedule for array of ops(and their dependencies).
139+
* \param ops The ops to be scheduled.
140+
*/
141+
explicit Schedule(Array<Operation> ops);
142+
/*!
143+
* \brief Get the stage corresponds to the op
144+
* \param op The operation.
145+
*/
146+
Stage operator[](const Operation& op);
147+
/*!
148+
* \brief Short hand for getting the stage of tensor's operation.
149+
* \param tensor The tensor
150+
* \return The stage corresponding to the tensor's op
151+
*/
152+
Stage operator[](const Tensor& tensor) {
153+
return this->operator[](tensor->op);
154+
}
155+
/*!
156+
* \brief access the internal node container
157+
* \return the pointer to the internal node container
158+
*/
159+
inline const ScheduleNode* operator->() const;
106160
};
107161

108162
/*!
@@ -135,11 +189,11 @@ class IterVarRelation : public NodeRef {
135189
*
136190
* The relations connects the IterVars in the graph.
137191
*/
138-
class ScheduleNode : public Node {
192+
class StageNode : public Node {
139193
public:
140194
/*! \brief The operation to be scheduled */
141195
Operation op;
142-
/*! \brief The thread scope level of the schedule */
196+
/*! \brief The thread scope level of the stage */
143197
std::string scope;
144198
/*! \brief All the nodes in the iter var */
145199
Array<IterVar> all_iter_vars;
@@ -152,12 +206,10 @@ class ScheduleNode : public Node {
152206
Array<IterVarRelation> relations;
153207
/*! \brief The attachment type of the schedule */
154208
AttachType attach_type{kNone};
155-
/*!
156-
* \brief The attach point of this schedule.
157-
*/
158-
IterVar attach_parent;
159-
/*! \brief the schedules that this schedule depend on */
160-
Array<Schedule> children;
209+
/*! \brief The attach point of this schedule. */
210+
IterVar attach_ivar;
211+
/*! \brief The stage this node attaches to */
212+
Stage attach_stage;
161213

162214
void VisitAttrs(AttrVisitor* v) final {
163215
v->Visit("scope", &scope);
@@ -166,8 +218,31 @@ class ScheduleNode : public Node {
166218
v->Visit("leaf_iter_vars", &leaf_iter_vars);
167219
v->Visit("relations", &relations);
168220
v->Visit("attach_type", &attach_type);
169-
v->Visit("attach_parent", &attach_parent);
170-
v->Visit("children", &children);
221+
v->Visit("attach_ivar", &attach_ivar);
222+
v->Visit("attach_stage", &attach_stage);
223+
}
224+
225+
static constexpr const char* _type_key = "Stage";
226+
TVM_DECLARE_NODE_TYPE_INFO(StageNode);
227+
};
228+
229+
/*! \brief node container for schedule */
230+
class ScheduleNode : public Node {
231+
public:
232+
/*! \brief The root operations */
233+
Array<Operation> roots;
234+
/*!
235+
* \brief list of all stages for non-placeholder ops
236+
* The stage are ordered in PostDFS order of their op.
237+
*/
238+
Array<Stage> stages;
239+
/*! \brief map of operation to the stages */
240+
Map<Operation, Stage> stage_map;
241+
242+
void VisitAttrs(AttrVisitor* v) final {
243+
v->Visit("roots", &roots);
244+
v->Visit("stages", &stages);
245+
v->Visit("stage_map", &stage_map);
171246
}
172247

173248
static constexpr const char* _type_key = "Schedule";
@@ -234,12 +309,16 @@ class FuseNode : public IterVarRelationNode {
234309
};
235310

236311
// implementations
312+
inline const StageNode* Stage::operator->() const {
313+
return static_cast<const StageNode*>(node_.get());
314+
}
315+
inline StageNode* Stage::operator->() {
316+
return static_cast<StageNode*>(node_.get());
317+
}
318+
237319
inline const ScheduleNode* Schedule::operator->() const {
238320
return static_cast<const ScheduleNode*>(node_.get());
239321
}
240-
inline ScheduleNode* Schedule::operator->() {
241-
return static_cast<ScheduleNode*>(node_.get());
242-
}
243322

244323
inline const IterVarRelationNode* IterVarRelation::operator->() const {
245324
return static_cast<const IterVarRelationNode*>(node_.get());

python/tvm/function.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,8 +174,17 @@ def max(expr, rdom):
174174
return x
175175

176176

177-
def Schedule(tensor, scope="global"):
178-
return _function_internal._Schedule(tensor, scope)
177+
def Schedule(ops):
178+
"""Create a schedule for list of ops
179+
180+
Parameters
181+
----------
182+
ops : list of Operations
183+
The source expression.
184+
"""
185+
if not isinstance(ops, (list, _collections.Array)):
186+
ops = [ops]
187+
return _function_internal._Schedule(ops)
179188

180189

181190
_init_function_module("tvm")

0 commit comments

Comments
 (0)