2727 match_ethosu_binary_elementwise ,
2828 binary_elementwise_compute ,
2929)
30+ from tvm .relay .backend .contrib .ethosu .te .common import get_layout_transform_matrices
3031
3132
32- def _make_matrices (broadcast , ifm_layout , ifm2_layout , ofm_layout ):
33+ def _make_matrices (broadcast , ifm_layout , ifm2_layout , ofm_layout , ofm_channels ):
3334 broadcast_h , broadcast_w , broadcast_c = broadcast
34- nhwc_to_nhcwb16 = [
35- [1 , 0 , 0 , 0 , 0 ],
36- [0 , 1 , 0 , 0 , 0 ],
37- [0 , 0 , 0 , 1 / 16 , 0 ],
38- [0 , 0 , 1 , 0 , 0 ],
39- [0 , 0 , 0 , 0 , 16 ],
40- [0 , 0 , 0 , 0 , 1 ],
41- ]
42- nhcwb16_to_nhwc = [
43- [1 , 0 , 0 , 0 , 0 , 0 ],
44- [0 , 1 , 0 , 0 , 0 , 0 ],
45- [0 , 0 , 0 , 1 , 0 , 0 ],
46- [0 , 0 , 16 , 0 , 1 , - 16 ],
47- [0 , 0 , 0 , 0 , 0 , 1 ],
48- ]
35+ nhwc_to_nhcwb16 , nhcwb16_to_nhwc = get_layout_transform_matrices (ofm_channels )
4936 ifm_matrix = [
5037 [1 , 0 , 0 , 0 , 0 ],
5138 [0 , 1 , 0 , 0 , 0 ],
@@ -93,14 +80,8 @@ def test_ethosu_binary_elementwise_matcher(
9380 ifm2_shape = [1 ] + [1 if (b == 1 ) else a for a , b in zip (ofm_shape [1 :], ifm2_broadcast )]
9481 ifm_channels = ifm_shape [3 ]
9582 ifm2_channels = ifm2_shape [3 ]
96- nhwc_to_nhcwb16 = [
97- [1 , 0 , 0 , 0 , 0 ],
98- [0 , 1 , 0 , 0 , 0 ],
99- [0 , 0 , 0 , 1 / 16 , 0 ],
100- [0 , 0 , 1 , 0 , 0 ],
101- [0 , 0 , 0 , 0 , 16 ],
102- [0 , 0 , 0 , 0 , 1 ],
103- ]
83+ ofm_channels = ofm_shape [3 ]
84+ nhwc_to_nhcwb16 , _ = get_layout_transform_matrices (ofm_channels )
10485 broadcast = [1 if a == 1 else 0 for a in ifm2_shape [1 :]]
10586 if ifm_layout == "NHCWB16" :
10687 ifm_shape = [
@@ -173,10 +154,7 @@ def test_ethosu_binary_elementwise_matcher(
173154 output_stripe_config = cs .StripeConfig (ofm_shape , ofm_shape , ofm_shape , order , stripes , offset )
174155
175156 (ifm_transform , ifm2_transform ) = _make_matrices (
176- broadcast ,
177- ifm_layout ,
178- ifm2_layout ,
179- ofm_layout ,
157+ broadcast , ifm_layout , ifm2_layout , ofm_layout , ofm_channels
180158 )
181159
182160 device_config = cs .EthosuDeviceConfig ("ethos-u55-256" )
@@ -190,19 +168,10 @@ def test_ethosu_binary_elementwise_matcher(
190168 propagated_ifm = ifm_propagator .propagate (output_stripe_config ).shape
191169 propagated_ifm2 = ifm2_propagator .propagate (output_stripe_config ).shape
192170
193- # Layout conversions will align the propagated IFMs to the brick, i.e. 16
194- # so the expected ifm(2)_shape needs to be rounded up to 16
195- if ifm_layout != ofm_layout :
196- assert ifm_shape [:- 1 ] == propagated_ifm [:- 1 ]
197- assert ((ifm_shape [- 1 ] + 16 - 1 ) // 16 ) * 16 == propagated_ifm [- 1 ]
198- else :
199- assert ifm_shape == propagated_ifm
200-
201- if ifm2_layout != ofm_layout :
202- assert ifm2_shape [:- 1 ] == propagated_ifm2 [:- 1 ]
203- assert ((ifm2_shape [- 1 ] + 16 - 1 ) // 16 ) * 16 == propagated_ifm2 [- 1 ]
204- else :
205- assert ifm2_shape == propagated_ifm2
171+ # The layout transforms that have the exact number of output channels in them
172+ # will lose no information about the number of channels
173+ assert ifm_shape == propagated_ifm
174+ assert ifm2_shape == propagated_ifm2
206175
207176
208177if __name__ == "__main__" :
0 commit comments