Skip to content

Commit 9863cf0

Browse files
author
Christian Convey
authored
[hexagon][testing] Better pytest ID strings (#12154)
- Add utility functions to allow more human-readable pytest test IDs. Helpful when ID strings become too large for humans to easily read. - Update the `test_avg_pool2d_slice.py` unit test to use this mechanism.
1 parent a07e18e commit 9863cf0

File tree

2 files changed

+125
-13
lines changed

2 files changed

+125
-13
lines changed
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
import pytest
19+
import numpy as np
20+
from typing import *
21+
import collections
22+
import tvm.testing
23+
24+
25+
def get_test_id(*test_params, test_param_descs: List[Optional[str]] = None) -> str:
26+
"""
27+
An opinionated alternative to pytest's default algorithm for generating a
28+
test's ID string. Intended to make it easier for human readers to
29+
interpret the test IDs.
30+
31+
'test_params': The sequence of pytest parameter values supplied to some unit
32+
test.
33+
34+
'test_param_descs': An (optional) means to provide additional text for some/all of the
35+
paramuments in 'test_params'.
36+
37+
If provided, then len(test_params) must equal len(test_param_descs).
38+
Each element test_param_descs that is a non-empty string will be used
39+
in some sensible way in this function's returned string.
40+
"""
41+
42+
assert len(test_params) > 0
43+
44+
if test_param_descs is None:
45+
test_param_descs = [None] * len(test_params)
46+
else:
47+
assert len(test_param_descs) == len(test_params)
48+
49+
def get_single_param_chunk(param_val, param_desc: Optional[str]):
50+
if type(param_val) == list:
51+
# Like str(list), but avoid the whitespace padding.
52+
val_str = "[" + ",".join(str(x) for x in param_val) + "]"
53+
need_prefix_separator = False
54+
55+
elif type(param_val) == bool:
56+
if param_val:
57+
val_str = "T"
58+
else:
59+
val_str = "F"
60+
need_prefix_separator = True
61+
62+
else:
63+
val_str = str(param_val)
64+
need_prefix_separator = True
65+
66+
if param_desc and need_prefix_separator:
67+
return f"{param_desc}:{val_str}"
68+
elif param_desc and not need_prefix_separator:
69+
return f"{param_desc}{val_str}"
70+
else:
71+
return val_str
72+
73+
chunks = [
74+
get_single_param_chunk(param_val, param_desc)
75+
for param_val, param_desc in zip(test_params, test_param_descs)
76+
]
77+
return "-".join(chunks)
78+
79+
80+
def get_multitest_ids(
81+
multitest_params_list: List[List], param_descs: Optional[List[Optional[str]]]
82+
) -> List[str]:
83+
"""
84+
A convenience function for classes that use both 'tvm.testing.parameters' and 'get_test_id'.
85+
86+
This function provides a workaround for a specific quirk in Python, where list-comprehension
87+
can't necessarily access the value of another class-variable, discused here:
88+
https://stackoverflow.com/q/13905741
89+
"""
90+
return [
91+
get_test_id(*single_test_param_list, test_param_descs=param_descs)
92+
for single_test_param_list in multitest_params_list
93+
]

tests/python/contrib/test_hexagon/topi/test_avg_pool2d_slice.py

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
import pytest
1919
import numpy as np
20+
from typing import *
21+
import collections
2022

2123
from tvm import te
2224
import tvm.testing
@@ -25,6 +27,7 @@
2527
from tvm.contrib.hexagon.session import Session
2628
import tvm.topi.hexagon.slice_ops as sl
2729
from ..infrastructure import allocate_hexagon_array, transform_numpy
30+
from ..pytest_util import get_multitest_ids
2831

2932

3033
input_layout = tvm.testing.parameter(
@@ -48,18 +51,19 @@ def transformed_input_np_padded(input_np_padded, input_layout):
4851

4952

5053
class TestAvgPool2dSlice:
51-
# NOTE: input_layout is always assumed to be "nhwc-8h2w32c2w-2d"
52-
(
53-
output_shape,
54-
kernel,
55-
stride,
56-
dilation,
57-
padding,
58-
ceil_mode,
59-
count_include_pad,
60-
output_layout,
61-
dtype,
62-
) = tvm.testing.parameters(
54+
_param_descs = [
55+
"out_shape", # output_shape
56+
"kernel", # kernel
57+
"stride", # stride
58+
"dil", # dilation
59+
"pad", # padding
60+
"ceil", # ceil_mode
61+
"cnt_padded", # count_include_pad
62+
"out_layout", # output_layout
63+
None, # dtype
64+
]
65+
66+
_multitest_params = [
6367
(
6468
[1, 8, 8, 32],
6569
[3, 3],
@@ -217,7 +221,22 @@ class TestAvgPool2dSlice:
217221
"n11c-1024c-2d",
218222
"float16",
219223
),
220-
)
224+
]
225+
226+
_param_ids = get_multitest_ids(_multitest_params, _param_descs)
227+
228+
# NOTE: input_layout is always assumed to be "nhwc-8h2w32c2w-2d"
229+
(
230+
output_shape,
231+
kernel,
232+
stride,
233+
dilation,
234+
padding,
235+
ceil_mode,
236+
count_include_pad,
237+
output_layout,
238+
dtype,
239+
) = tvm.testing.parameters(*_multitest_params, ids=_param_ids)
221240

222241
@tvm.testing.fixture
223242
def expected_output_np(

0 commit comments

Comments
 (0)