Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/relay/ir/dataflow_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -300,11 +300,17 @@ bool DFPatternMatcher::VisitDFPattern_(const CallPatternNode* op, const Expr& ex

// Recursively find the Dominator parent along all inputs paths.
bool DFPatternMatcher::MatchesPath(const DominatorPatternNode* op, const Expr& expr) {
// utilities
auto is_leaf_node = [](const Expr& expr) {
return expr.as<ConstantNode>() || expr.as<VarNode>();
};

// logic
auto call_node = expr.as<CallNode>();
auto index_node = expr_to_node(expr);
size_t arg_counter{0};
for (auto node : index_node->inputs_) {
if (!(call_node && node->ref() == call_node->op)) {
if (!(call_node && (node->ref() == call_node->op || is_leaf_node(node->ref())))) {
arg_counter += 1;
memoize_ = true;
if (!VisitDFPattern(op->parent, node->ref())) {
Expand Down
24 changes: 23 additions & 1 deletion tests/python/relay/test_dataflow_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
# convention.
K_ELEMWISE = 0
K_BROADCAST = 1

K_INJECTIVE = 2

## NODE TESTS
def test_expr_pattern():
Expand Down Expand Up @@ -696,6 +696,28 @@ def test_match_dominator():
assert diamond.match(out)


def test_match_dominator2():
# Pattern
conv2d_pat = is_op("nn.conv2d")(wildcard(), wildcard())
eltwise_pat = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(None)
broadcast_pat = (wildcard().has_attr({"TOpPattern": K_BROADCAST}))(None)
path_pat = eltwise_pat | broadcast_pat
injective_pat = (wildcard().has_attr({"TOpPattern": K_INJECTIVE}))(wildcard())
pattern = injective_pat.dominates(conv2d_pat, path_pat)

# Graph
inp = relay.var("input")
weight = relay.var("weight")
bias = relay.var("bias")
conv2d = relay.op.nn.conv2d(inp, weight)
bias_add = relay.op.nn.bias_add(conv2d, bias)
relu = relay.op.nn.relu(bias_add)
reshape = relay.op.reshape(relu, newshape=[-1, 2, 8])

# Check
assert pattern.match(reshape)


def test_not_match_dominator():
is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard())
is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard())
Expand Down