diff --git a/backends/arm/test/models/test_conformer.py b/backends/arm/test/models/test_conformer.py index 3119145aef1..bda52fc078a 100644 --- a/backends/arm/test/models/test_conformer.py +++ b/backends/arm/test/models/test_conformer.py @@ -36,6 +36,10 @@ class TestConformer: # .to_executorch step, i.e. after Arm partitioner. aten_ops = ["torch.ops.aten._assert_scalar.default"] + # TODO(MLETORCH-635): reduce tolerance + atol = 0.4 + rtol = 0.4 + dim = 16 num_examples = 10 lengths = torch.randint(1, 100, (num_examples,), dtype=torch.int32) @@ -65,7 +69,7 @@ def test_conformer_tosa_INT(): pipeline = TosaPipelineINT[input_t]( TestConformer.conformer, TestConformer.model_example_inputs, - aten_op=[], # RemoveGraphAssertsPass is added in transform_for_annotation_pipeline to remove the assert ops + aten_op=[], exir_op=[], use_to_edge_transform_and_lower=True, ) @@ -75,8 +79,8 @@ def test_conformer_tosa_INT(): get_test_inputs( TestConformer.dim, TestConformer.lengths, TestConformer.num_examples ), - rtol=1.0, - atol=3.0, + rtol=TestConformer.rtol, + atol=TestConformer.atol, ) pipeline.run() @@ -132,22 +136,20 @@ def test_conformer_vgf_INT(): pipeline = VgfPipeline[input_t]( TestConformer.conformer, TestConformer.model_example_inputs, - aten_op=[], # RemoveGraphAssertsPass is added in transform_for_annotation_pipeline to remove the assert ops + aten_op=[], exir_op=[], tosa_version="TOSA-1.0+INT", use_to_edge_transform_and_lower=True, ) pipeline.pop_stage("check_count.exir") - - # TODO: MLETORCH-1167 Create Vulkan backend e2e tests - # pipeline.change_args( - # "run_method_and_compare_outputs", - # get_test_inputs( - # TestConformer.dim, TestConformer.lengths, TestConformer.num_examples - # ), - # rtol=1.0, - # atol=3.0, - # ) + pipeline.change_args( + "run_method_and_compare_outputs", + get_test_inputs( + TestConformer.dim, TestConformer.lengths, TestConformer.num_examples + ), + rtol=TestConformer.rtol, + atol=TestConformer.atol, + ) pipeline.run() diff --git a/backends/arm/test/models/test_dl3_arm.py b/backends/arm/test/models/test_dl3_arm.py index 2000ac34794..8c25230f1a7 100644 --- a/backends/arm/test/models/test_dl3_arm.py +++ b/backends/arm/test/models/test_dl3_arm.py @@ -99,11 +99,11 @@ def test_dl3_vgf_INT(): exir_op=[], tosa_version="TOSA-1.0+INT", use_to_edge_transform_and_lower=True, + run_on_vulkan_runtime=True, # TODO: run on vulkan runtime ) - # TODO: MLETORCH-1167 Create Vulkan backend e2e tests - # pipeline.change_args( - # "run_method_and_compare_outputs", rtol=1.0, atol=1.0 - # ) + pipeline.change_args( + "run_method_and_compare_outputs", rtol=0.1, atol=0.1 + ) # TODO: MLETORCH-1036 decrease tolerance pipeline.run()