Skip to content

Commit 9700cf0

Browse files
authored
Merge branch 'main' into vgf_ethosu_docs
2 parents 64b25c8 + 6e9fb80 commit 9700cf0

File tree

56 files changed

+1229
-328
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

56 files changed

+1229
-328
lines changed

backends/arm/_passes/__init__.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,13 @@
1818
from .convert_full_like_to_full_pass import ConvertFullLikeToFullPass # noqa
1919
from .convert_int64_const_ops_to_int32 import ConvertInt64ConstOpsToInt32Pass # noqa
2020
from .convert_int64_output_ops_to_int32 import ConvertInt64OutputOpsToInt32Pass # noqa
21-
from .convert_int_pow_to_mul import ConvertIntPowToMuls # noqa
2221
from .convert_minmax_pass import ConvertMinMaxPass # noqa
2322
from .convert_permute_singleton_to_view_pass import ( # noqa
2423
ConvertPermuteSingletonToViewPass,
2524
)
2625
from .convert_split_to_slice import ConvertSplitToSlicePass # noqa
2726
from .convert_squeezes_to_view import ConvertSqueezesToViewPass # noqa
28-
from .convert_to_clamp import ConvertToClampPass # noqa
27+
from .convert_to_clamp_pass import ConvertToClampPass # noqa
2928
from .decompose_acosh_pass import DecomposeAcoshPass # noqa
3029
from .decompose_adaptive_avg_pool2d_pass import DecomposeAdaptiveAvgPool2dPass # noqa
3130
from .decompose_add_sub_alpha_pass import DecomposeAddSubAlphaPass # noqa
@@ -35,7 +34,7 @@
3534
from .decompose_asinh_pass import DecomposeAsinhPass # noqa
3635
from .decompose_atan_pass import DecomposeAtanPass # noqa
3736
from .decompose_atanh_pass import DecomposeAtanhPass # noqa
38-
from .decompose_avg_pool2d import DecomposeAvgPool2d # noqa
37+
from .decompose_avg_pool2d_pass import DecomposeAvgPool2dPass # noqa
3938
from .decompose_batch_norm_no_stats import DecomposeBatchNormNoStatsPass # noqa
4039
from .decompose_cosh_pass import DecomposeCoshPass # noqa
4140
from .decompose_cosine_similarity_pass import DecomposeCosineSimilarityPass # noqa
@@ -48,23 +47,24 @@
4847
from .decompose_floor_divide_pass import DecomposeFloorDividePass # noqa
4948
from .decompose_gelu_pass import DecomposeGeluPass # noqa
5049
from .decompose_glu_pass import DecomposeGluPass # noqa
51-
from .decompose_grouped_conv import DecomposeGroupedConv # noqa
50+
from .decompose_grouped_conv_pass import DecomposeGroupedConvPass # noqa
5251
from .decompose_groupnorm_pass import DecomposeGroupNormPass # noqa
5352
from .decompose_int16_activation_conv2d_pass import ( # noqa
5453
DecomposeConv2dWithInt16ActivationPass,
5554
)
55+
from .decompose_int_pow_pass import DecomposeIntPowPass # noqa
5656
from .decompose_layernorm_pass import DecomposeLayerNormPass # noqa
5757
from .decompose_leaky_relu_pass import DecomposeLeakyReLUPass # noqa
58-
from .decompose_linalg_vector_norm_pass import DecomposeLinearVectorNormPass # noqa
58+
from .decompose_linalg_vector_norm_pass import DecomposeLinalgVectorNormPass # noqa
5959
from .decompose_linear_pass import DecomposeLinearPass # noqa
6060
from .decompose_logit_pass import DecomposeLogitPass # noqa
61-
from .decompose_masked_fill import DecomposeMaskedFill # noqa
62-
from .decompose_maxpool2d_with_dilation import DecomposeMaxPool2DPass # noqa
61+
from .decompose_masked_fill_pass import DecomposeMaskedFillPass # noqa
62+
from .decompose_maxpool2d_with_dilation_pass import DecomposeMaxPool2dPass # noqa
6363
from .decompose_meandim_pass import DecomposeMeanDimPass # noqa
6464
from .decompose_ne_pass import DecomposeNotEqualPass # noqa
6565
from .decompose_remainder_pass import DecomposeRemainderPass # noqa
6666
from .decompose_round_pass import DecomposeRoundPass # noqa
67-
from .decompose_sdpa_pass import DecomposeScaledDotProductAttention # noqa
67+
from .decompose_sdpa_pass import DecomposeScaledDotProductAttentionPass # noqa
6868
from .decompose_select import DecomposeSelectPass # noqa
6969
from .decompose_sign_pass import DecomposeSignPass # noqa
7070
from .decompose_silu_pass import DecomposeSiluPass # noqa
@@ -77,18 +77,25 @@
7777
from .decorate_fp32_to_int32_casting_pass import DecorateFp32toInt32CastingPass # noqa
7878
from .fold_qdq_with_annotated_qparams_pass import ( # noqa
7979
FoldAndAnnotateQParamsPass,
80-
QuantizeOperatorArguments,
80+
QuantizeClampArgumentsPass,
81+
)
82+
from .fuse_batch_norm2d_pass import FuseBatchNorm2dPass # noqa
83+
from .fuse_constant_ops_pass import ( # noqa
84+
ComputeConstantOpsAOTPass,
85+
FuseConstantArgsPass,
8186
)
82-
from .fuse_batchnorm2d_pass import FuseBatchnorm2DPass # noqa
83-
from .fuse_constant_ops_pass import ComputeConstantOpsAOT, FuseConstantArgsPass # noqa
8487
from .fuse_duplicate_users_pass import FuseDuplicateUsersPass # noqa
8588
from .fuse_equal_placeholders_pass import FuseEqualPlaceholdersPass # noqa
8689
from .fuse_quantized_activation_pass import FuseQuantizedActivationPass # noqa
8790
from .fuse_view_copy_transform_pass import FuseViewCopyTransformPass # noqa
8891
from .insert_int32_casts_after_int64_placeholders import ( # noqa
8992
InsertInt32CastsAfterInt64PlaceholdersPass,
9093
)
91-
from .insert_rescales_pass import InsertRescaleInt32Pass, InsertRescalePass # noqa
94+
from .insert_rescales_pass import ( # noqa
95+
InsertControlFlowRescalesPass,
96+
InsertRescaleInt32Pass,
97+
InsertRescalePass,
98+
)
9299
from .insert_table_ops import InsertTableOpsPass # noqa
93100
from .match_arg_dtype_pass import MatchArgDtypePass # noqa
94101
from .match_arg_ranks_pass import MatchArgRanksPass # noqa
@@ -107,5 +114,5 @@
107114
from .to_tosa_memory_format_pass import ToTosaMemoryFormatPass # noqa
108115
from .unsqueeze_before_repeat_pass import UnsqueezeBeforeRepeatPass # noqa
109116
from .unsqueeze_scalar_placeholders_pass import UnsqueezeScalarPlaceholdersPass # noqa
110-
from .replace_inf_values_pass import ReplaceInfValues # noqa # usort: skip
117+
from .replace_inf_values_pass import ReplaceInfValuesPass # noqa # usort: skip
111118
from .arm_pass_manager import ArmPassManager # noqa # usort: skip

backends/arm/_passes/arm_pass_manager.py

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,13 @@
1717
CastBoolToInt8Pass,
1818
CastInt64BuffersToInt32Pass,
1919
CastToInt32Pass,
20-
ComputeConstantOpsAOT,
20+
ComputeConstantOpsAOTPass,
2121
Conv1dUnsqueezePass,
2222
ConvertELUParamsPass,
2323
ConvertExpandCopyToRepeatPass,
2424
ConvertFullLikeToFullPass,
2525
ConvertInt64ConstOpsToInt32Pass,
2626
ConvertInt64OutputOpsToInt32Pass,
27-
ConvertIntPowToMuls,
2827
ConvertMinMaxPass,
2928
ConvertMmToBmmPass,
3029
ConvertPermuteSingletonToViewPass,
@@ -40,7 +39,7 @@
4039
DecomposeAsinhPass,
4140
DecomposeAtanhPass,
4241
DecomposeAtanPass,
43-
DecomposeAvgPool2d,
42+
DecomposeAvgPool2dPass,
4443
DecomposeBatchNormNoStatsPass,
4544
DecomposeConv2dWithInt16ActivationPass,
4645
DecomposeCoshPass,
@@ -54,20 +53,21 @@
5453
DecomposeFloorDividePass,
5554
DecomposeGeluPass,
5655
DecomposeGluPass,
57-
DecomposeGroupedConv,
56+
DecomposeGroupedConvPass,
5857
DecomposeGroupNormPass,
58+
DecomposeIntPowPass,
5959
DecomposeLayerNormPass,
6060
DecomposeLeakyReLUPass,
61+
DecomposeLinalgVectorNormPass,
6162
DecomposeLinearPass,
62-
DecomposeLinearVectorNormPass,
6363
DecomposeLogitPass,
64-
DecomposeMaskedFill,
65-
DecomposeMaxPool2DPass,
64+
DecomposeMaskedFillPass,
65+
DecomposeMaxPool2dPass,
6666
DecomposeMeanDimPass,
6767
DecomposeNotEqualPass,
6868
DecomposeRemainderPass,
6969
DecomposeRoundPass,
70-
DecomposeScaledDotProductAttention,
70+
DecomposeScaledDotProductAttentionPass,
7171
DecomposeSelectPass,
7272
DecomposeSignPass,
7373
DecomposeSiluPass,
@@ -79,23 +79,24 @@
7979
DecomposeVarPass,
8080
DecorateFp32toInt32CastingPass,
8181
FoldAndAnnotateQParamsPass,
82-
FuseBatchnorm2DPass,
82+
FuseBatchNorm2dPass,
8383
FuseConstantArgsPass,
8484
FuseDuplicateUsersPass,
8585
FuseEqualPlaceholdersPass,
8686
FuseQuantizedActivationPass,
8787
FuseViewCopyTransformPass,
88+
InsertControlFlowRescalesPass,
8889
InsertInt32CastsAfterInt64PlaceholdersPass,
8990
InsertRescaleInt32Pass,
9091
InsertRescalePass,
9192
InsertTableOpsPass,
9293
MatchArgDtypePass,
9394
MatchArgRanksPass,
94-
QuantizeOperatorArguments,
95+
QuantizeClampArgumentsPass,
9596
RemoveGetItemPass,
9697
RemoveGraphAssertsPass,
9798
RemoveNoopPass,
98-
ReplaceInfValues,
99+
ReplaceInfValuesPass,
99100
ReplaceScalarWithTensorByProfilePass,
100101
RewriteConv2dPass,
101102
RewriteMatmulPass,
@@ -181,7 +182,7 @@ def _tosa_pipeline(
181182
AnnotateDecomposedMatmulPass(),
182183
ConvertELUParamsPass(),
183184
ConvertSplitToSlicePass(),
184-
QuantizeOperatorArguments(),
185+
QuantizeClampArgumentsPass(),
185186
]
186187
)
187188

@@ -195,14 +196,15 @@ def _tosa_pipeline(
195196
# Ticket: MLETORCH-1539
196197
DecomposeLinearPass(),
197198
InsertRescaleInt32Pass(),
199+
InsertControlFlowRescalesPass(),
198200
]
199201
)
200202

201203
# Node transformation passes (post q/dq folding)
202204
self.add_passes(
203205
[
204206
DecomposeLogitPass(),
205-
DecomposeMaskedFill(),
207+
DecomposeMaskedFillPass(),
206208
DecomposeRoundPass(),
207209
DecomposeAcoshPass(),
208210
DecomposeAsinhPass(),
@@ -214,14 +216,14 @@ def _tosa_pipeline(
214216
DecomposeAddmmPass(),
215217
DecomposeEluPass(),
216218
DecomposeExpm1Pass(),
217-
ConvertIntPowToMuls(),
219+
DecomposeIntPowPass(),
218220
CastBoolToInt8Pass(),
219221
DecomposeSinhPass(),
220222
DecomposeSignPass(),
221223
DecomposeFloorDividePass(),
222224
DecomposeGeluPass(),
223225
DecomposeAddSubAlphaPass(),
224-
DecomposeGroupedConv(),
226+
DecomposeGroupedConvPass(),
225227
Conv1dUnsqueezePass(),
226228
]
227229
)
@@ -247,7 +249,7 @@ def _tosa_pipeline(
247249
DecomposeRemainderPass(),
248250
DecomposeDivTensorModePass(),
249251
DecomposeEmbeddingPass(),
250-
FuseBatchnorm2DPass(exported_program),
252+
FuseBatchNorm2dPass(exported_program),
251253
ConvertMmToBmmPass(),
252254
DecomposeGluPass(),
253255
DecomposeLeakyReLUPass(),
@@ -256,13 +258,13 @@ def _tosa_pipeline(
256258
ConvertMinMaxPass(),
257259
DecomposeAnyPass(),
258260
DecomposeAdaptiveAvgPool2dPass(),
259-
DecomposeAvgPool2d(),
261+
DecomposeAvgPool2dPass(),
260262
DecorateFp32toInt32CastingPass(),
261-
ComputeConstantOpsAOT(exported_program),
263+
ComputeConstantOpsAOTPass(exported_program),
262264
ConvertExpandCopyToRepeatPass(),
263265
UnsqueezeBeforeRepeatPass(),
264266
DecomposeCumsumPass(exported_program),
265-
DecomposeMaxPool2DPass(),
267+
DecomposeMaxPool2dPass(),
266268
SizeAdjustInputPass(),
267269
DecomposeSelectPass(),
268270
ConvertSqueezesToViewPass(),
@@ -324,7 +326,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
324326
ConvertInt64OutputOpsToInt32Pass(),
325327
InsertInt32CastsAfterInt64PlaceholdersPass(),
326328
DecomposeEmbeddingPass(),
327-
DecomposeScaledDotProductAttention(),
329+
DecomposeScaledDotProductAttentionPass(),
328330
DecomposeRoundPass(),
329331
DecomposeLogitPass(),
330332
CastBoolToInt8Pass(),
@@ -357,10 +359,10 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
357359
DecomposeGluPass(),
358360
DecomposeDivPass(),
359361
DecomposeLeakyReLUPass(),
360-
DecomposeLinearVectorNormPass(),
362+
DecomposeLinalgVectorNormPass(),
361363
DecomposeSqrtPass(),
362364
DecomposeSiluPass(),
363-
DecomposeAvgPool2d(),
365+
DecomposeAvgPool2dPass(),
364366
(
365367
DecomposeSoftmaxUnstablePass()
366368
if self.tosa_spec.is_U55_subset
@@ -373,8 +375,8 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
373375
# Postprocessing passes
374376
self.add_passes(
375377
[
376-
ReplaceInfValues(),
377-
DecomposeMaskedFill() if not self.tosa_spec.is_U55_subset else None,
378+
ReplaceInfValuesPass(),
379+
DecomposeMaskedFillPass() if not self.tosa_spec.is_U55_subset else None,
378380
]
379381
)
380382

backends/arm/_passes/convert_full_like_to_full_pass.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
from typing import Set, Type
77

88
from executorch.backends.arm._passes.arm_pass import ArmPass
9-
from executorch.backends.arm._passes.fuse_constant_ops_pass import ComputeConstantOpsAOT
9+
from executorch.backends.arm._passes.fuse_constant_ops_pass import (
10+
ComputeConstantOpsAOTPass,
11+
)
1012

1113
from executorch.exir.dialects._ops import ops as exir_ops
1214
from executorch.exir.pass_base import ExportPass
@@ -24,7 +26,7 @@ class ConvertFullLikeToFullPass(ArmPass):
2426
Skip layout and device since it's not relevant for our backend.
2527
"""
2628

27-
_passes_required_after: Set[Type[ExportPass]] = {ComputeConstantOpsAOT}
29+
_passes_required_after: Set[Type[ExportPass]] = {ComputeConstantOpsAOTPass}
2830

2931
def call_operator(self, op, args, kwargs, meta):
3032
if op not in [

backends/arm/_passes/convert_int64_const_ops_to_int32.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99

1010
import torch
1111
from executorch.backends.arm._passes import ArmPass
12-
from executorch.backends.arm._passes.fuse_constant_ops_pass import ComputeConstantOpsAOT
12+
from executorch.backends.arm._passes.fuse_constant_ops_pass import (
13+
ComputeConstantOpsAOTPass,
14+
)
1315
from executorch.exir.pass_base import ExportPass, PassResult
1416

1517

@@ -30,7 +32,7 @@ class ConvertInt64ConstOpsToInt32Pass(ArmPass):
3032
5. `torch.tensor`
3133
"""
3234

33-
_passes_required_after: Set[Type[ExportPass]] = {ComputeConstantOpsAOT}
35+
_passes_required_after: Set[Type[ExportPass]] = {ComputeConstantOpsAOTPass}
3436

3537
torch_ops = [
3638
torch.ops.aten.full.default,
@@ -47,7 +49,10 @@ def call(self, graph_module: torch.fx.GraphModule):
4749
if node.op != "call_function":
4850
continue
4951

50-
if node.target not in ComputeConstantOpsAOT.targeted_ops + self.torch_ops:
52+
if (
53+
node.target
54+
not in ComputeConstantOpsAOTPass.targeted_ops + self.torch_ops
55+
):
5156
continue
5257

5358
data = node.target(*node.args, **node.kwargs)

backends/arm/_passes/convert_to_clamp.py renamed to backends/arm/_passes/convert_to_clamp_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from executorch.backends.arm._passes import ArmPass
99

1010
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
11-
QuantizeOperatorArguments,
11+
QuantizeClampArgumentsPass,
1212
)
1313

1414
from executorch.exir.dialects._ops import ops as exir_ops
@@ -30,7 +30,7 @@ def get_clamp_params(op, args) -> Tuple[float | None, float | None]:
3030

3131

3232
class ConvertToClampPass(ArmPass):
33-
_passes_required_after: Set[Type[ExportPass]] = {QuantizeOperatorArguments}
33+
_passes_required_after: Set[Type[ExportPass]] = {QuantizeClampArgumentsPass}
3434

3535
def call_operator(self, op, args, kwargs, meta):
3636
if op not in edge_operators:

backends/arm/_passes/decompose_adaptive_avg_pool2d_pass.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
import torch
1010

1111
from executorch.backends.arm._passes import ArmPass
12-
from executorch.backends.arm._passes.decompose_avg_pool2d import DecomposeAvgPool2d
12+
from executorch.backends.arm._passes.decompose_avg_pool2d_pass import (
13+
DecomposeAvgPool2dPass,
14+
)
1315

1416
from executorch.exir.dialects._ops import ops as exir_ops
1517
from executorch.exir.pass_base import ExportPass, NodeMetadata
@@ -44,7 +46,7 @@ class DecomposeAdaptiveAvgPool2dPass(ArmPass):
4446
The output is of size output_size_h x output_size_w for any input.
4547
"""
4648

47-
_passes_required_after: Set[Type[ExportPass]] = {DecomposeAvgPool2d}
49+
_passes_required_after: Set[Type[ExportPass]] = {DecomposeAvgPool2dPass}
4850

4951
def call_operator(self, op, args, kwargs, meta, updated=False):
5052
if op not in (edge_ops + aten_ops):

backends/arm/_passes/decompose_avg_pool2d.py renamed to backends/arm/_passes/decompose_avg_pool2d_pass.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88

99
import torch
1010
from executorch.backends.arm._passes.arm_pass import ArmPass
11-
from executorch.backends.arm._passes.fuse_constant_ops_pass import ComputeConstantOpsAOT
11+
from executorch.backends.arm._passes.fuse_constant_ops_pass import (
12+
ComputeConstantOpsAOTPass,
13+
)
1214
from executorch.backends.arm.operators.operator_validation_utils import (
1315
adjust_pooling_pad_if_needed,
1416
)
@@ -37,8 +39,8 @@ def get_decomposition(op) -> tuple:
3739
raise RuntimeError(f"Can't get avg_pool2d decomposition for op {op}")
3840

3941

40-
class DecomposeAvgPool2d(ArmPass):
41-
_passes_required_after: Set[Type[ExportPass]] = {ComputeConstantOpsAOT}
42+
class DecomposeAvgPool2dPass(ArmPass):
43+
_passes_required_after: Set[Type[ExportPass]] = {ComputeConstantOpsAOTPass}
4244

4345
def call_operator(self, op, args, kwargs, meta):
4446
if op not in (edge_div_ops + aten_div_ops):

0 commit comments

Comments
 (0)