Skip to content

Commit

Permalink
RelayTextPrinter is now non-recursive. ExpandDataflow refactored
Browse files Browse the repository at this point in the history
RelayTextPrinter is now non-recursive to allow printing larger
graphs. ExpandDataflow is generalised to have separate node expander.

Change-Id: Id5a3a470fbc8b90822502fbc8d24d534df1ea355
  • Loading branch information
d-smirnov committed Apr 11, 2021
1 parent 461d06e commit b3ef0a9
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 51 deletions.
102 changes: 58 additions & 44 deletions include/tvm/relay/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,12 @@
#include <tvm/relay/function.h>
#include <tvm/relay/op.h>

#include <deque>
#include <stack>
#include <string>
#include <unordered_map>
#include <utility>

#include <vector>
namespace tvm {
namespace relay {

Expand Down Expand Up @@ -276,7 +277,9 @@ class MixedModeVisitor : public ::tvm::relay::ExprVisitor {
*/
class MixedModeMutator : public ::tvm::relay::ExprMutator {
public:
MixedModeMutator(bool pre = false) : pre_{pre} {};
Expr VisitExpr(const Expr& expr) final;

virtual Expr DispatchVisitExpr(const Expr& expr);
Expr VisitExpr_(const TupleNode* op) final { return Rewrite(op); };
Expr VisitExpr_(const CallNode* call_node) final { return Rewrite(call_node); };
Expand All @@ -294,6 +297,7 @@ class MixedModeMutator : public ::tvm::relay::ExprMutator {
virtual Expr Rewrite_(const TupleGetItemNode* pre, const Expr& post) { return post; }

protected:
bool pre_;
/*! \brief Implement Rewrite API by calling ExprMutator's VisitExpr_(op) to get a `post` node with
* changed inputs.
*/
Expand Down Expand Up @@ -410,72 +414,82 @@ Expr PostOrderRewrite(const Expr& expr, ExprRewriter* rewriter);
*/
void PostOrderVisit(const Expr& node, std::function<void(const Expr&)> fvisit);

/*!
* \brief A struct to keep info of traversed expr in ExpandDataflow function
*/
struct v_info {
explicit v_info(Expr node_) : node{node_} {}
v_info(Expr node_, bool children_expanded_)
: node{node_}, children_expanded{children_expanded_} {};
Expr node{};
bool children_expanded{false};
};

/*!
* \brief A function to iteratively traverse dataflow regions of a graph
*
* ExpandDataflow manually manages a stack and performs DFS to determine the processing
* order of nodes in an input graph.
*
* If it finds a dataflow node (Call, Tuple, TupleGetItem), it checks if the arguments to that node
* need to be processed via fcheck_visited. If so, the function pushes those arguments to the stack
* and continues iteratively to process the top of the stack. When it finds a node that doesn't
* match the dataflow types, or a node who's inputs have all been processed, it visits the current
* leaf via fvisit_leaf.
* By default fexpand_expr implemented in a way that if it finds a dataflow node (Call, Tuple,
* TupleGetItem), it checks if the arguments to that node need to be processed via fcheck_visited.
* If so, the function pushes those arguments to the stack and continues iteratively to process
* the top of the stack. When it finds a node that doesn't match the dataflow types, or a node who's
* inputs have all been processed, it visits the current leaf via fvisit_leaf.
*
* This function should be used internally to other classes to implement mixed-mode traversals. The
* expectation is that fvisit_leaf will perform recursive analysis within mixed-mode traversal if it
* hits a non-dataflow node.
*
* fcheck_visited and fvisit_leaf are templated to encourage compiler inlining.
* fcheck_visited, fvisit_leaf and fexpand_expr are templated to encourage reusing.
*/
template <typename FCheckVisited, typename FVisitLeaf>
void ExpandDataflow(Expr expr, FCheckVisited fcheck_visited, FVisitLeaf fvisit_leaf) {
std::stack<std::pair<Expr, bool>> stack;
template <typename FCheckVisited, typename FVisitLeaf, typename FExpandExpr>
void ExpandDataflow(Expr expr, FCheckVisited fcheck_visited, FVisitLeaf fvisit_leaf,
FExpandExpr fexpand_expr) {
std::deque<v_info> stack;
auto fpush_to_stack = [&fcheck_visited, &stack](const Expr& expr) {
// The second state of the stack indicate whether the child has been
// expanded in the pre-order.
// NOTE: function will be inlined.
if (!fcheck_visited(expr)) {
stack.push({expr, false});
stack.push_front(std::move(v_info(expr)));
}
};

fpush_to_stack(expr);
while (stack.size() > 0) {
auto node = stack.top().first;
if (fcheck_visited(node)) {
// if this node was visited through another path
// after being added to the stack ignore it.
stack.pop();
} else if (stack.top().second) {
// all the children have already been expanded.
// we can just run post order visit on it.
fvisit_leaf(node);
stack.pop();
} else if (const CallNode* op = node.as<CallNode>()) {
// mark expanded = true
stack.top().second = true;
// push the children to the stack in reverse order
// to match recursive processing order
v_info* front = &stack.front();
if (fcheck_visited(front->node)) {
stack.pop_front();
} else if (front->children_expanded) {
fvisit_leaf(front->node);
// TODO(d-smirnov): this is for compatibility with current implementation of MixedModeVisitor
stack.pop_front();
} else {
front->children_expanded = true;
for (auto e : fexpand_expr(front->node)) {
fpush_to_stack(e);
}
}
}
}

template <typename FCheckVisited, typename FVisitLeaf>
void ExpandDataflow(Expr expr, FCheckVisited fcheck_visited, FVisitLeaf fvisit_leaf) {
auto fexpand_expr = [](const Expr& expr) {
std::vector<Expr> result;
if (const CallNode* op = expr.as<CallNode>()) {
for (auto it = op->args.rbegin(); it != op->args.rend(); ++it) {
fpush_to_stack(*it);
result.push_back(*it);
}
fpush_to_stack(op->op);
} else if (const TupleNode* op = node.as<TupleNode>()) {
stack.top().second = true;
// push the children to the stack in reverse order
// to match recursive processing order
result.push_back(op->op);
} else if (const TupleNode* op = expr.as<TupleNode>()) {
for (auto it = op->fields.rbegin(); it != op->fields.rend(); ++it) {
fpush_to_stack(*it);
result.push_back(*it);
}
} else if (const TupleGetItemNode* op = node.as<TupleGetItemNode>()) {
stack.top().second = true;
fpush_to_stack(op->tuple);
} else {
// No need to expand the children directly run visit.
fvisit_leaf(node);
stack.pop();
} else if (const TupleGetItemNode* op = expr.as<TupleGetItemNode>()) {
result.push_back(op->tuple);
}
}
return std::move(result);
};
ExpandDataflow(expr, fcheck_visited, fvisit_leaf, fexpand_expr);
}

void ExpandANormalForm(const LetNode* op, std::function<void(const LetNode*)> pre_visit,
Expand Down
39 changes: 35 additions & 4 deletions src/printer/relay_text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,34 @@ bool RelayTextPrinter::AlwaysInline(const Expr& expr) {
expr.as<VarNode>() || expr.as<ConstructorNode>();
}

Doc RelayTextPrinter::VisitLeaf(const Expr& expr) {
if (!CheckVisited(expr)) {
Doc result = ExprFunctor<Doc(const Expr&)>::VisitExpr(expr);
// Add if not added after visiting
if (!CheckVisited(expr)) {
memo_[expr] = result;
} else {
result_memo_[expr] = result;
}
return result;
}
return memo_[expr];
}

bool RelayTextPrinter::CheckVisited(const Expr& expr) { return (memo_.count(expr)); }

Doc RelayTextPrinter::VisitExpr(const Expr& expr) {
auto fcheck_visited = [this](const Expr& expr) { return this->CheckVisited(expr); };
auto fvisit_leaf = [this](const Expr& expr) { return this->VisitLeaf(expr); };

if (fcheck_visited(expr)) {
return memo_[expr];
} else {
ExpandDataflow(expr, fcheck_visited, fvisit_leaf);
return memo_[expr];
}
}

//------------------------------------
// Overload of Expr printing functions
//------------------------------------
Expand All @@ -252,9 +280,6 @@ Doc RelayTextPrinter::PrintExpr(const Expr& expr, bool meta, bool try_inline, bo
inline_expr |= IsUnique(expr);
}

auto it = memo_.find(expr);
if (it != memo_.end()) return it->second;

Doc printed_expr;

if (meta) {
Expand All @@ -277,13 +302,19 @@ Doc RelayTextPrinter::PrintExpr(const Expr& expr, bool meta, bool try_inline, bo
if (expr.as<VarNode>()) {
// This is our first time visiting the var and we hit the VarNode case
// in the visitor. Thus the variable is free.
doc_stack_.back() << "free_var " << printed_expr << ";" << Doc::NewLine();
if (var_memo_.insert(expr).second && result_memo_.count(expr)) {
doc_stack_.back() << "free_var " << result_memo_[expr] << ";" << Doc::NewLine();
}
// Memoization is done in AllocVar.
return memo_[expr];
} else if (inline_expr) {
memo_[expr] = printed_expr;
return printed_expr;
} else {
// Already exists. Reuse
if (!var_memo_.insert(expr).second) {
return memo_[expr];
}
Doc temp_var = AllocTemp();
memo_[expr] = temp_var;
doc_stack_.back() << temp_var << " = " << printed_expr << ";" << Doc::NewLine();
Expand Down
8 changes: 8 additions & 0 deletions src/printer/text_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>

#include "../ir/attr_functor.h"
Expand All @@ -60,6 +61,9 @@ class RelayTextPrinter : public ExprFunctor<Doc(const Expr&)>,
explicit RelayTextPrinter(bool show_meta_data, TextMetaDataContext* meta,
runtime::TypedPackedFunc<std::string(ObjectRef)> annotate)
: show_meta_data_(show_meta_data), annotate_(annotate), meta_(meta) {}
Doc VisitExpr(const Expr& expr) override;
virtual Doc VisitLeaf(const Expr& expr);
virtual bool CheckVisited(const Expr& expr);

/*!
* \brief Print additional info about expr in comment.
Expand Down Expand Up @@ -170,6 +174,10 @@ class RelayTextPrinter : public ExprFunctor<Doc(const Expr&)>,
runtime::TypedPackedFunc<std::string(ObjectRef)> annotate_;
/*! \brief Stack of docs to implement scoped GNFing. */
std::vector<Doc> doc_stack_{};
/*! \brief Set for introduced vars */
std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> var_memo_;
/*! \brief Map for result and memo_ diffs for visited expression */
std::unordered_map<Expr, Doc, ObjectPtrHash, ObjectPtrEqual> result_memo_;
/*! \brief Map from Expr to Doc */
std::unordered_map<Expr, Doc, ObjectPtrHash, ObjectPtrEqual> memo_;
/*! \brief Map from Type to Doc */
Expand Down
6 changes: 3 additions & 3 deletions src/relay/analysis/dependency_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ namespace tvm {
namespace relay {

// Creator of DependencyGraph
class DependencyGraph::Creator : private ExprFunctor<void(const Expr& e)> {
class DependencyGraph::Creator : private MixedModeVisitor {
public:
explicit Creator(support::Arena* arena) : arena_(arena) {}

Expand Down Expand Up @@ -73,13 +73,13 @@ class DependencyGraph::Creator : private ExprFunctor<void(const Expr& e)> {
return ret;
}

void VisitExpr(const Expr& e) final {
void VisitLeaf(const Expr& e) override {
if (visited_.count(e) == 0) {
if (graph_.expr_node.count(e) == 0) {
graph_.expr_node[e] = NewNode(false);
}
visited_.insert(e);
ExprFunctor<void(const Expr&)>::VisitExpr(e);
MixedModeVisitor::VisitLeaf(e);
graph_.post_dfs_order.push_back(graph_.expr_node[e]);
}
}
Expand Down

0 comments on commit b3ef0a9

Please sign in to comment.