Skip to content

Commit cbc1c77

Browse files
committed
Extended testing for block config
1 parent a946662 commit cbc1c77

File tree

1 file changed

+41
-15
lines changed

1 file changed

+41
-15
lines changed

tests/python/contrib/test_ethosu/cascader/test_ethosu_block_config.py

Lines changed: 41 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828

2929
@pytest.mark.parametrize(
30-
"id, op_type, activation, kernel, stride, dilation, padding, in_shape, out_shape",
30+
"test_id, op_type, activation, kernel, stride, dilation, padding, in_shape, out_shape",
3131
[
3232
# Conv2D
3333
(
@@ -98,6 +98,7 @@
9898
),
9999
# Depthwise Conv2D
100100
(
101+
6,
101102
"ethosu_depthwise_conv2d",
102103
"NONE",
103104
(3, 5),
@@ -106,9 +107,9 @@
106107
(0, 0, 0, 0),
107108
(1, 77, 23, 18),
108109
(1, 75, 19, 18),
109-
(1, 7, 10, 16),
110110
),
111111
(
112+
7,
112113
"ethosu_depthwise_conv2d",
113114
"NONE",
114115
(3, 3),
@@ -117,10 +118,10 @@
117118
(1, 1, 1, 1),
118119
(1, 25, 10, 276),
119120
(1, 13, 5, 276),
120-
(1, 7, 6, 16),
121121
),
122122
# Pooling
123123
(
124+
8,
124125
"ethosu_pooling",
125126
"NONE",
126127
(13, 5),
@@ -129,9 +130,9 @@
129130
(0, 0, 0, 0),
130131
(1, 13, 5, 276),
131132
(1, 1, 1, 276),
132-
(1, 1, 2, 80),
133133
),
134134
(
135+
9,
135136
"ethosu_pooling",
136137
"NONE",
137138
(7, 3),
@@ -140,7 +141,6 @@
140141
(0, 0, 0, 0),
141142
(1, 317, 14, 21),
142143
(1, 156, 12, 21),
143-
(1, 10, 6, 16),
144144
),
145145
],
146146
)
@@ -159,51 +159,79 @@
159159
(
160160
"ethos-u55-32",
161161
[
162+
# Conv2D
162163
((1, 8, 4, 16), (1, 8, 1, 4, 16)),
163164
((1, 6, 5, 16), (1, 6, 1, 5, 16)),
164165
((1, 4, 4, 16), (1, 4, 1, 4, 16)),
165166
((1, 8, 4, 16), (1, 8, 1, 4, 16)),
166-
((1, 10, 6, 4), (1, 16, 1, 4, 4)),
167-
((1, 10, 3, 16), (1, 10, 1, 3, 16)),
167+
((1, 10, 6, 4), (1, 5, 1, 12, 4), (1, 16, 1, 4, 4)),
168+
((1, 6, 5, 16), (1, 6, 1, 5, 16)),
169+
# Depthwise Conv2D
170+
((1, 6, 10, 16), (1, 6, 1, 10, 16)),
171+
((1, 7, 5, 16), (1, 7, 1, 5, 16)),
172+
# Pooling
173+
((1, 1, 1, 16), (1, 1, 1, 1, 16)),
174+
((1, 9, 6, 16), (1, 9, 1, 6, 16)),
168175
],
169176
),
170177
(
171178
"ethos-u55-64",
172179
[
180+
# Conv2D
173181
((1, 8, 4, 16), (1, 8, 1, 4, 16)),
174182
((1, 6, 5, 16), (1, 6, 1, 5, 16)),
175183
((1, 4, 4, 16), (1, 4, 1, 4, 16)),
176184
((1, 8, 4, 16), (1, 8, 1, 4, 16)),
177185
((1, 10, 6, 8), (1, 16, 1, 4, 8)),
178-
((1, 10, 3, 16), (1, 10, 1, 3, 16)),
186+
((1, 6, 5, 16), (1, 6, 1, 5, 16)),
187+
# Depthwise Conv2D
188+
((1, 6, 10, 16), (1, 6, 1, 10, 16)),
189+
((1, 7, 5, 16), (1, 7, 1, 5, 16)),
190+
# Pooling
191+
((1, 1, 1, 16), (1, 1, 1, 1, 16)),
192+
((1, 9, 6, 16), (1, 9, 1, 6, 16)),
179193
],
180194
),
181195
(
182196
"ethos-u55-128",
183197
[
198+
# Conv2D
184199
((1, 7, 6, 16), (1, 7, 1, 6, 16)),
185200
((1, 5, 8, 16), (1, 5, 1, 8, 16)),
186201
((1, 4, 4, 16), (1, 4, 1, 4, 16)),
187202
((1, 16, 4, 16), (1, 16, 1, 4, 16)),
188203
((1, 8, 12, 8), (1, 8, 1, 12, 8)),
189204
((1, 10, 6, 16), (1, 10, 1, 6, 16)),
205+
# Depthwise Conv2D
206+
((1, 7, 10, 16), (1, 7, 1, 10, 16)),
207+
((1, 7, 6, 16), (1, 7, 1, 6, 16)),
208+
# Pooling
209+
((1, 1, 2, 80), (1, 1, 5, 2, 16)),
210+
((1, 10, 6, 16), (1, 10, 1, 6, 16)),
190211
],
191212
),
192213
(
193214
"ethos-u55-256",
194215
[
216+
# Conv2D
195217
((1, 14, 8, 16), (1, 14, 1, 8, 16)),
196218
((1, 16, 8, 16), (1, 16, 1, 8, 16)),
197219
((1, 4, 4, 16), (1, 4, 1, 4, 16)),
198-
((1, 32, 4, 16), (1, 32, 1, 4, 16)),
220+
((1, 32, 4, 16), (1, 10, 12, 16), (1, 32, 1, 4, 16), (1, 10, 1, 12, 16)),
199221
((1, 20, 12, 8), (1, 20, 1, 12, 8)),
200-
((1, 20, 6, 16), (1, 20, 1, 6, 16)),
222+
((1, 12, 10, 16), (1, 12, 1, 10, 16)),
223+
# Depthwise Conv2D
224+
((1, 8, 20, 16), (1, 8, 1, 20, 16)),
225+
((1, 14, 6, 16), (1, 14, 1, 6, 16)),
226+
# Pooling
227+
((1, 2, 2, 48), (1, 2, 3, 2, 16)),
228+
((1, 10, 12, 16), (1, 10, 1, 12, 16)),
201229
],
202230
),
203231
],
204232
)
205233
def test_best_block_config(
206-
id,
234+
test_id,
207235
op_type,
208236
activation,
209237
kernel,
@@ -299,10 +327,8 @@ def test_best_block_config(
299327

300328
block = part.get_block_config(stripe_config)
301329
block_shape = tuple(int(a) for a in block.output_shape)
302-
if layouts[1] == "NHCWB16":
303-
assert block_shape == expected_block_configs[id][1]
304-
else:
305-
assert block_shape == expected_block_configs[id][0]
330+
331+
assert block_shape in expected_block_configs[test_id]
306332

307333

308334
if __name__ == "__main__":

0 commit comments

Comments
 (0)