Skip to content

Commit a00cf3d

Browse files
committed
Arm backend: Enable running MV2/DL3/Conformer on MLSDK runtime
Change-Id: I448186989172ed40527082a3d17ea663bc01e437
1 parent 8fbc42c commit a00cf3d

File tree

3 files changed

+20
-8
lines changed

3 files changed

+20
-8
lines changed

backends/arm/test/models/test_conformer.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ class TestConformer:
3636
# .to_executorch step, i.e. after Arm partitioner.
3737
aten_ops = ["torch.ops.aten._assert_scalar.default"]
3838

39+
# TODO(MLETORCH-635): reduce tolerance
40+
atol = 0.4
41+
rtol = 0.4
42+
3943
dim = 16
4044
num_examples = 10
4145
lengths = torch.randint(1, 100, (num_examples,), dtype=torch.int32)
@@ -65,7 +69,7 @@ def test_conformer_tosa_INT():
6569
pipeline = TosaPipelineINT[input_t](
6670
TestConformer.conformer,
6771
TestConformer.model_example_inputs,
68-
aten_op=[], # RemoveGraphAssertsPass is added in transform_for_annotation_pipeline to remove the assert ops
72+
aten_op=[],
6973
exir_op=[],
7074
use_to_edge_transform_and_lower=True,
7175
)
@@ -75,8 +79,8 @@ def test_conformer_tosa_INT():
7579
get_test_inputs(
7680
TestConformer.dim, TestConformer.lengths, TestConformer.num_examples
7781
),
78-
rtol=1.0,
79-
atol=3.0,
82+
rtol=TestConformer.rtol,
83+
atol=TestConformer.atol,
8084
)
8185
pipeline.run()
8286

@@ -130,13 +134,20 @@ def test_conformer_vgf_INT():
130134
pipeline = VgfPipeline[input_t](
131135
TestConformer.conformer,
132136
TestConformer.model_example_inputs,
133-
aten_op=[], # RemoveGraphAssertsPass is added in transform_for_annotation_pipeline to remove the assert ops
137+
aten_op=[],
134138
exir_op=[],
135139
tosa_version="TOSA-1.0+INT",
136140
use_to_edge_transform_and_lower=True,
137-
run_on_vulkan_runtime=False, # TODO: run on vulkan runtime
138141
)
139142
pipeline.pop_stage("check_count.exir")
143+
pipeline.change_args(
144+
"run_method_and_compare_outputs",
145+
get_test_inputs(
146+
TestConformer.dim, TestConformer.lengths, TestConformer.num_examples
147+
),
148+
rtol=TestConformer.rtol,
149+
atol=TestConformer.atol,
150+
)
140151
pipeline.run()
141152

142153

backends/arm/test/models/test_dl3_arm.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,11 @@ def test_dl3_vgf_INT():
9797
exir_op=[],
9898
tosa_version="TOSA-1.0+INT",
9999
use_to_edge_transform_and_lower=True,
100-
run_on_vulkan_runtime=False, # TODO: run on vulkan runtime
100+
run_on_vulkan_runtime=True, # TODO: run on vulkan runtime
101101
)
102+
pipeline.change_args(
103+
"run_method_and_compare_outputs", rtol=0.1, atol=0.1
104+
) # TODO: MLETORCH-1036 decrease tolerance
102105
pipeline.run()
103106

104107

backends/arm/test/models/test_mobilenet_v2_arm.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,6 @@ def test_mv2_vgf_INT(per_channel_quantization):
125125
per_channel_quantization=per_channel_quantization,
126126
atol=0.25,
127127
qtol=1,
128-
run_on_vulkan_runtime=False,
129128
)
130129
pipeline.run()
131130

@@ -139,6 +138,5 @@ def test_mv2_vgf_FP():
139138
exir_op=[],
140139
tosa_version="TOSA-1.0+FP",
141140
use_to_edge_transform_and_lower=True,
142-
run_on_vulkan_runtime=False,
143141
)
144142
pipeline.run()

0 commit comments

Comments
 (0)