1111import pytest
1212
1313import torch
14- from executorch .backends .arm .test import common
14+ from executorch .backends .arm .test import common , conftest
1515from executorch .backends .arm .test .tester .arm_tester import ArmTester
1616from executorch .exir .backend .compile_spec_schema import CompileSpec
1717from parameterized import parameterized
2828 lambda : ("randn" , torch .randn (10 , 10 , 10 , 10 ), 3 ),
2929 lambda : ("randn_neg_dim" , torch .randn (10 , 5 , 8 , 7 ), - 3 ),
3030]
31- test_data_generators_u55 = [
31+
32+ test_data_generators_FVP = [
3233 # (test_name, test_data, dim)
3334 lambda : ("ones" , torch .ones (10 , 10 ), 1 ),
3435 lambda : ("ones_neg_dim" , torch .ones (10 , 3 , 4 ), - 1 ),
35- lambda : ("randn_neg_dim" , torch .randn (10 , 5 , 8 , 7 ), - 3 ),
36- lambda : ("zeros" , torch .zeros (10 , 8 , 5 , 2 ), 0 ),
37- lambda : ("zeros_neg_dim" , torch .zeros (10 , 7 , 8 , 9 ), - 4 ),
36+ lambda : ("randn_neg_dim" , torch .randn (1 , 5 , 8 , 7 ), - 3 ),
37+ lambda : ("zeros" , torch .zeros (1 , 8 , 5 , 2 ), 0 ),
38+ lambda : ("zeros_neg_dim" , torch .zeros (1 , 7 , 8 , 9 ), - 4 ),
3839 lambda : ("rand" , torch .rand (1 , 2 , 5 , 8 ), 2 ),
39- lambda : ("rand_neg_dim" , torch .rand (2 , 10 , 8 , 10 ), - 2 ),
40- lambda : ("randn" , torch .randn (10 , 10 , 10 , 10 ), 3 ),
40+ lambda : ("rand_neg_dim" , torch .rand (1 , 10 , 8 , 10 ), - 2 ),
41+ lambda : ("randn" , torch .randn (1 , 10 , 10 , 10 ), 3 ),
4142]
4243
4344
@@ -99,7 +100,7 @@ def _test_logsoftmax_tosa_ethos_BI_pipeline(
99100 module : torch .nn .Module ,
100101 test_data : Tuple [torch .tensor ],
101102 ):
102- (
103+ tester = (
103104 ArmTester (
104105 module ,
105106 example_inputs = test_data ,
@@ -114,21 +115,10 @@ def _test_logsoftmax_tosa_ethos_BI_pipeline(
114115 .check_not (["executorch_exir_dialects_edge__ops_aten__logsoftmax_default" ])
115116 .check_count ({"torch.ops.higher_order.executorch_call_delegate" : 1 })
116117 .to_executorch ()
118+ .serialize ()
117119 )
118-
119- def _test_logsoftmax_tosa_u55_BI_pipeline (
120- self , module : torch .nn .Module , test_data : Tuple [torch .tensor ]
121- ):
122- self ._test_logsoftmax_tosa_ethos_BI_pipeline (
123- common .get_u55_compile_spec (), module , test_data
124- )
125-
126- def _test_logsoftmax_tosa_u85_BI_pipeline (
127- self , module : torch .nn .Module , test_data : Tuple [torch .tensor ]
128- ):
129- self ._test_logsoftmax_tosa_ethos_BI_pipeline (
130- common .get_u85_compile_spec (), module , test_data
131- )
120+ if conftest .is_option_enabled ("corstone_fvp" ):
121+ tester .run_method_and_compare_outputs (inputs = test_data , qtol = 1 )
132122
133123 @parameterized .expand (test_data_generators )
134124 def test_logsoftmax_tosa_MI (self , test_data_generator : Callable [[], Tuple ]):
@@ -141,18 +131,18 @@ def test_logsoftmax_tosa_BI(self, test_data_generator: Callable[[], Tuple]):
141131 test_name , test_data , dim = test_data_generator ()
142132 self ._test_logsoftmax_tosa_BI_pipeline (self .LogSoftmax (dim = dim ), (test_data ,))
143133
144- @parameterized .expand (test_data_generators_u55 )
134+ @parameterized .expand (test_data_generators_FVP )
145135 @pytest .mark .flaky # TODO: MLETORCH-460 - Numerically stabler (log)softmax implementation
146136 def test_logsoftmax_tosa_u55_BI (self , test_data_generator : Callable [[], Tuple ]):
147137 test_name , test_data , dim = test_data_generator ()
148- self ._test_logsoftmax_tosa_u55_BI_pipeline (
149- self .LogSoftmax (dim = dim ), (test_data ,)
138+ self ._test_logsoftmax_tosa_ethos_BI_pipeline (
139+ common . get_u55_compile_spec (), self .LogSoftmax (dim = dim ), (test_data ,)
150140 )
151141
152- @parameterized .expand (test_data_generators )
142+ @parameterized .expand (test_data_generators_FVP )
153143 @pytest .mark .flaky # TODO: MLETORCH-460 - Numerically stabler (log)softmax implementation
154144 def test_logsoftmax_tosa_u85_BI (self , test_data_generator : Callable [[], Tuple ]):
155145 test_name , test_data , dim = test_data_generator ()
156- self ._test_logsoftmax_tosa_u85_BI_pipeline (
157- self .LogSoftmax (dim = dim ), (test_data ,)
146+ self ._test_logsoftmax_tosa_ethos_BI_pipeline (
147+ common . get_u85_compile_spec (), self .LogSoftmax (dim = dim ), (test_data ,)
158148 )
0 commit comments