Skip to content

Commit 3a873d9

Browse files
author
Joey Tsai
committed
[SpanFillingCommonAPI]
- Change the test cases to pytest style - Group the set_span test cases to a class
1 parent 556ec6b commit 3a873d9

File tree

2 files changed

+33
-48
lines changed

2 files changed

+33
-48
lines changed

tests/python/frontend/test_common.py

Lines changed: 25 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ def test_key_is_not_present():
3232
assert not attrs.has_attr("b")
3333

3434

35-
def test_set_span():
36-
def _verify_env_var_switch():
35+
class TestSetSpan:
36+
def test_env_var_switch(self):
3737
def _res(should_fill):
3838
if should_fill:
3939
with testing.enable_span_filling():
@@ -45,11 +45,11 @@ def _res(should_fill):
4545
disable = relay.var("x", shape=(1, 64, 56, 56))
4646
enable = relay.var("x", shape=(1, 64, 56, 56), span=_create_span("x_var"))
4747

48-
assert _verify_structural_equal_with_span(_res(False), disable)
49-
assert _verify_structural_equal_with_span(_res(True), enable)
48+
_verify_structural_equal_with_span(_res(False), disable)
49+
_verify_structural_equal_with_span(_res(True), enable)
5050

5151
# Should tag all exprs without span, and stop when expr is span-tagged
52-
def _verify_builtin_tuple():
52+
def test_builtin_tuple(self):
5353
def _res():
5454
a = relay.const(np.ones([1, 1, 1]), dtype="int64", span=_create_span("a"))
5555
b = relay.const(np.zeros([1, 1, 1]), dtype="int64")
@@ -63,9 +63,9 @@ def _golden():
6363
res_tuple, golden_tuple = _res(), _golden()
6464
assert len(res_tuple) == len(golden_tuple)
6565
for i in range(len(res_tuple)):
66-
assert _verify_structural_equal_with_span(res_tuple[i], golden_tuple[i])
66+
_verify_structural_equal_with_span(res_tuple[i], golden_tuple[i])
6767

68-
def _verify_builtin_list():
68+
def test_builtin_list(self):
6969
def _res():
7070
a = relay.const(np.ones([1, 1, 1]), dtype="int64", span=_create_span("a"))
7171
b = relay.const(np.zeros([1, 1, 1]), dtype="int64")
@@ -85,21 +85,21 @@ def _golden():
8585
res_list, golden_list = _res(), _golden()
8686
assert len(res_list) == len(golden_list)
8787
for i in range(len(res_list)):
88-
assert _verify_structural_equal_with_span(res_list[i], golden_list[i])
88+
_verify_structural_equal_with_span(res_list[i], golden_list[i])
8989

90-
def _verify_var():
90+
def test_var(self):
9191
x = set_span(relay.var("x", shape=(1, 64, 56, 56)), "x_var")
9292
x_expected = relay.var("x", shape=(1, 64, 56, 56), span=_create_span("x_var"))
93-
assert _verify_structural_equal_with_span(x, x_expected)
93+
_verify_structural_equal_with_span(x, x_expected)
9494

95-
def _verify_constant():
95+
def test_constant(self):
9696
c = set_span(relay.const(np.ones([64, 64, 3, 3]), dtype="int64"), "const_c")
9797
c_expected = relay.const(
9898
np.ones([64, 64, 3, 3]), dtype="int64", span=_create_span("const_c")
9999
)
100-
assert _verify_structural_equal_with_span(c, c_expected)
100+
_verify_structural_equal_with_span(c, c_expected)
101101

102-
def _verify_call():
102+
def test_call(self):
103103
def _res():
104104
x = set_span(relay.var("x", shape=(1, 64, 56, 56)), "x_var")
105105
w = relay.const(np.ones([64, 64, 3, 3]), dtype="int64")
@@ -116,9 +116,9 @@ def _golden():
116116
)
117117
return relay.Function([x], y)
118118

119-
assert _verify_structural_equal_with_span(_res(), _golden())
119+
_verify_structural_equal_with_span(_res(), _golden())
120120

121-
def _verify_tuple():
121+
def test_tuple(self):
122122
def _res():
123123
a = set_span(relay.const(np.ones([1, 1, 1]), dtype="int64"), "a")
124124
b = relay.const(np.ones([1, 1, 1]), dtype="int64")
@@ -131,9 +131,9 @@ def _golden():
131131
t = relay.Tuple([a, b], span=_create_span("t"))
132132
return relay.Function([], t)
133133

134-
assert _verify_structural_equal_with_span(_res(), _golden())
134+
_verify_structural_equal_with_span(_res(), _golden())
135135

136-
def _verify_tuple_getitem():
136+
def test_tuple_getitem(self):
137137
def _res():
138138
a = set_span(relay.const(np.ones([1, 1, 1]), dtype="int64"), "a")
139139
b = relay.const(np.ones([1, 1, 1]), dtype="int64")
@@ -148,9 +148,9 @@ def _golden():
148148
i = relay.TupleGetItem(t, 0, span=_create_span("i"))
149149
return relay.Function([], i)
150150

151-
assert _verify_structural_equal_with_span(_res(), _golden())
151+
_verify_structural_equal_with_span(_res(), _golden())
152152

153-
def _verify_let():
153+
def test_let(self):
154154
def _res():
155155
x = set_span(relay.Var("x"), "x_var")
156156
c_1 = relay.const(np.ones(10))
@@ -171,9 +171,9 @@ def _golden():
171171
y = _set_span(relay.add(body, c_2), "add_2")
172172
return relay.Function([x], y)
173173

174-
assert _verify_structural_equal_with_span(_res(), _golden())
174+
_verify_structural_equal_with_span(_res(), _golden())
175175

176-
def _verify_if():
176+
def test_if(self):
177177
def _res():
178178
x = set_span(relay.var("x", shape=[], dtype="float32"), "x_var")
179179
y = set_span(relay.var("y", shape=[], dtype="float32"), "y_var")
@@ -194,9 +194,9 @@ def _golden():
194194
ife = relay.If(eq, true_branch, false_branch, span=_create_span("if"))
195195
return relay.Function([x, y], ife)
196196

197-
assert _verify_structural_equal_with_span(_res(), _golden())
197+
_verify_structural_equal_with_span(_res(), _golden())
198198

199-
def _verify_fn():
199+
def test_fn(self):
200200
def _res():
201201
x = set_span(relay.var("x", shape=(1, 64, 56, 56)), "x_var")
202202
w = relay.const(np.ones([64, 64, 3, 3]), dtype="int64")
@@ -213,22 +213,8 @@ def _golden():
213213
f = relay.Function([x], y, span=_create_span("func"))
214214
return f
215215

216-
assert _verify_structural_equal_with_span(_res(), _golden())
217-
218-
_verify_env_var_switch()
219-
_verify_builtin_tuple()
220-
_verify_builtin_list()
221-
_verify_var()
222-
_verify_constant()
223-
_verify_call()
224-
_verify_tuple()
225-
_verify_tuple_getitem()
226-
_verify_let()
227-
_verify_if()
228-
_verify_fn()
216+
_verify_structural_equal_with_span(_res(), _golden())
229217

230218

231219
if __name__ == "__main__":
232-
test_key_is_present()
233-
test_key_is_present()
234-
test_set_span()
220+
testing.main()

tests/python/relay/utils/tag_span.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -85,22 +85,21 @@ def get_spans(self):
8585
def _verify_span(lhs, rhs):
8686
lhs_spans, rhs_spans = _collect_spans(lhs), _collect_spans(rhs)
8787

88-
if len(lhs_spans) != len(rhs_spans):
89-
return False
88+
assert len(lhs_spans) == len(rhs_spans)
9089

9190
for i in range(len(lhs_spans)):
92-
if not tvm.ir.structural_equal(lhs_spans[i], rhs_spans[i]):
93-
return False
94-
return True
91+
assert tvm.ir.structural_equal(lhs_spans[i], rhs_spans[i])
9592

9693

9794
def _verify_structural_equal_with_span(lhs, rhs, assert_mode=False, map_free_vars=False):
9895
if isinstance(lhs, relay.Var) and isinstance(rhs, relay.Var):
99-
return _verify_span(lhs, rhs)
96+
# SEqualReduce compares the vid of Var type. Threrfore we only compare span here.
97+
_verify_span(lhs, rhs)
98+
return
10099

101100
if assert_mode:
102101
tvm.ir.assert_structural_equal(lhs, rhs, map_free_vars)
103-
elif not tvm.ir.structural_equal(lhs, rhs, map_free_vars):
104-
return False
102+
else:
103+
assert tvm.ir.structural_equal(lhs, rhs, map_free_vars)
105104

106-
return _verify_span(lhs, rhs)
105+
_verify_span(lhs, rhs)

0 commit comments

Comments
 (0)