Skip to content

Commit

Permalink
loop_revisit method for dealing with recursive loops in the IR (#3106)
Browse files Browse the repository at this point in the history
- by default will BUG("IR loop detected") (the old behavior), but allows
  writing visitor passes that can override and deal with it.
  • Loading branch information
Chris Dodd authored Mar 3, 2022
1 parent 369d527 commit df473f7
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 10 deletions.
8 changes: 7 additions & 1 deletion ir/ir-inline.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,24 @@ limitations under the License.
{ Node::traceVisit("Mod post"); v.postorder(this); } \
TEMPLATE INLINE void IR::CLASS TT::apply_visitor_revisit(Modifier &v, const Node *n) const \
{ Node::traceVisit("Mod revisit"); v.revisit(this, n); } \
TEMPLATE INLINE void IR::CLASS TT::apply_visitor_loop_revisit(Modifier &v) const \
{ Node::traceVisit("Mod loop_revisit"); v.loop_revisit(this); } \
TEMPLATE INLINE bool IR::CLASS TT::apply_visitor_preorder(Inspector &v) const \
{ Node::traceVisit("Insp pre"); return v.preorder(this); } \
TEMPLATE INLINE void IR::CLASS TT::apply_visitor_postorder(Inspector &v) const \
{ Node::traceVisit("Insp post"); v.postorder(this); } \
TEMPLATE INLINE void IR::CLASS TT::apply_visitor_revisit(Inspector &v) const \
{ Node::traceVisit("Insp revisit"); v.revisit(this); } \
TEMPLATE INLINE void IR::CLASS TT::apply_visitor_loop_revisit(Inspector &v) const \
{ Node::traceVisit("Insp loop_revisit"); v.loop_revisit(this); } \
TEMPLATE INLINE const IR::Node *IR::CLASS TT::apply_visitor_preorder(Transform &v) \
{ Node::traceVisit("Trans pre"); return v.preorder(this); } \
TEMPLATE INLINE const IR::Node *IR::CLASS TT::apply_visitor_postorder(Transform &v) \
{ Node::traceVisit("Trans post"); return v.postorder(this); } \
TEMPLATE INLINE void IR::CLASS TT::apply_visitor_revisit(Transform &v, const Node *n) const \
{ Node::traceVisit("Trans revisit"); v.revisit(this, n); }
{ Node::traceVisit("Trans revisit"); v.revisit(this, n); } \
TEMPLATE INLINE void IR::CLASS TT::apply_visitor_loop_revisit(Transform &v) const \
{ Node::traceVisit("Trans loop_revisit"); v.loop_revisit(this); }

IRNODE_ALL_TEMPLATES(DEFINE_APPLY_FUNCTIONS, inline)

Expand Down
6 changes: 6 additions & 0 deletions ir/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,15 @@ class Node : public virtual INode {
virtual bool apply_visitor_preorder(Modifier &v);
virtual void apply_visitor_postorder(Modifier &v);
virtual void apply_visitor_revisit(Modifier &v, const Node *n) const;
virtual void apply_visitor_loop_revisit(Modifier &v) const;
virtual bool apply_visitor_preorder(Inspector &v) const;
virtual void apply_visitor_postorder(Inspector &v) const;
virtual void apply_visitor_revisit(Inspector &v) const;
virtual void apply_visitor_loop_revisit(Inspector &v) const;
virtual const Node *apply_visitor_preorder(Transform &v);
virtual const Node *apply_visitor_postorder(Transform &v);
virtual void apply_visitor_revisit(Transform &v, const Node *n) const;
virtual void apply_visitor_loop_revisit(Transform &v) const;
Node &operator=(const Node &) = default;
Node &operator=(Node &&) = default;

Expand Down Expand Up @@ -168,12 +171,15 @@ inline bool equiv(const INode *a, const INode *b) {
bool apply_visitor_preorder(Modifier &v) override; \
void apply_visitor_postorder(Modifier &v) override; \
void apply_visitor_revisit(Modifier &v, const Node *n) const override; \
void apply_visitor_loop_revisit(Modifier &v) const override; \
bool apply_visitor_preorder(Inspector &v) const override; \
void apply_visitor_postorder(Inspector &v) const override; \
void apply_visitor_revisit(Inspector &v) const override; \
void apply_visitor_loop_revisit(Inspector &v) const override; \
const Node *apply_visitor_preorder(Transform &v) override; \
const Node *apply_visitor_postorder(Transform &v) override; \
void apply_visitor_revisit(Transform &v, const Node *n) const override; \
void apply_visitor_loop_revisit(Transform &v) const override; \

/* only define 'apply' for a limited number of classes (those we want to call
* visitors directly on), as defining it and making it virtual would mean that
Expand Down
50 changes: 44 additions & 6 deletions ir/visitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,15 @@ class Visitor::ChangeTracker {
else
++it; } }

/** Determine whether @n is currently being visited and the visitor has not finished
* That is, `start(@n)` has been invoked, and `finish(@n)` has not,
*
* @return true if @n is being visited and has not finished
*/
bool busy(const IR::Node *n) const {
auto it = visited.find(n);
return it != visited.end() && it->second.visit_in_progress; }

/** Determine whether @n has been visited and the visitor has finished
* and we don't want to visit @n again the next time we see it.
* That is, `start(@n)` has been invoked, followed by `finish(@n)`,
Expand Down Expand Up @@ -272,7 +281,11 @@ const IR::Node *Modifier::apply_visitor(const IR::Node *n, const char *name) {
if (ctxt) ctxt->child_name = name;
if (n) {
PushContext local(ctxt, n);
if (visited->done(n)) {
if (visited->busy(n)) {
n->apply_visitor_loop_revisit(*this);
// FIXME -- should have a way of updating the node? Needs to be decided
// by the visitor somehow, but it is tough
} else if (visited->done(n)) {
n->apply_visitor_revisit(*this, visited->result(n));
n = visited->result(n);
} else {
Expand Down Expand Up @@ -301,9 +314,9 @@ const IR::Node *Inspector::apply_visitor(const IR::Node *n, const char *name) {
if (n && !join_flows(n)) {
PushContext local(ctxt, n);
auto vp = visited->emplace(n, info_t{false, visitDagOnce});
if (!vp.second && !vp.first->second.done)
BUG("IR loop detected");
if (!vp.second && vp.first->second.visitOnce) {
if (!vp.second && !vp.first->second.done) {
n->apply_visitor_loop_revisit(*this);
} else if (!vp.second && vp.first->second.visitOnce) {
n->apply_visitor_revisit(*this);
} else {
vp.first->second.done = false;
Expand All @@ -326,7 +339,11 @@ const IR::Node *Transform::apply_visitor(const IR::Node *n, const char *name) {
if (ctxt) ctxt->child_name = name;
if (n) {
PushContext local(ctxt, n);
if (visited->done(n)) {
if (visited->busy(n)) {
n->apply_visitor_loop_revisit(*this);
// FIXME -- should have a way of updating the node? Needs to be decided
// by the visitor somehow, but it is tough
} else if (visited->done(n)) {
n->apply_visitor_revisit(*this, visited->result(n));
n = visited->result(n);
} else {
Expand Down Expand Up @@ -386,9 +403,16 @@ void Inspector::revisit_visited() {
void Modifier::revisit_visited() {
visited->revisit_visited();
}
bool Modifier::visit_in_progress(const IR::Node *n) const {
return visited->busy(n);
}
void Transform::revisit_visited() {
visited->revisit_visited();
}
bool Transform::visit_in_progress(const IR::Node *n) const {
return visited->busy(n);
}


#define DEFINE_VISIT_FUNCTIONS(CLASS, BASE) \
bool Modifier::preorder(IR::CLASS *n) { \
Expand All @@ -397,18 +421,24 @@ void Modifier::postorder(IR::CLASS *n) {
postorder(static_cast<IR::BASE *>(n)); } \
void Modifier::revisit(const IR::CLASS *o, const IR::CLASS *n) { \
revisit(static_cast<const IR::BASE *>(o), static_cast<const IR::BASE *>(n)); } \
void Modifier::loop_revisit(const IR::CLASS *o) { \
loop_revisit(static_cast<const IR::BASE *>(o)); } \
bool Inspector::preorder(const IR::CLASS *n) { \
return preorder(static_cast<const IR::BASE *>(n)); } \
void Inspector::postorder(const IR::CLASS *n) { \
postorder(static_cast<const IR::BASE *>(n)); } \
void Inspector::revisit(const IR::CLASS *n) { \
revisit(static_cast<const IR::BASE *>(n)); } \
void Inspector::loop_revisit(const IR::CLASS *n) { \
loop_revisit(static_cast<const IR::BASE *>(n)); } \
const IR::Node *Transform::preorder(IR::CLASS *n) { \
return preorder(static_cast<IR::BASE *>(n)); } \
const IR::Node *Transform::postorder(IR::CLASS *n) { \
return postorder(static_cast<IR::BASE *>(n)); } \
void Transform::revisit(const IR::CLASS *o, const IR::Node *n) { \
return revisit(static_cast<const IR::BASE *>(o), n); } \
void Transform::loop_revisit(const IR::CLASS *o) { \
return loop_revisit(static_cast<const IR::BASE *>(o)); } \

IRNODE_ALL_SUBCLASSES(DEFINE_VISIT_FUNCTIONS)
#undef DEFINE_VISIT_FUNCTIONS
Expand Down Expand Up @@ -440,9 +470,17 @@ void ControlFlowVisitor::init_join_flows(const IR::Node *root) {
bool ControlFlowVisitor::join_flows(const IR::Node *n) {
if (flow_join_points && flow_join_points->count(n)) {
auto &status = flow_join_points->at(n);
// BUG_CHECK(status.second > 0, "join point reached too many times");
// FIXME -- this means that we calculated the wrong number of parents for a
// join point, and completed the join sooner than we should have. This can
// happen if there are recursive calls in the visitor and the differing order
// between SetupJoinPoints and the main visitor means that the recursion is
// seen differently. Might be possible to fix this by careful use of
// loop_revisit, but if not, we might as well just merge what we have.

// Decrement the number of upstream edges yet to be traversed. If none
// remain, merge and return false to visit this node.
if (!--status.second) {
if (--status.second <= 0) {
flow_merge(*status.first);
return false;
} else if (status.first) {
Expand Down
17 changes: 14 additions & 3 deletions ir/visitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -299,13 +299,16 @@ class Modifier : public virtual Visitor {
virtual bool preorder(IR::Node *) { return true; }
virtual void postorder(IR::Node *) {}
virtual void revisit(const IR::Node *, const IR::Node *) {}
virtual void loop_revisit(const IR::Node *) { BUG("IR loop detected"); }
#define DECLARE_VISIT_FUNCTIONS(CLASS, BASE) \
virtual bool preorder(IR::CLASS *); \
virtual void postorder(IR::CLASS *); \
virtual void revisit(const IR::CLASS *, const IR::CLASS *);
virtual void revisit(const IR::CLASS *, const IR::CLASS *); \
virtual void loop_revisit(const IR::CLASS *);
IRNODE_ALL_SUBCLASSES(DECLARE_VISIT_FUNCTIONS)
#undef DECLARE_VISIT_FUNCTIONS
void revisit_visited();
bool visit_in_progress(const IR::Node *) const;
};

class Inspector : public virtual Visitor {
Expand All @@ -319,13 +322,18 @@ class Inspector : public virtual Visitor {
virtual bool preorder(const IR::Node *) { return true; } // return 'false' to prune
virtual void postorder(const IR::Node *) {}
virtual void revisit(const IR::Node *) {}
virtual void loop_revisit(const IR::Node *) { BUG("IR loop detected"); }
#define DECLARE_VISIT_FUNCTIONS(CLASS, BASE) \
virtual bool preorder(const IR::CLASS *); \
virtual void postorder(const IR::CLASS *); \
virtual void revisit(const IR::CLASS *);
virtual void revisit(const IR::CLASS *); \
virtual void loop_revisit(const IR::CLASS *);
IRNODE_ALL_SUBCLASSES(DECLARE_VISIT_FUNCTIONS)
#undef DECLARE_VISIT_FUNCTIONS
void revisit_visited();
bool visit_in_progress(const IR::Node *n) const {
if (visited->count(n)) return visited->at(n).done;
return false; }
};

class Transform : public virtual Visitor {
Expand All @@ -340,13 +348,16 @@ class Transform : public virtual Visitor {
virtual const IR::Node *preorder(IR::Node *n) {return n;}
virtual const IR::Node *postorder(IR::Node *n) {return n;}
virtual void revisit(const IR::Node *, const IR::Node *) {}
virtual void loop_revisit(const IR::Node *) { BUG("IR loop detected"); }
#define DECLARE_VISIT_FUNCTIONS(CLASS, BASE) \
virtual const IR::Node *preorder(IR::CLASS *); \
virtual const IR::Node *postorder(IR::CLASS *); \
virtual void revisit(const IR::CLASS *, const IR::Node *);
virtual void revisit(const IR::CLASS *, const IR::Node *); \
virtual void loop_revisit(const IR::CLASS *);
IRNODE_ALL_SUBCLASSES(DECLARE_VISIT_FUNCTIONS)
#undef DECLARE_VISIT_FUNCTIONS
void revisit_visited();
bool visit_in_progress(const IR::Node *) const;
// can only be called usefully from a 'preorder' function (directly or indirectly)
void prune() { prune_flag = true; }

Expand Down

0 comments on commit df473f7

Please sign in to comment.