Skip to content

Commit 39b154d

Browse files
committed
Update on "Arm backend: Add 16A8W support and test for mul operation"
Add 16A8W quantization support and test for the mul operation in ExecutorTorch ARM backend. This follows the pattern established for linear operations, extending int16 support to mul operations. Changes: - Add INT16 dtype validation support in op_mul.py - Add test_mul_tensor_16a8w_tosa_INT test function - Enable test_mul.py in test targets configuration The 16A8W configuration uses 16-bit activations with 8-bit weights, enabling higher precision for activations while maintaining weight efficiency. Differential Revision: [D80510628](https://our.internmc.facebook.com/intern/diff/D80510628/) cc digantdesai freddan80 per zingo oscarandersson8218 [ghstack-poisoned]
2 parents f29a8f6 + d008093 commit 39b154d

File tree

52 files changed

+2308
-1751
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+2308
-1751
lines changed

.github/workflows/pull.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -971,6 +971,13 @@ jobs:
971971
./cmake-out/backends/vulkan/test/custom_ops/q4gsw_linear
972972
./cmake-out/backends/vulkan/test/custom_ops/choose_qparams_per_row
973973
974+
# "Classic" Operator tests
975+
PYTHON_EXECUTABLE=python bash backends/vulkan/test/scripts/test_op.sh --build
976+
# TODO(ssjia): figure out how to run custom op tests in CI. Currently, they are
977+
# failing due to to the libstdc++.so.6 installed with conda not supporting
978+
# GLIBCXX_3.4.30. These tests are still run in Meta internal CI.
979+
# ./cmake-out/backends/vulkan/test/op_tests/vulkan_sdpa_test
980+
974981
# Run e2e testing for selected operators. More operators will be tested via this
975982
# route in the future.
976983
python -m unittest backends/vulkan/test/test_vulkan_delegate.py -k "*pt2e*"

backends/arm/README.md

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -206,14 +206,6 @@ The current TOSA version does not support int64. However, int64 is commonly used
206206
- For quantized models, these transformations will be automatically handled during annotation before the export stage.
207207

208208
List of model specific and optional passes:
209-
- InsertCastForOpsWithInt64InputPass
210-
- Functionality:
211-
- For LLMs such as LLama, some opeartors like aten.embedding have int64 input. In order to lower these operators to TOSA, this pass will insert a casting node that converts the input from int64 to int32.
212-
- Supported Ops:
213-
- aten.embedding.default, aten.slice_copy.Tensor
214-
- Example usage:
215-
- backends/arm/test/models/test_llama.py
216-
217209
- ConvertInt64ConstOpsToInt32Pass
218210
- Functionalities:
219211
- Rewrites constant-producing ops that output int64 to instead output int32, when values are within int32 bounds.
@@ -244,3 +236,16 @@ List of model specific and optional passes:
244236
- Example usage:
245237
- (Functionality 1) backends/arm/test/models/stable_diffusion/test_T5EncoderModel.py
246238
- (Functionality 2) backends/arm/test/models/stable_diffusion/test_CLIPTextModelWithProjection.py
239+
240+
- InsertInt32CastsAfterInt64PlaceholdersPass
241+
- Functionalities:
242+
- Inserts an int64 -> int32 cast immediately after each int64 placeholder (graph input).
243+
- Redirects all uses of each int64 placeholder to its int32 cast output.
244+
- Inserts local int32 -> int64 casts at call sites where an operator requires int64 inputs, e.g. `torch.nn.functional.one_hot`
245+
- Pass ordering:
246+
- When used with `ConvertInt64ConstOpsToInt32Pass` and `ConvertInt64OutputOpsToInt32Pass`, run this pass last.
247+
- Rationale: Those passes may cause retracing to re-infer some int64 placeholders as int32. Running this pass last casts only inputs that remain int64, minimizing inserted casts.
248+
- Example usage:
249+
- backends/arm/test/models/test_llama.py
250+
- backends/arm/test/models/stable_diffusion/test_CLIPTextModelWithProjection.py
251+
- backends/arm/test/models/stable_diffusion/test_T5EncoderModel.py

backends/arm/_passes/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,8 @@
7575
from .fuse_constant_ops_pass import ComputeConstantOpsAOT, FuseConstantArgsPass # noqa
7676
from .fuse_equal_placeholders_pass import FuseEqualPlaceholdersPass # noqa
7777
from .fuse_quantized_activation_pass import FuseQuantizedActivationPass # noqa
78-
from .insert_int64_input_cast_pass import ( # noqa # noqa
79-
InsertCastForOpsWithInt64InputPass,
78+
from .insert_int32_casts_after_int64_placeholders import ( # noqa
79+
InsertInt32CastsAfterInt64PlaceholdersPass,
8080
)
8181
from .insert_rescales_pass import InsertRescalePass # noqa
8282
from .insert_table_ops import InsertTableOpsPass # noqa

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@
7676
FuseConstantArgsPass,
7777
FuseEqualPlaceholdersPass,
7878
FuseQuantizedActivationPass,
79-
InsertCastForOpsWithInt64InputPass,
79+
InsertInt32CastsAfterInt64PlaceholdersPass,
8080
InsertRescalePass,
8181
InsertTableOpsPass,
8282
MatchArgDtypePass,
@@ -277,7 +277,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
277277
) # ConvertInt64ConstOpsToInt32Pass requires this pass to remove the assertation in Graph
278278
self.add_pass(ConvertInt64ConstOpsToInt32Pass())
279279
self.add_pass(ConvertInt64OutputOpsToInt32Pass())
280-
self.add_pass(InsertCastForOpsWithInt64InputPass())
280+
self.add_pass(InsertInt32CastsAfterInt64PlaceholdersPass())
281281
self.add_pass(DecomposeEmbeddingPass())
282282
self.add_pass(DecomposeScaledDotProductAttention())
283283
self.add_pass(DecomposeRoundPass())
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-unsafe
7+
8+
9+
import logging
10+
11+
import torch
12+
from executorch.backends.arm._passes.arm_pass_utils import create_node
13+
from executorch.exir.dialects._ops import ops as exir_ops
14+
from executorch.exir.pass_base import EdgeOpOverload, ExportPass, PassResult
15+
from torch._subclasses.fake_tensor import FakeTensor
16+
17+
18+
logger = logging.getLogger(__name__)
19+
20+
21+
class InsertInt32CastsAfterInt64PlaceholdersPass(ExportPass):
22+
"""
23+
Insert an int64->int32 cast after each int64 placeholder.
24+
25+
Note: Overflow checks are not applied in this pass. It is the user's responsibility to ensure that values fit within
26+
the int32 range.
27+
"""
28+
29+
# Ops that require i64 inputs → positions of args to upcast.
30+
# Key: op overload; Value: zero-based indices of positional args that must be i64.
31+
I64_INPUT_ARG_POSITIONS = {
32+
torch.ops.aten.one_hot.default: (0,),
33+
}
34+
35+
def _insert_callsite_i32_to_i64_casts(self, graph_module: torch.fx.GraphModule):
36+
"""
37+
If an operator requires int64 inputs but dtype propagation (via call_operator)
38+
produced int32, insert a local int32→int64 cast at the call site to satisfy
39+
PyTorch's operator input validation.
40+
"""
41+
modified = False
42+
graph = graph_module.graph
43+
for node in graph.nodes:
44+
if node.op != "call_function":
45+
continue
46+
if node.target not in self.I64_INPUT_ARG_POSITIONS:
47+
continue
48+
49+
with graph.inserting_before(node):
50+
arg_positions = self.I64_INPUT_ARG_POSITIONS.get(node.target)
51+
args_list = list(node.args)
52+
for pos in arg_positions: # type: ignore[union-attr]
53+
input_arg = args_list[pos]
54+
to_copy_op = self._get_decomposition(graph)
55+
cast_node = graph_module.graph.create_node(
56+
"call_function",
57+
to_copy_op,
58+
(input_arg,),
59+
{"dtype": torch.int64},
60+
)
61+
cast_node.meta["val"] = node.meta["val"].to(torch.int64)
62+
args_list[pos] = cast_node
63+
node.args = tuple(args_list)
64+
modified = True
65+
return modified
66+
67+
def _graph_uses_edge_ops(self, graph: torch.fx.Graph) -> bool:
68+
for n in graph.nodes:
69+
if n.op == "call_function":
70+
if isinstance(n.target, EdgeOpOverload):
71+
return True
72+
return False
73+
74+
def _get_decomposition(self, graph: torch.fx.Graph):
75+
if self._graph_uses_edge_ops(graph):
76+
return exir_ops.edge.dim_order_ops._to_dim_order_copy.default
77+
else:
78+
return torch.ops.dim_order_ops._to_dim_order_copy.default
79+
80+
def _is_tensor_of_dtype(self, node_val, dtype: torch.dtype) -> bool:
81+
return isinstance(node_val, FakeTensor) and node_val.dtype == dtype
82+
83+
def _insert_placeholder_i64_to_i32_casts(self, graph_module: torch.fx.GraphModule):
84+
modified = False
85+
graph = graph_module.graph
86+
for node in graph.nodes:
87+
if node.op != "placeholder":
88+
continue
89+
node_val = node.meta["val"]
90+
if not self._is_tensor_of_dtype(node_val, torch.int64):
91+
continue
92+
93+
to_copy_op = self._get_decomposition(graph)
94+
with graph.inserting_after(node):
95+
cast_after = create_node(
96+
graph,
97+
to_copy_op,
98+
args=(node,),
99+
kwargs={
100+
"dtype": torch.int32,
101+
},
102+
)
103+
users = [user for user in node.users if user != cast_after]
104+
for user in users:
105+
user.replace_input_with(node, cast_after)
106+
logger.warning(
107+
f"Inserting a casting node {cast_after.name} after {node.name} to cast int64 placeholder"
108+
f" to int32 for {node.name} defined in {node.meta.get('stack_trace','[no stack trace found]')}"
109+
)
110+
modified = True
111+
return modified
112+
113+
def call(self, graph_module: torch.fx.GraphModule):
114+
modified = False
115+
modified |= self._insert_placeholder_i64_to_i32_casts(graph_module)
116+
modified |= self._insert_callsite_i32_to_i64_casts(graph_module)
117+
118+
if modified:
119+
graph_module.graph.eliminate_dead_code()
120+
graph_module.recompile()
121+
graph_module = super().call(graph_module).graph_module
122+
return PassResult(graph_module, modified)

backends/arm/_passes/insert_int64_input_cast_pass.py

Lines changed: 0 additions & 109 deletions
This file was deleted.

backends/arm/operators/op_abs.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,9 @@ def define_node(
7373
abs_output = output
7474

7575
# Do the INT32 Abs
76-
tosa_graph.addOperator(
76+
self._serialize_operator(
77+
node,
78+
tosa_graph,
7779
ts.TosaOp.Op().ABS,
7880
[
7981
rescaled_inputs[0].name,

backends/arm/operators/op_sum.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,9 @@ def define_node(
6767
dtype=ts.DType.INT32,
6868
)
6969

70-
tosa_graph.addOperator(
70+
self._serialize_operator(
71+
node,
72+
tosa_graph,
7173
ts.TosaOp.Op().REDUCE_SUM,
7274
[rescaled_inputs[0].name],
7375
[intermediate.name],
@@ -111,7 +113,9 @@ def define_node(
111113
attr = ts.TosaSerializerAttribute()
112114
attr.ReduceSumAttribute(tensor.dim_order.index(dim))
113115

114-
tosa_graph.addOperator(
116+
self._serialize_operator(
117+
node,
118+
tosa_graph,
115119
ts.TosaOp.Op().REDUCE_SUM,
116120
[tensor.name],
117121
[output.name],

backends/arm/test/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ python_library(
4444
"//executorch/backends/arm:ethosu_partitioner",
4545
"//executorch/backends/arm/quantizer:lib",
4646
"//executorch/backends/arm/tosa:mapping",
47+
"//executorch/backends/arm:vgf_partitioner",
4748
"//executorch/devtools/backend_debug:delegation_info",
4849
"//executorch/exir/backend:operator_support",
4950
"fbsource//third-party/pypi/tabulate:tabulate",

backends/arm/test/models/stable_diffusion/test_CLIPTextModelWithProjection.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from executorch.backends.arm._passes import (
1212
ConvertInt64ConstOpsToInt32Pass,
1313
ConvertInt64OutputOpsToInt32Pass,
14-
InsertCastForOpsWithInt64InputPass,
14+
InsertInt32CastsAfterInt64PlaceholdersPass,
1515
)
1616

1717
from executorch.backends.arm.test import common
@@ -33,10 +33,9 @@ class TestCLIPTextModelWithProjection(unittest.TestCase):
3333
# for that is some assert ops are removed by passes in the
3434
# .to_executorch step, i.e. after Arm partitioner.
3535
ops_after_partitioner = {
36-
"executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 3,
37-
"executorch_exir_dialects_edge__ops_aten_view_copy_default": 1,
3836
"executorch_exir_dialects_edge__ops_aten_argmax_default": 1,
39-
"torch.ops.higher_order.executorch_call_delegate": 1,
37+
"executorch_exir_dialects_edge__ops_dim_order_ops__to_dim_order_copy_default": 2,
38+
"torch.ops.higher_order.executorch_call_delegate": 2,
4039
}
4140

4241
def _prepare_inputs(
@@ -71,9 +70,9 @@ def test_CLIPTextModelWithProjection_tosa_FP(self):
7170
example_inputs=text_encoder_model_inputs,
7271
compile_spec=common.get_tosa_compile_spec(tosa_spec="TOSA-1.0+FP"),
7372
transform_passes=[
74-
InsertCastForOpsWithInt64InputPass(),
7573
ConvertInt64ConstOpsToInt32Pass(),
7674
ConvertInt64OutputOpsToInt32Pass(),
75+
InsertInt32CastsAfterInt64PlaceholdersPass(),
7776
],
7877
)
7978
.export()

0 commit comments

Comments
 (0)