Skip to content

Commit 9bcf0bc

Browse files
authored
[Relay] add redirecting operation to dataflow pattern graph (#15392)
* Add redirecting operation to dataflow pattern graph * Lint
1 parent ac99367 commit 9bcf0bc

File tree

7 files changed

+101
-3
lines changed

7 files changed

+101
-3
lines changed

include/tvm/relay/dataflow_pattern.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -362,6 +362,10 @@ class WildcardPatternNode : public DFPatternNode {
362362
public:
363363
void VisitAttrs(tvm::AttrVisitor* v) {}
364364

365+
/*! \brief If the wildcard is redirected, then pattern is not nullptr, and the wildcard
366+
* redirects to the pattern. */
367+
Optional<DFPattern> pattern{nullptr};
368+
365369
static constexpr const char* _type_key = "relay.dataflow_pattern.WildcardPattern";
366370
TVM_DECLARE_FINAL_OBJECT_INFO(WildcardPatternNode, DFPatternNode);
367371
};
@@ -372,6 +376,8 @@ class WildcardPatternNode : public DFPatternNode {
372376
class WildcardPattern : public DFPattern {
373377
public:
374378
TVM_DEFINE_OBJECT_REF_METHODS(WildcardPattern, DFPattern, WildcardPatternNode);
379+
380+
void redirect_to(DFPattern pat) const;
375381
};
376382

377383
class TypePattern;

python/tvm/relay/dataflow_pattern/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -722,6 +722,19 @@ class WildcardPattern(DFPattern):
722722
def __init__(self):
723723
self.__init_handle_by_constructor__(ffi.WildcardPattern)
724724

725+
def redirect_to(
726+
self,
727+
pat: "DFPattern",
728+
):
729+
"""Redirect the WildcardPattern to another pattern
730+
731+
Parameters
732+
----------
733+
pat: relay.dataflow_pattern.DFPattern
734+
The pattern that wildcard is redirected to.
735+
"""
736+
ffi.WildcardPattern_redirect_to(self, pat)
737+
725738

726739
@register_df_node
727740
class TypePattern(DFPattern):

src/relay/ir/dataflow_matcher.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -488,7 +488,11 @@ bool DFPatternMatcher::VisitDFPattern_(const ConstantPatternNode* op, const Expr
488488
}
489489

490490
bool DFPatternMatcher::VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) {
491-
return true;
491+
if (op->pattern) {
492+
return VisitDFPattern(op->pattern.value(), expr);
493+
} else {
494+
return true;
495+
}
492496
}
493497

494498
bool MatchPattern(DFPattern pattern, Expr expr) {

src/relay/ir/dataflow_pattern.cc

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,8 +344,18 @@ TVM_STATIC_IR_FUNCTOR(DFPatternPrinter, vtable)
344344
<< ")";
345345
});
346346

347+
void WildcardPattern::redirect_to(DFPattern pat) const {
348+
WildcardPatternNode* ptr = static_cast<WildcardPatternNode*>(get_mutable());
349+
ptr->pattern = pat;
350+
}
351+
347352
TVM_REGISTER_NODE_TYPE(WildcardPatternNode);
348353

354+
TVM_REGISTER_GLOBAL("relay.dataflow_pattern.WildcardPattern_redirect_to")
355+
.set_body_typed([](WildcardPattern wildcard, DFPattern pat) {
356+
return wildcard.redirect_to(pat);
357+
});
358+
349359
TVM_REGISTER_GLOBAL("relay.dataflow_pattern.WildcardPattern").set_body_typed([]() {
350360
auto w = WildcardPattern(make_object<WildcardPatternNode>());
351361
return w;

src/relay/ir/dataflow_pattern_functor.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,11 @@ void DFPatternVisitor::VisitDFPattern_(const VarPatternNode* op) {}
105105

106106
void DFPatternVisitor::VisitDFPattern_(const ConstantPatternNode* op) {}
107107

108-
void DFPatternVisitor::VisitDFPattern_(const WildcardPatternNode* op) {}
108+
void DFPatternVisitor::VisitDFPattern_(const WildcardPatternNode* op) {
109+
if (op->pattern) {
110+
VisitDFPattern(op->pattern.value());
111+
}
112+
}
109113

110114
} // namespace relay
111115
} // namespace tvm

src/relay/ir/indexed_graph.cc

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -537,7 +537,12 @@ std::unique_ptr<IndexedGraph<DFPattern>> CreateIndexedGraph(const DFPattern& pat
537537

538538
void VisitDFPattern_(const VarPatternNode* op) override {}
539539

540-
void VisitDFPattern_(const WildcardPatternNode* op) override {}
540+
void VisitDFPattern_(const WildcardPatternNode* op) override {
541+
if (op->pattern) {
542+
auto node = graph_->item_to_node(GetRef<WildcardPattern>(op));
543+
AddOutput(op->pattern.value(), node);
544+
}
545+
}
541546

542547
std::unique_ptr<IndexedGraph<DFPattern>> graph_;
543548
};

tests/python/relay/test_dataflow_pattern.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1995,5 +1995,61 @@ def test_partition_parallel_branch_with_same_input():
19951995
assert tvm.ir.structural_equal(partitioned, reference)
19961996

19971997

1998+
def test_rewrite_with_pattern_recursion():
1999+
data = relay.var("data", relay.TensorType((2, 8), "float32"))
2000+
dense_weight = relay.const(np.zeros((4, 8)))
2001+
feat = relay.nn.dense(data, dense_weight)
2002+
feat = relay.cast(feat, "float32")
2003+
feat = relay.cast(feat, "float32")
2004+
feat = relay.cast(feat, "float32")
2005+
feat = relay.cast(feat, "float32")
2006+
feat = relay.cast(feat, "float32")
2007+
oup = relay.cast(feat, "float32")
2008+
2009+
expected = relay.nn.relu(oup)
2010+
2011+
class TheRewrite(DFPatternCallback):
2012+
def __init__(self, pattern):
2013+
super(TheRewrite, self).__init__(rewrite_once=True)
2014+
self.pattern = pattern
2015+
2016+
def callback(self, pre, post, node_map):
2017+
return relay.nn.relu(post)
2018+
2019+
def test_reset_call_args():
2020+
dense_pattern = is_op("nn.dense")(wildcard(), wildcard())
2021+
wildcard_redirect = wildcard()
2022+
the_pattern = is_op("cast")(wildcard_redirect)
2023+
the_pattern2 = the_pattern | dense_pattern
2024+
wildcard_redirect.redirect_to(the_pattern2)
2025+
2026+
actual = rewrite(TheRewrite(the_pattern), oup)
2027+
tvm.ir.assert_structural_equal(actual, expected)
2028+
2029+
def test_reset_alt_left():
2030+
dense_pattern = is_op("nn.dense")(wildcard(), wildcard())
2031+
wildcard_redirect = wildcard()
2032+
or_pattern = wildcard_redirect | dense_pattern
2033+
the_pattern = is_op("cast")(or_pattern)
2034+
wildcard_redirect.redirect_to(the_pattern)
2035+
2036+
actual = rewrite(TheRewrite(the_pattern), oup)
2037+
tvm.ir.assert_structural_equal(actual, expected)
2038+
2039+
def test_reset_alt_right():
2040+
dense_pattern = is_op("nn.dense")(wildcard(), wildcard())
2041+
wildcard_redirect = wildcard()
2042+
or_pattern = dense_pattern | wildcard_redirect
2043+
the_pattern = is_op("cast")(or_pattern)
2044+
wildcard_redirect.redirect_to(the_pattern)
2045+
2046+
actual = rewrite(TheRewrite(the_pattern), oup)
2047+
tvm.ir.assert_structural_equal(actual, expected)
2048+
2049+
test_reset_call_args()
2050+
test_reset_alt_left()
2051+
test_reset_alt_right()
2052+
2053+
19982054
if __name__ == "__main__":
19992055
tvm.testing.main()

0 commit comments

Comments
 (0)