Skip to content

Commit 54ad8e2

Browse files
committed
Arm backend: Add special dtype TOSA handling
Add a enum class to handle special dtypes that can't be represented in torch (i.e. int48_t) to avoid leaking serializer types into the pass handling of the backend. Signed-off-by: Per Åstrand <[email protected]> Change-Id: I3388cec3c8a26f28790eedc3f124c336b6724cb4
1 parent 3dd50a2 commit 54ad8e2

File tree

4 files changed

+31
-8
lines changed

4 files changed

+31
-8
lines changed

backends/arm/_passes/add_bias_pass.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import torch
77
from executorch.backends.arm._passes import ArmPass
88
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
9+
from executorch.backends.arm.tosa.mapping import TosaSpecialDtype
910
from executorch.backends.transforms.utils import create_constant_placeholder
1011

1112
from executorch.exir.dialects._ops import ops as exir_ops
@@ -56,7 +57,9 @@ def call(self, graph_module):
5657
name=f"{node.name}_bias",
5758
)
5859
if node.args[0].meta["val"].dtype == torch.int16:
59-
bias_node.meta["tosa_dtype_48bit"] = True
60+
bias_node.meta[TosaSpecialDtype.meta_key()] = (
61+
TosaSpecialDtype.INT48
62+
)
6063
node.update_arg(2, bias_node)
6164

6265
if modified:

backends/arm/_passes/fuse_equal_placeholders_pass.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,13 @@
77
from collections import defaultdict
88

99
import torch
10+
1011
from executorch.backends.arm._passes.arm_pass_utils import (
1112
get_constant_placeholder_kind,
1213
get_param_tensor,
1314
is_param_node,
1415
)
16+
from executorch.backends.arm.tosa.mapping import TosaSpecialDtype
1517
from executorch.backends.transforms.utils import (
1618
create_constant_placeholder,
1719
delete_constant_placeholder,
@@ -47,7 +49,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
4749

4850
# ensure we don't merge any special case int48_t tensors with int32_t tensors
4951
# since int48_t tensors needs to be instantiated separately.
50-
is_int48 = node.meta.get("tosa_dtype_48bit", False)
52+
is_int48 = node.meta.get(TosaSpecialDtype.meta_key(), None)
5153
t_cpu = tensor.detach().cpu().contiguous()
5254
data_bytes = t_cpu.numpy().tobytes()
5355
key = (

backends/arm/process_node.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import torch
1313
import torch.fx
1414
from executorch.backends.arm.operators.node_visitor import NodeVisitor
15-
from executorch.backends.arm.tosa.mapping import TosaArg
15+
from executorch.backends.arm.tosa.mapping import TosaArg, TosaSpecialDtype
1616
from executorch.backends.arm.tosa.specification import TosaSpecification
1717
from executorch.backends.arm.tosa.utils import tosa_shape
1818
from torch._export.utils import (
@@ -120,8 +120,9 @@ def process_inputs_to_parameters(
120120
assert tosa_spec.support_float(), f"{tosa_spec} doesn't support float"
121121

122122
# Handle special case for INT48 tensors
123-
if node.meta.get("tosa_dtype_48bit", False):
124-
tosa_dtype = ts.DType.INT48
123+
special_type = node.meta.get(TosaSpecialDtype.meta_key(), None)
124+
if isinstance(special_type, TosaSpecialDtype):
125+
tosa_dtype = special_type.get_tosa_dtype()
125126
else:
126127
tosa_dtype = tosa_arg.dtype
127128

backends/arm/tosa/mapping.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# the standardised TOSA representation.
1212
#
1313

14+
from enum import Enum
1415
from typing import Any, Optional, Sequence
1516

1617
import serializer.tosa_serializer as ts # type: ignore
@@ -31,6 +32,22 @@
3132
)
3233

3334

35+
class TosaSpecialDtype(Enum):
36+
"""
37+
Special TOSA data types that are not natively supported in PyTorch, to be
38+
used in specific scenarios as a value in the key from meta_key().
39+
"""
40+
41+
INT48 = ts.DType.INT48
42+
43+
def get_tosa_dtype(self) -> ts.TosaDType.DType:
44+
return self.value
45+
46+
@staticmethod
47+
def meta_key() -> str:
48+
return "tosa_special_dtype"
49+
50+
3451
def map_dtype(data_type: torch.dtype, tosa_spec: TosaSpecification) -> Any:
3552
if data_type in UNSUPPORTED_DTYPES:
3653
raise ValueError(f"Unsupported type: {data_type}")
@@ -85,9 +102,9 @@ def __process_node(self, argument: torch.fx.Node):
85102
argument.meta, self.tosa_spec
86103
)
87104

88-
# Handle special case of int
89-
if argument.meta.get("tosa_dtype_48bit", False):
90-
output_dtype = ts.DType.INT48
105+
# Handle special case of types not representable in torch (i.e. i48_t)
106+
if special_type := argument.meta.get(TosaSpecialDtype.meta_key(), None):
107+
output_dtype = special_type.get_tosa_dtype()
91108

92109
self.dtype = output_dtype
93110

0 commit comments

Comments
 (0)