diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h index e6eec61a7e9d0..7ae6a54391c3c 100644 --- a/include/tvm/relay/expr_functor.h +++ b/include/tvm/relay/expr_functor.h @@ -32,11 +32,12 @@ #include #include +#include #include #include #include #include - +#include namespace tvm { namespace relay { @@ -276,7 +277,9 @@ class MixedModeVisitor : public ::tvm::relay::ExprVisitor { */ class MixedModeMutator : public ::tvm::relay::ExprMutator { public: + MixedModeMutator(bool pre = false) : pre_{pre} {}; Expr VisitExpr(const Expr& expr) final; + virtual Expr DispatchVisitExpr(const Expr& expr); Expr VisitExpr_(const TupleNode* op) final { return Rewrite(op); }; Expr VisitExpr_(const CallNode* call_node) final { return Rewrite(call_node); }; @@ -294,6 +297,7 @@ class MixedModeMutator : public ::tvm::relay::ExprMutator { virtual Expr Rewrite_(const TupleGetItemNode* pre, const Expr& post) { return post; } protected: + bool pre_; /*! \brief Implement Rewrite API by calling ExprMutator's VisitExpr_(op) to get a `post` node with * changed inputs. */ @@ -410,72 +414,82 @@ Expr PostOrderRewrite(const Expr& expr, ExprRewriter* rewriter); */ void PostOrderVisit(const Expr& node, std::function fvisit); +/*! + * \brief A struct to keep info of traversed expr in ExpandDataflow function + */ +struct v_info { + explicit v_info(Expr node_) : node{node_} {} + v_info(Expr node_, bool children_expanded_) + : node{node_}, children_expanded{children_expanded_} {}; + Expr node{}; + bool children_expanded{false}; +}; + /*! * \brief A function to iteratively traverse dataflow regions of a graph * * ExpandDataflow manually manages a stack and performs DFS to determine the processing * order of nodes in an input graph. * - * If it finds a dataflow node (Call, Tuple, TupleGetItem), it checks if the arguments to that node - * need to be processed via fcheck_visited. If so, the function pushes those arguments to the stack - * and continues iteratively to process the top of the stack. When it finds a node that doesn't - * match the dataflow types, or a node who's inputs have all been processed, it visits the current - * leaf via fvisit_leaf. + * By default fexpand_expr implemented in a way that if it finds a dataflow node (Call, Tuple, + * TupleGetItem), it checks if the arguments to that node need to be processed via fcheck_visited. + * If so, the function pushes those arguments to the stack and continues iteratively to process + * the top of the stack. When it finds a node that doesn't match the dataflow types, or a node who's + * inputs have all been processed, it visits the current leaf via fvisit_leaf. * * This function should be used internally to other classes to implement mixed-mode traversals. The * expectation is that fvisit_leaf will perform recursive analysis within mixed-mode traversal if it * hits a non-dataflow node. * - * fcheck_visited and fvisit_leaf are templated to encourage compiler inlining. + * fcheck_visited, fvisit_leaf and fexpand_expr are templated to encourage reusing. */ -template -void ExpandDataflow(Expr expr, FCheckVisited fcheck_visited, FVisitLeaf fvisit_leaf) { - std::stack> stack; +template +void ExpandDataflow(Expr expr, FCheckVisited fcheck_visited, FVisitLeaf fvisit_leaf, + FExpandExpr fexpand_expr) { + std::deque stack; auto fpush_to_stack = [&fcheck_visited, &stack](const Expr& expr) { - // The second state of the stack indicate whether the child has been - // expanded in the pre-order. - // NOTE: function will be inlined. if (!fcheck_visited(expr)) { - stack.push({expr, false}); + stack.push_front(std::move(v_info(expr))); } }; + fpush_to_stack(expr); while (stack.size() > 0) { - auto node = stack.top().first; - if (fcheck_visited(node)) { - // if this node was visited through another path - // after being added to the stack ignore it. - stack.pop(); - } else if (stack.top().second) { - // all the children have already been expanded. - // we can just run post order visit on it. - fvisit_leaf(node); - stack.pop(); - } else if (const CallNode* op = node.as()) { - // mark expanded = true - stack.top().second = true; - // push the children to the stack in reverse order - // to match recursive processing order + v_info* front = &stack.front(); + if (fcheck_visited(front->node)) { + stack.pop_front(); + } else if (front->children_expanded) { + fvisit_leaf(front->node); + // TODO(d-smirnov): this is for compatibility with current implementation of MixedModeVisitor + stack.pop_front(); + } else { + front->children_expanded = true; + for (auto e : fexpand_expr(front->node)) { + fpush_to_stack(e); + } + } + } +} + +template +void ExpandDataflow(Expr expr, FCheckVisited fcheck_visited, FVisitLeaf fvisit_leaf) { + auto fexpand_expr = [](const Expr& expr) { + std::vector result; + if (const CallNode* op = expr.as()) { for (auto it = op->args.rbegin(); it != op->args.rend(); ++it) { - fpush_to_stack(*it); + result.push_back(*it); } - fpush_to_stack(op->op); - } else if (const TupleNode* op = node.as()) { - stack.top().second = true; - // push the children to the stack in reverse order - // to match recursive processing order + result.push_back(op->op); + } else if (const TupleNode* op = expr.as()) { for (auto it = op->fields.rbegin(); it != op->fields.rend(); ++it) { - fpush_to_stack(*it); + result.push_back(*it); } - } else if (const TupleGetItemNode* op = node.as()) { - stack.top().second = true; - fpush_to_stack(op->tuple); - } else { - // No need to expand the children directly run visit. - fvisit_leaf(node); - stack.pop(); + } else if (const TupleGetItemNode* op = expr.as()) { + result.push_back(op->tuple); } - } + return std::move(result); + }; + ExpandDataflow(expr, fcheck_visited, fvisit_leaf, fexpand_expr); } void ExpandANormalForm(const LetNode* op, std::function pre_visit, diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc index 31f98ce4d2705..4fc03039466c7 100644 --- a/src/printer/relay_text_printer.cc +++ b/src/printer/relay_text_printer.cc @@ -236,6 +236,34 @@ bool RelayTextPrinter::AlwaysInline(const Expr& expr) { expr.as() || expr.as(); } +Doc RelayTextPrinter::VisitLeaf(const Expr& expr) { + if (!CheckVisited(expr)) { + Doc result = ExprFunctor::VisitExpr(expr); + // Add if not added after visiting + if (!CheckVisited(expr)) { + memo_[expr] = result; + } else { + result_memo_[expr] = result; + } + return result; + } + return memo_[expr]; +} + +bool RelayTextPrinter::CheckVisited(const Expr& expr) { return (memo_.count(expr)); } + +Doc RelayTextPrinter::VisitExpr(const Expr& expr) { + auto fcheck_visited = [this](const Expr& expr) { return this->CheckVisited(expr); }; + auto fvisit_leaf = [this](const Expr& expr) { return this->VisitLeaf(expr); }; + + if (fcheck_visited(expr)) { + return memo_[expr]; + } else { + ExpandDataflow(expr, fcheck_visited, fvisit_leaf); + return memo_[expr]; + } +} + //------------------------------------ // Overload of Expr printing functions //------------------------------------ @@ -252,9 +280,6 @@ Doc RelayTextPrinter::PrintExpr(const Expr& expr, bool meta, bool try_inline, bo inline_expr |= IsUnique(expr); } - auto it = memo_.find(expr); - if (it != memo_.end()) return it->second; - Doc printed_expr; if (meta) { @@ -277,13 +302,19 @@ Doc RelayTextPrinter::PrintExpr(const Expr& expr, bool meta, bool try_inline, bo if (expr.as()) { // This is our first time visiting the var and we hit the VarNode case // in the visitor. Thus the variable is free. - doc_stack_.back() << "free_var " << printed_expr << ";" << Doc::NewLine(); + if (var_memo_.insert(expr).second && result_memo_.count(expr)) { + doc_stack_.back() << "free_var " << result_memo_[expr] << ";" << Doc::NewLine(); + } // Memoization is done in AllocVar. return memo_[expr]; } else if (inline_expr) { memo_[expr] = printed_expr; return printed_expr; } else { + // Already exists. Reuse + if (!var_memo_.insert(expr).second) { + return memo_[expr]; + } Doc temp_var = AllocTemp(); memo_[expr] = temp_var; doc_stack_.back() << temp_var << " = " << printed_expr << ";" << Doc::NewLine(); diff --git a/src/printer/text_printer.h b/src/printer/text_printer.h index 90e46c5624fad..a4ce2f4994c5a 100644 --- a/src/printer/text_printer.h +++ b/src/printer/text_printer.h @@ -37,6 +37,7 @@ #include #include +#include #include #include "../ir/attr_functor.h" @@ -60,6 +61,9 @@ class RelayTextPrinter : public ExprFunctor, explicit RelayTextPrinter(bool show_meta_data, TextMetaDataContext* meta, runtime::TypedPackedFunc annotate) : show_meta_data_(show_meta_data), annotate_(annotate), meta_(meta) {} + Doc VisitExpr(const Expr& expr) override; + virtual Doc VisitLeaf(const Expr& expr); + virtual bool CheckVisited(const Expr& expr); /*! * \brief Print additional info about expr in comment. @@ -170,6 +174,10 @@ class RelayTextPrinter : public ExprFunctor, runtime::TypedPackedFunc annotate_; /*! \brief Stack of docs to implement scoped GNFing. */ std::vector doc_stack_{}; + /*! \brief Set for introduced vars */ + std::unordered_set var_memo_; + /*! \brief Map for result and memo_ diffs for visited expression */ + std::unordered_map result_memo_; /*! \brief Map from Expr to Doc */ std::unordered_map memo_; /*! \brief Map from Type to Doc */ diff --git a/src/relay/analysis/dependency_graph.cc b/src/relay/analysis/dependency_graph.cc index 3a4fb59475a4c..66ff8e684115c 100644 --- a/src/relay/analysis/dependency_graph.cc +++ b/src/relay/analysis/dependency_graph.cc @@ -32,7 +32,7 @@ namespace tvm { namespace relay { // Creator of DependencyGraph -class DependencyGraph::Creator : private ExprFunctor { +class DependencyGraph::Creator : private MixedModeVisitor { public: explicit Creator(support::Arena* arena) : arena_(arena) {} @@ -73,13 +73,13 @@ class DependencyGraph::Creator : private ExprFunctor { return ret; } - void VisitExpr(const Expr& e) final { + void VisitLeaf(const Expr& e) override { if (visited_.count(e) == 0) { if (graph_.expr_node.count(e) == 0) { graph_.expr_node[e] = NewNode(false); } visited_.insert(e); - ExprFunctor::VisitExpr(e); + MixedModeVisitor::VisitLeaf(e); graph_.post_dfs_order.push_back(graph_.expr_node[e]); } }