Skip to content

Commit 8fd3348

Browse files
committed
add testcase
1 parent 08f91b9 commit 8fd3348

File tree

1 file changed

+22
-0
lines changed

1 file changed

+22
-0
lines changed

tests/python/relay/test_dataflow_pattern.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -696,6 +696,28 @@ def test_match_dominator():
696696
assert diamond.match(out)
697697

698698

699+
def test_match_dominator2():
700+
# Pattern
701+
conv2d_pat = is_op("nn.conv2d")(wildcard(), wildcard())
702+
eltwise_pat = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(None)
703+
broadcast_pat = (wildcard().has_attr({"TOpPattern": K_BROADCAST}))(None)
704+
path_pat = eltwise_pat | broadcast_pat
705+
injective_pat = (wildcard().has_attr({"TOpPattern": K_INJECTIVE}))(wildcard())
706+
pattern = injective_pat.dominates(conv2d_pat, path_pat)
707+
708+
# Graph
709+
inp = relay.var("input")
710+
weight = relay.var("weight")
711+
bias = relay.var("bias")
712+
conv2d = relay.op.nn.conv2d(inp, weight)
713+
bias_add = relay.op.nn.bias_add(conv2d, bias)
714+
relu = relay.op.nn.relu(bias_add)
715+
reshape = relay.op.reshape(relu, newshape=[-1, 2, 8])
716+
717+
# Check
718+
assert pattern.match(reshape)
719+
720+
699721
def test_not_match_dominator():
700722
is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard())
701723
is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard())

0 commit comments

Comments
 (0)