1212
1313namespace tvm {
1414
15+ // Node container for Stage
16+ class StageNode ;
1517// Node container for Schedule
1618class ScheduleNode ;
1719// Node container for IterVarRelation
@@ -25,46 +27,48 @@ enum AttachType : int {
2527 kScope = 3
2628};
2729
28- /* ! \brief schedule container */
29- class Schedule : public NodeRef {
30+ /* ! \brief Stage, contains scheduling for a stage of computation. */
31+ class Stage : public NodeRef {
3032 public:
31- Schedule () {}
32- explicit Schedule (std::shared_ptr<Node> n) : NodeRef(n) {}
33+ Stage () {}
34+ explicit Stage (std::shared_ptr<Node> n) : NodeRef(n) {}
3335 /* !
3436 * \brief create a new schedule for op.
3537 * \param op The operator in the schedule
36- * \param scope The scope of the schedule
3738 */
38- Schedule (Operation op, std::string scope );
39+ explicit Stage (Operation op);
3940 /* !
4041 * \brief access the internal node container
4142 * \return the pointer to the internal node container
4243 */
43- inline const ScheduleNode * operator ->() const ;
44+ inline const StageNode * operator ->() const ;
4445 /* !
4546 * \brief access the internal node container
4647 * \return the pointer to the internal node container
4748 */
48- inline ScheduleNode* operator ->();
49+ inline StageNode* operator ->();
50+ /* !
51+ * \brief set the memory scope of the stage
52+ * \param scope The memory scope.
53+ */
54+ Stage& set_scope (std::string scope); // NOLINT(*)
4955 /* !
5056 * \brief specify the schedule to be computed at the parent schedule's scope.
5157 * \param parent The parent schedule.
5258 * \param scope The iteration point to carry the schedule.
5359 * \return reference to self.
5460 */
55- Schedule & compute_at (Schedule parent, IterVar scope); // NOLINT(*)
61+ Stage & compute_at (Stage parent, IterVar scope); // NOLINT(*)
5662 /* !
5763 * \brief Compute the function inline, attach it at parent.
58- * \param parent The parent schedule to be attached to.
5964 * \return reference to self.
6065 */
61- Schedule & compute_inline (Schedule parent ); // NOLINT(*)
66+ Stage & compute_inline (); // NOLINT(*)
6267 /* !
6368 * \brief Compute the function at root, attach it to its parent.
64- * \param parent The parent schedule to be attached to.
6569 * \return reference to self.
6670 */
67- Schedule & compute_root (Schedule parent ); // NOLINT(*)
71+ Stage & compute_root (); // NOLINT(*)
6872 /* !
6973 * \brief Split the parent by factor, generate
7074 * \param parent The parent iteration domain.
@@ -73,7 +77,7 @@ class Schedule : public NodeRef {
7377 * \param factor The split factor of the loop.
7478 * \return reference to self.
7579 */
76- Schedule & split (IterVar parent, IterVar* p_outer, IterVar* p_inner, Expr factor); // NOLINT(*)
80+ Stage & split (IterVar parent, IterVar* p_outer, IterVar* p_inner, Expr factor); // NOLINT(*)
7781 /* !
7882 * \brief Split the iteration with a given outer domain,
7983 * the outer domain must have a thread-tag.
@@ -85,24 +89,74 @@ class Schedule : public NodeRef {
8589 * factor must be provided such that factor * outer.extent >= parent.extent.
8690 * \return reference to self.
8791 */
88- Schedule & split (IterVar parent, IterVar outer, IterVar* p_inner, Expr factor = Expr()); // NOLINT(*)
92+ Stage & split (IterVar parent, IterVar outer, IterVar* p_inner, Expr factor = Expr()); // NOLINT(*)
8993 /* !
9094 * \brief Fuse the inner outer domain to the target
9195 * \param inner The inner domain to be fused
9296 * \param outer The outer domain to be fused.
9397 * \param p_target The result target domain.
9498 * \return reference to self.
9599 */
96- Schedule & fuse (IterVar inner, IterVar outer, IterVar* p_target); // NOLINT(*)
100+ Stage & fuse (IterVar inner, IterVar outer, IterVar* p_target); // NOLINT(*)
97101 /* !
98102 * \brief Reorder the iteration
99103 * \param order The order of iteration variable.
100104 * \return reference to self.
101105 */
102- Schedule& reorder (const Array<IterVar>& order); // NOLINT(*)
103- Schedule& tile (IterVar x_parent, IterVar y_parent, IterVar* p_x_outer,
104- IterVar* p_y_outer, IterVar* p_x_inner, IterVar* p_y_inner,
105- Expr x_factor, Expr y_factor); // NOLINT(*)
106+ Stage& reorder (const Array<IterVar>& order); // NOLINT(*)
107+ /* !
108+ * \brief Perform tiling on two dimensions
109+ * The final loop order from outmost to inner most are
110+ * [x_outer, y_outer, x_inner, y_inner]
111+ *
112+ * \param x_parent The original x dimension
113+ * \param y_parent The original y dimension
114+ * \param p_x_outer Outer axis of x dimension
115+ * \param p_y_outer Outer axis of y dimension
116+ * \param p_x_inner Inner axis of x dimension
117+ * \param p_y_inner Inner axis of y dimension
118+ * \param x_factor The stride factor on x axis
119+ * \param y_factor The stride factor on y axis
120+ * \return reference to self.
121+ */
122+ Stage& tile (IterVar x_parent, IterVar y_parent, // NOLINT(*)
123+ IterVar* p_x_outer, IterVar* p_y_outer,
124+ IterVar* p_x_inner, IterVar* p_y_inner,
125+ Expr x_factor, Expr y_factor);
126+ };
127+
128+ /* !
129+ * \brief Global schedule container
130+ * For operations and all the operations they depend on.
131+ * The schedule per Operation is named as stage.
132+ */
133+ class Schedule : public NodeRef {
134+ public:
135+ Schedule () {}
136+ explicit Schedule (std::shared_ptr<Node> n) : NodeRef(n) {}
137+ /* !
138+ * \brief construct schedule for array of ops(and their dependencies).
139+ * \param ops The ops to be scheduled.
140+ */
141+ explicit Schedule (Array<Operation> ops);
142+ /* !
143+ * \brief Get the stage corresponds to the op
144+ * \param op The operation.
145+ */
146+ Stage operator [](const Operation& op);
147+ /* !
148+ * \brief Short hand for getting the stage of tensor's operation.
149+ * \param tensor The tensor
150+ * \return The stage corresponding to the tensor's op
151+ */
152+ Stage operator [](const Tensor& tensor) {
153+ return this ->operator [](tensor->op );
154+ }
155+ /* !
156+ * \brief access the internal node container
157+ * \return the pointer to the internal node container
158+ */
159+ inline const ScheduleNode* operator ->() const ;
106160};
107161
108162/* !
@@ -135,11 +189,11 @@ class IterVarRelation : public NodeRef {
135189 *
136190 * The relations connects the IterVars in the graph.
137191 */
138- class ScheduleNode : public Node {
192+ class StageNode : public Node {
139193 public:
140194 /* ! \brief The operation to be scheduled */
141195 Operation op;
142- /* ! \brief The thread scope level of the schedule */
196+ /* ! \brief The thread scope level of the stage */
143197 std::string scope;
144198 /* ! \brief All the nodes in the iter var */
145199 Array<IterVar> all_iter_vars;
@@ -152,12 +206,10 @@ class ScheduleNode : public Node {
152206 Array<IterVarRelation> relations;
153207 /* ! \brief The attachment type of the schedule */
154208 AttachType attach_type{kNone };
155- /* !
156- * \brief The attach point of this schedule.
157- */
158- IterVar attach_parent;
159- /* ! \brief the schedules that this schedule depend on */
160- Array<Schedule> children;
209+ /* ! \brief The attach point of this schedule. */
210+ IterVar attach_ivar;
211+ /* ! \brief The stage this node attaches to */
212+ Stage attach_stage;
161213
162214 void VisitAttrs (AttrVisitor* v) final {
163215 v->Visit (" scope" , &scope);
@@ -166,8 +218,31 @@ class ScheduleNode : public Node {
166218 v->Visit (" leaf_iter_vars" , &leaf_iter_vars);
167219 v->Visit (" relations" , &relations);
168220 v->Visit (" attach_type" , &attach_type);
169- v->Visit (" attach_parent" , &attach_parent);
170- v->Visit (" children" , &children);
221+ v->Visit (" attach_ivar" , &attach_ivar);
222+ v->Visit (" attach_stage" , &attach_stage);
223+ }
224+
225+ static constexpr const char * _type_key = " Stage" ;
226+ TVM_DECLARE_NODE_TYPE_INFO (StageNode);
227+ };
228+
229+ /* ! \brief node container for schedule */
230+ class ScheduleNode : public Node {
231+ public:
232+ /* ! \brief The root operations */
233+ Array<Operation> roots;
234+ /* !
235+ * \brief list of all stages for non-placeholder ops
236+ * The stage are ordered in PostDFS order of their op.
237+ */
238+ Array<Stage> stages;
239+ /* ! \brief map of operation to the stages */
240+ Map<Operation, Stage> stage_map;
241+
242+ void VisitAttrs (AttrVisitor* v) final {
243+ v->Visit (" roots" , &roots);
244+ v->Visit (" stages" , &stages);
245+ v->Visit (" stage_map" , &stage_map);
171246 }
172247
173248 static constexpr const char * _type_key = " Schedule" ;
@@ -234,12 +309,16 @@ class FuseNode : public IterVarRelationNode {
234309};
235310
236311// implementations
312+ inline const StageNode* Stage::operator ->() const {
313+ return static_cast <const StageNode*>(node_.get ());
314+ }
315+ inline StageNode* Stage::operator ->() {
316+ return static_cast <StageNode*>(node_.get ());
317+ }
318+
237319inline const ScheduleNode* Schedule::operator ->() const {
238320 return static_cast <const ScheduleNode*>(node_.get ());
239321}
240- inline ScheduleNode* Schedule::operator ->() {
241- return static_cast <ScheduleNode*>(node_.get ());
242- }
243322
244323inline const IterVarRelationNode* IterVarRelation::operator ->() const {
245324 return static_cast <const IterVarRelationNode*>(node_.get ());
0 commit comments