Skip to content

Commit 18c9985

Browse files
committed
Arm backend: Add special dtype TOSA handling
Add a enum class to handle special dtypes that can't be represented in torch (i.e. int48_t) to avoid leaking serializer types into the pass handling of the backend. Signed-off-by: Per Åstrand <[email protected]> Change-Id: I3388cec3c8a26f28790eedc3f124c336b6724cb4
1 parent aafcede commit 18c9985

File tree

4 files changed

+31
-8
lines changed

4 files changed

+31
-8
lines changed

backends/arm/_passes/add_bias_pass.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import torch
99
from executorch.backends.arm._passes import ArmPass
1010
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
11+
from executorch.backends.arm.tosa.mapping import TosaSpecialDtype
1112
from executorch.backends.transforms.utils import create_constant_placeholder
1213

1314
from executorch.exir.dialects._ops import ops as exir_ops
@@ -60,7 +61,9 @@ def call(self, graph_module):
6061
name=f"{node.name}_bias",
6162
)
6263
if node.args[0].meta["val"].dtype == torch.int16:
63-
bias_node.meta["tosa_dtype_48bit"] = True
64+
bias_node.meta[TosaSpecialDtype.meta_key()] = (
65+
TosaSpecialDtype.INT48
66+
)
6467
node.update_arg(2, bias_node)
6568

6669
if modified:

backends/arm/_passes/fuse_equal_placeholders_pass.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,13 @@
88
from typing import Set, Type
99

1010
import torch
11+
1112
from executorch.backends.arm._passes.arm_pass_utils import (
1213
get_constant_placeholder_kind,
1314
get_param_tensor,
1415
is_param_node,
1516
)
17+
from executorch.backends.arm.tosa.mapping import TosaSpecialDtype
1618
from executorch.backends.transforms.utils import (
1719
create_constant_placeholder,
1820
delete_constant_placeholder,
@@ -50,7 +52,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
5052

5153
# ensure we don't merge any special case int48_t tensors with int32_t tensors
5254
# since int48_t tensors needs to be instantiated separately.
53-
is_int48 = node.meta.get("tosa_dtype_48bit", False)
55+
is_int48 = node.meta.get(TosaSpecialDtype.meta_key(), None)
5456
t_cpu = tensor.detach().cpu().contiguous()
5557
data_bytes = t_cpu.numpy().tobytes()
5658
key = (

backends/arm/process_node.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import torch
1313
import torch.fx
1414
from executorch.backends.arm.operators.node_visitor import NodeVisitor
15-
from executorch.backends.arm.tosa.mapping import TosaArg
15+
from executorch.backends.arm.tosa.mapping import TosaArg, TosaSpecialDtype
1616
from executorch.backends.arm.tosa.specification import TosaSpecification
1717
from executorch.backends.arm.tosa.utils import tosa_shape
1818
from torch._export.utils import (
@@ -113,8 +113,9 @@ def process_inputs_to_parameters(
113113
assert tosa_spec.support_float(), f"{tosa_spec} doesn't support float"
114114

115115
# Handle special case for INT48 tensors
116-
if node.meta.get("tosa_dtype_48bit", False):
117-
tosa_dtype = ts.DType.INT48
116+
special_type = node.meta.get(TosaSpecialDtype.meta_key(), None)
117+
if isinstance(special_type, TosaSpecialDtype):
118+
tosa_dtype = special_type.get_tosa_dtype()
118119
else:
119120
tosa_dtype = tosa_arg.dtype
120121

backends/arm/tosa/mapping.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
1212
"""
1313

14+
from enum import Enum
1415
from typing import Any, Optional, Sequence
1516

1617
import serializer.tosa_serializer as ts # type: ignore
@@ -31,6 +32,22 @@
3132
)
3233

3334

35+
class TosaSpecialDtype(Enum):
36+
"""
37+
Special TOSA data types that are not natively supported in PyTorch, to be
38+
used in specific scenarios as a value in the key from meta_key().
39+
"""
40+
41+
INT48 = ts.DType.INT48
42+
43+
def get_tosa_dtype(self) -> ts.TosaDType.DType:
44+
return self.value
45+
46+
@staticmethod
47+
def meta_key() -> str:
48+
return "tosa_special_dtype"
49+
50+
3451
def map_dtype(data_type: torch.dtype, tosa_spec: TosaSpecification) -> Any:
3552
"""Map a ``torch.dtype`` to a ``ts.DType``.
3653
@@ -134,9 +151,9 @@ def __process_node(self, argument: torch.fx.Node):
134151
argument.meta, self.tosa_spec
135152
)
136153

137-
# Handle special case of int
138-
if argument.meta.get("tosa_dtype_48bit", False):
139-
output_dtype = ts.DType.INT48
154+
# Handle special case of types not representable in torch (i.e. i48_t)
155+
if special_type := argument.meta.get(TosaSpecialDtype.meta_key(), None):
156+
output_dtype = special_type.get_tosa_dtype()
140157

141158
self.dtype = output_dtype
142159

0 commit comments

Comments
 (0)