@@ -36,6 +36,10 @@ class TestConformer:
3636 # .to_executorch step, i.e. after Arm partitioner.
3737 aten_ops = ["torch.ops.aten._assert_scalar.default" ]
3838
39+ # TODO(MLETORCH-635): reduce tolerance
40+ atol = 0.4
41+ rtol = 0.4
42+
3943 dim = 16
4044 num_examples = 10
4145 lengths = torch .randint (1 , 100 , (num_examples ,), dtype = torch .int32 )
@@ -65,7 +69,7 @@ def test_conformer_tosa_INT():
6569 pipeline = TosaPipelineINT [input_t ](
6670 TestConformer .conformer ,
6771 TestConformer .model_example_inputs ,
68- aten_op = [], # RemoveGraphAssertsPass is added in transform_for_annotation_pipeline to remove the assert ops
72+ aten_op = [],
6973 exir_op = [],
7074 use_to_edge_transform_and_lower = True ,
7175 )
@@ -75,8 +79,8 @@ def test_conformer_tosa_INT():
7579 get_test_inputs (
7680 TestConformer .dim , TestConformer .lengths , TestConformer .num_examples
7781 ),
78- rtol = 1.0 ,
79- atol = 3.0 ,
82+ rtol = TestConformer . rtol ,
83+ atol = TestConformer . atol ,
8084 )
8185 pipeline .run ()
8286
@@ -130,13 +134,20 @@ def test_conformer_vgf_INT():
130134 pipeline = VgfPipeline [input_t ](
131135 TestConformer .conformer ,
132136 TestConformer .model_example_inputs ,
133- aten_op = [], # RemoveGraphAssertsPass is added in transform_for_annotation_pipeline to remove the assert ops
137+ aten_op = [],
134138 exir_op = [],
135139 tosa_version = "TOSA-1.0+INT" ,
136140 use_to_edge_transform_and_lower = True ,
137- run_on_vulkan_runtime = False , # TODO: run on vulkan runtime
138141 )
139142 pipeline .pop_stage ("check_count.exir" )
143+ pipeline .change_args (
144+ "run_method_and_compare_outputs" ,
145+ get_test_inputs (
146+ TestConformer .dim , TestConformer .lengths , TestConformer .num_examples
147+ ),
148+ rtol = TestConformer .rtol ,
149+ atol = TestConformer .atol ,
150+ )
140151 pipeline .run ()
141152
142153
0 commit comments