|
27 | 27 |
|
28 | 28 |
|
29 | 29 | @pytest.mark.parametrize( |
30 | | - "id, op_type, activation, kernel, stride, dilation, padding, in_shape, out_shape", |
| 30 | + "test_id, op_type, activation, kernel, stride, dilation, padding, in_shape, out_shape", |
31 | 31 | [ |
32 | 32 | # Conv2D |
33 | 33 | ( |
|
98 | 98 | ), |
99 | 99 | # Depthwise Conv2D |
100 | 100 | ( |
| 101 | + 6, |
101 | 102 | "ethosu_depthwise_conv2d", |
102 | 103 | "NONE", |
103 | 104 | (3, 5), |
|
106 | 107 | (0, 0, 0, 0), |
107 | 108 | (1, 77, 23, 18), |
108 | 109 | (1, 75, 19, 18), |
109 | | - (1, 7, 10, 16), |
110 | 110 | ), |
111 | 111 | ( |
| 112 | + 7, |
112 | 113 | "ethosu_depthwise_conv2d", |
113 | 114 | "NONE", |
114 | 115 | (3, 3), |
|
117 | 118 | (1, 1, 1, 1), |
118 | 119 | (1, 25, 10, 276), |
119 | 120 | (1, 13, 5, 276), |
120 | | - (1, 7, 6, 16), |
121 | 121 | ), |
122 | 122 | # Pooling |
123 | 123 | ( |
| 124 | + 8, |
124 | 125 | "ethosu_pooling", |
125 | 126 | "NONE", |
126 | 127 | (13, 5), |
|
129 | 130 | (0, 0, 0, 0), |
130 | 131 | (1, 13, 5, 276), |
131 | 132 | (1, 1, 1, 276), |
132 | | - (1, 1, 2, 80), |
133 | 133 | ), |
134 | 134 | ( |
| 135 | + 9, |
135 | 136 | "ethosu_pooling", |
136 | 137 | "NONE", |
137 | 138 | (7, 3), |
|
140 | 141 | (0, 0, 0, 0), |
141 | 142 | (1, 317, 14, 21), |
142 | 143 | (1, 156, 12, 21), |
143 | | - (1, 10, 6, 16), |
144 | 144 | ), |
145 | 145 | ], |
146 | 146 | ) |
|
159 | 159 | ( |
160 | 160 | "ethos-u55-32", |
161 | 161 | [ |
| 162 | + # Conv2D |
162 | 163 | ((1, 8, 4, 16), (1, 8, 1, 4, 16)), |
163 | 164 | ((1, 6, 5, 16), (1, 6, 1, 5, 16)), |
164 | 165 | ((1, 4, 4, 16), (1, 4, 1, 4, 16)), |
165 | 166 | ((1, 8, 4, 16), (1, 8, 1, 4, 16)), |
166 | | - ((1, 10, 6, 4), (1, 16, 1, 4, 4)), |
167 | | - ((1, 10, 3, 16), (1, 10, 1, 3, 16)), |
| 167 | + ((1, 10, 6, 4), (1, 5, 1, 12, 4), (1, 16, 1, 4, 4)), |
| 168 | + ((1, 6, 5, 16), (1, 6, 1, 5, 16)), |
| 169 | + # Depthwise Conv2D |
| 170 | + ((1, 6, 10, 16), (1, 6, 1, 10, 16)), |
| 171 | + ((1, 7, 5, 16), (1, 7, 1, 5, 16)), |
| 172 | + # Pooling |
| 173 | + ((1, 1, 1, 16), (1, 1, 1, 1, 16)), |
| 174 | + ((1, 9, 6, 16), (1, 9, 1, 6, 16)), |
168 | 175 | ], |
169 | 176 | ), |
170 | 177 | ( |
171 | 178 | "ethos-u55-64", |
172 | 179 | [ |
| 180 | + # Conv2D |
173 | 181 | ((1, 8, 4, 16), (1, 8, 1, 4, 16)), |
174 | 182 | ((1, 6, 5, 16), (1, 6, 1, 5, 16)), |
175 | 183 | ((1, 4, 4, 16), (1, 4, 1, 4, 16)), |
176 | 184 | ((1, 8, 4, 16), (1, 8, 1, 4, 16)), |
177 | 185 | ((1, 10, 6, 8), (1, 16, 1, 4, 8)), |
178 | | - ((1, 10, 3, 16), (1, 10, 1, 3, 16)), |
| 186 | + ((1, 6, 5, 16), (1, 6, 1, 5, 16)), |
| 187 | + # Depthwise Conv2D |
| 188 | + ((1, 6, 10, 16), (1, 6, 1, 10, 16)), |
| 189 | + ((1, 7, 5, 16), (1, 7, 1, 5, 16)), |
| 190 | + # Pooling |
| 191 | + ((1, 1, 1, 16), (1, 1, 1, 1, 16)), |
| 192 | + ((1, 9, 6, 16), (1, 9, 1, 6, 16)), |
179 | 193 | ], |
180 | 194 | ), |
181 | 195 | ( |
182 | 196 | "ethos-u55-128", |
183 | 197 | [ |
| 198 | + # Conv2D |
184 | 199 | ((1, 7, 6, 16), (1, 7, 1, 6, 16)), |
185 | 200 | ((1, 5, 8, 16), (1, 5, 1, 8, 16)), |
186 | 201 | ((1, 4, 4, 16), (1, 4, 1, 4, 16)), |
187 | 202 | ((1, 16, 4, 16), (1, 16, 1, 4, 16)), |
188 | 203 | ((1, 8, 12, 8), (1, 8, 1, 12, 8)), |
189 | 204 | ((1, 10, 6, 16), (1, 10, 1, 6, 16)), |
| 205 | + # Depthwise Conv2D |
| 206 | + ((1, 7, 10, 16), (1, 7, 1, 10, 16)), |
| 207 | + ((1, 7, 6, 16), (1, 7, 1, 6, 16)), |
| 208 | + # Pooling |
| 209 | + ((1, 1, 2, 80), (1, 1, 5, 2, 16)), |
| 210 | + ((1, 10, 6, 16), (1, 10, 1, 6, 16)), |
190 | 211 | ], |
191 | 212 | ), |
192 | 213 | ( |
193 | 214 | "ethos-u55-256", |
194 | 215 | [ |
| 216 | + # Conv2D |
195 | 217 | ((1, 14, 8, 16), (1, 14, 1, 8, 16)), |
196 | 218 | ((1, 16, 8, 16), (1, 16, 1, 8, 16)), |
197 | 219 | ((1, 4, 4, 16), (1, 4, 1, 4, 16)), |
198 | | - ((1, 32, 4, 16), (1, 32, 1, 4, 16)), |
| 220 | + ((1, 32, 4, 16), (1, 10, 12, 16), (1, 32, 1, 4, 16), (1, 10, 1, 12, 16)), |
199 | 221 | ((1, 20, 12, 8), (1, 20, 1, 12, 8)), |
200 | | - ((1, 20, 6, 16), (1, 20, 1, 6, 16)), |
| 222 | + ((1, 12, 10, 16), (1, 12, 1, 10, 16)), |
| 223 | + # Depthwise Conv2D |
| 224 | + ((1, 8, 20, 16), (1, 8, 1, 20, 16)), |
| 225 | + ((1, 14, 6, 16), (1, 14, 1, 6, 16)), |
| 226 | + # Pooling |
| 227 | + ((1, 2, 2, 48), (1, 2, 3, 2, 16)), |
| 228 | + ((1, 10, 12, 16), (1, 10, 1, 12, 16)), |
201 | 229 | ], |
202 | 230 | ), |
203 | 231 | ], |
204 | 232 | ) |
205 | 233 | def test_best_block_config( |
206 | | - id, |
| 234 | + test_id, |
207 | 235 | op_type, |
208 | 236 | activation, |
209 | 237 | kernel, |
@@ -299,10 +327,8 @@ def test_best_block_config( |
299 | 327 |
|
300 | 328 | block = part.get_block_config(stripe_config) |
301 | 329 | block_shape = tuple(int(a) for a in block.output_shape) |
302 | | - if layouts[1] == "NHCWB16": |
303 | | - assert block_shape == expected_block_configs[id][1] |
304 | | - else: |
305 | | - assert block_shape == expected_block_configs[id][0] |
| 330 | + |
| 331 | + assert block_shape in expected_block_configs[test_id] |
306 | 332 |
|
307 | 333 |
|
308 | 334 | if __name__ == "__main__": |
|
0 commit comments