@@ -42,6 +42,50 @@ def setUp(self):
4242 "executorch_exir_dialects_edge__ops_quantized_decomposed_choose_qparams_tensor"
4343 )
4444 dynamic_quant_name = "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_tensor"
45+ def run_tester (self , module , inputs ):
46+ tester = Tester (
47+ module .eval (),
48+ inputs ,
49+ )
50+ tester .export ().to_edge_transform_and_lower ().to_executorch ().serialize ().run_method_and_compare_outputs ()
51+
52+ class LinearConv (torch .nn .Module ):
53+ def __init__ (self ):
54+ super ().__init__ ()
55+ self .conv1 = torch .nn .Conv2d (3 , 3 , 3 )
56+ self .linear1 = torch .nn .Linear (4 , 3 )
57+ def forward (self , x ):
58+ y = self .linear1 (x )
59+ return self .conv1 (y )
60+ class ConvLinearConv (torch .nn .Module ):
61+ def __init__ (self ):
62+ super ().__init__ ()
63+ self .conv1 = torch .nn .Conv2d (3 , 3 , 3 )
64+ self .linear1 = torch .nn .Linear (4 , 4 )
65+ def forward (self , x ):
66+ y = self .conv1 (x )
67+ return self .linear1 (y )
68+ class Bilinear (torch .nn .Module ):
69+ def __init__ (self ):
70+ super ().__init__ ()
71+ def forward (self , x ):
72+ return torch .nn .functional .interpolate (
73+ x , scale_factor = 2 , mode = "bilinear" , align_corners = True
74+ )
75+
76+ def test_conv_linear_dim_order_swaps (self ):
77+ self .run_tester (self .LinearConv (), (torch .randn (1 , 3 , 6 , 4 ),))
78+ self .run_tester (self .LinearConv (), (torch .randn (1 , 3 , 6 , 4 ).to (memory_format = torch .channels_last ),))
79+
80+ def test_linear_conv_dim_order_swaps (self ):
81+ self .run_tester (self .ConvLinearConv (), (torch .randn (1 , 3 , 6 , 6 ),))
82+ self .run_tester (self .ConvLinearConv (), (torch .randn (1 , 3 , 6 , 6 ).to (memory_format = torch .channels_last ),))
83+
84+ def test_nhwc_input_on_nhwc_op (self ):
85+ self .run_tester (self .Bilinear (), (torch .arange (8 ).reshape (1 , 2 , 2 , 2 ).to (torch .float32 ).to (memory_format = torch .channels_last ),))
86+
87+ def test_nchw_input_on_nhwc_op (self ):
88+ self .run_tester (self .Bilinear (), (torch .arange (8 ).reshape (1 , 2 , 2 , 2 ).to (torch .float32 ),))
4589
4690 def test_fp32_channels_last_tagged_reshape_pass (self ):
4791 for module , num_reshape in self .modules .items ():
@@ -58,6 +102,88 @@ def test_fp32_channels_last_tagged_reshape_pass(self):
58102 .run_method_and_compare_outputs ()
59103 )
60104
105+ class LinearConv (torch .nn .Module ):
106+ def __init__ (self ):
107+ super ().__init__ ()
108+ self .conv1 = torch .nn .Conv2d (3 , 3 , 3 )
109+ self .linear1 = torch .nn .Linear (4 , 3 )
110+
111+ def forward (self , x ):
112+ y = self .linear1 (x )
113+ return self .conv1 (y )
114+
115+ def test_conv_linear_dim_order_swaps_on_nhwc_input (self ):
116+ tester = Tester (
117+ self .LinearConv ().eval (),
118+ (torch .randn (1 , 3 , 6 , 4 ).to (memory_format = torch .channels_last ),),
119+ )
120+
121+ tester .export ().to_edge_transform_and_lower ().to_executorch ().serialize ().run_method_and_compare_outputs ()
122+
123+ def test_conv_linear_dim_order_swaps_on_nchw_input (self ):
124+ tester = Tester (
125+ self .LinearConv ().eval (),
126+ (torch .randn (1 , 3 , 6 , 4 ),),
127+ )
128+
129+ tester .export ().to_edge_transform_and_lower ().to_executorch ().serialize ().run_method_and_compare_outputs ()
130+
131+ class ConvLinearConv (torch .nn .Module ):
132+ def __init__ (self ):
133+ super ().__init__ ()
134+ self .conv1 = torch .nn .Conv2d (3 , 3 , 3 )
135+ self .linear1 = torch .nn .Linear (4 , 4 )
136+
137+ def forward (self , x ):
138+ y = self .conv1 (x )
139+ return self .linear1 (y )
140+
141+ def test_linear_conv_dim_order_swaps_on_nhwc_input (self ):
142+ tester = Tester (
143+ self .ConvLinearConv ().eval (),
144+ (torch .randn (1 , 3 , 6 , 6 ).to (memory_format = torch .channels_last ),),
145+ )
146+
147+ tester .export ().to_edge_transform_and_lower ().to_executorch ().serialize ().run_method_and_compare_outputs ()
148+
149+ def test_linear_conv_dim_order_swaps_on_nchw_input (self ):
150+ tester = Tester (
151+ self .ConvLinearConv ().eval (),
152+ (torch .randn (1 , 3 , 6 , 6 ),),
153+ )
154+
155+ tester .export ().to_edge_transform_and_lower ().to_executorch ().serialize ().run_method_and_compare_outputs ()
156+
157+ class Bilinear (torch .nn .Module ):
158+ def __init__ (self ):
159+ super ().__init__ ()
160+
161+ def forward (self , x ):
162+ return torch .nn .functional .interpolate (
163+ x , scale_factor = 2 , mode = "bilinear" , align_corners = True
164+ )
165+
166+ def test_nhwc_input_on_nhwc_op (self ):
167+ tester = Tester (
168+ self .Bilinear ().eval (),
169+ (
170+ torch .arange (8 )
171+ .reshape (1 , 2 , 2 , 2 )
172+ .to (torch .float32 )
173+ .to (memory_format = torch .channels_last ),
174+ ),
175+ )
176+
177+ tester .export ().to_edge_transform_and_lower ().to_executorch ().serialize ().run_method_and_compare_outputs ()
178+
179+ def test_nchw_input_on_nhwc_op (self ):
180+ tester = Tester (
181+ self .Bilinear ().eval (),
182+ (torch .arange (8 ).reshape (1 , 2 , 2 , 2 ).to (torch .float32 ),),
183+ )
184+
185+ tester .export ().to_edge_transform_and_lower ().to_executorch ().serialize ().run_method_and_compare_outputs ()
186+
61187 def test_qs8_channels_last_tagged_reshape_pass (self ):
62188 for module , num_reshape in self .modules .items ():
63189 (
0 commit comments