Skip to content

Commit e866b73

Browse files
committed
Arm backend: Handle i48 special case for bias tensor
For the case when the activation is 16 bit the bias in TOSA must be a int48_t tensor. Since that can't be represented using torch.dtypes the corresponding node.meta is set with a key 'tosa_dtype_48bit' to pass through the note to the creation of the TOSA Tensor. Also make sure to distinguish between int32 and int48 tensors in fuse constant ops pass. Signed-off-by: Per Åstrand <[email protected]> Change-Id: Iefe64f2b02f388c905c9c818ee7d2a6af40bc9e3
1 parent c9bf166 commit e866b73

File tree

4 files changed

+21
-2
lines changed

4 files changed

+21
-2
lines changed

backends/arm/_passes/add_bias_pass.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ def call(self, graph_module):
5555
persistent_buffer=True,
5656
name=f"{node.name}_bias",
5757
)
58+
if node.args[0].meta["val"].dtype == torch.int16:
59+
bias_node.meta["tosa_dtype_48bit"] = True
5860
node.update_arg(2, bias_node)
5961

6062
if modified:

backends/arm/_passes/fuse_equal_placeholders_pass.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,14 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
4444
continue
4545
# Create a lightweight fingerprint: dtype + shape + SHA1 of raw bytes
4646
# Ensure tensor is on CPU and contiguous
47+
48+
# ensure we don't merge any special case int48_t tensors with int32_t tensors
49+
# since int48_t tensors needs to be instantiated separately.
50+
is_int48 = node.meta.get("tosa_dtype_48bit", False)
4751
t_cpu = tensor.detach().cpu().contiguous()
4852
data_bytes = t_cpu.numpy().tobytes()
4953
key = (
54+
is_int48,
5055
str(t_cpu.dtype),
5156
tuple(t_cpu.shape),
5257
hashlib.sha1(data_bytes).hexdigest(),

backends/arm/process_node.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,10 +119,16 @@ def process_inputs_to_parameters(
119119
if tosa_arg.dtype == torch.float32:
120120
assert tosa_spec.support_float(), f"{tosa_spec} doesn't support float"
121121

122+
# Handle special case for INT48 tensors
123+
if node.meta.get("tosa_dtype_48bit", False):
124+
tosa_dtype = ts.DType.INT48
125+
else:
126+
tosa_dtype = tosa_arg.dtype
127+
122128
parameter_values = np.transpose(parameter_values, tosa_arg.dim_order)
123129

124130
tosa_graph.addConst(
125-
parameter_values.shape, tosa_arg.dtype, parameter_values, name=tosa_arg.name
131+
parameter_values.shape, tosa_dtype, parameter_values, name=tosa_arg.name
126132
)
127133

128134

backends/arm/tosa/mapping.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,16 @@ def extract_tensor_meta(meta, tosa_spec: TosaSpecification):
8181
class TosaArg:
8282
def __process_node(self, argument: torch.fx.Node):
8383
self.name: str = argument.name
84-
self.dtype, self.shape, self.dim_order = extract_tensor_meta(
84+
output_dtype, self.shape, self.dim_order = extract_tensor_meta(
8585
argument.meta, self.tosa_spec
8686
)
8787

88+
# Handle special case of int
89+
if argument.meta.get("tosa_dtype_48bit", False):
90+
output_dtype = ts.DType.INT48
91+
92+
self.dtype = output_dtype
93+
8894
def __process_list(self, argument):
8995
self.special: list = list(argument)
9096

0 commit comments

Comments
 (0)