Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion backends/arm/operators/op_avg_pool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
)
from executorch.backends.arm.tosa import TosaSpecification
from executorch.backends.arm.tosa.mapping import TosaArg
from executorch.backends.arm.tosa.specification import Tosa_1_00


@register_node_visitor
Expand Down Expand Up @@ -115,10 +116,15 @@ def define_node(
) -> None:
validate_num_inputs(self.target, inputs, [3, 4, 5, 6, 7])
validate_same_dtype(self.target, [inputs[0], output], ts)
supported_dtypes = [ts.DType.INT8, ts.DType.FP32]
if isinstance(self.tosa_spec, Tosa_1_00) and self.tosa_spec.support_extension(
"int16"
):
supported_dtypes.append(ts.DType.INT16)
validate_valid_dtype(
self.target,
[inputs[0], output],
[ts.DType.INT8, ts.DType.INT16, ts.DType.FP32],
supported_dtypes,
output.tosa_spec,
)

Expand Down
19 changes: 17 additions & 2 deletions backends/arm/operators/op_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@
)
from executorch.backends.arm.operators.operator_validation_utils import (
validate_num_inputs,
validate_same_dtype,
validate_valid_dtype,
)
from executorch.backends.arm.tosa.mapping import TosaArg
from executorch.backends.arm.tosa.specification import Tosa_1_00
from torch.fx import Node


Expand All @@ -35,9 +38,21 @@ def define_node(
inputs: List[TosaArg],
output: TosaArg,
) -> None:
supported_dtypes = [ts.DType.BOOL, ts.DType.INT8, ts.DType.INT32, ts.DType.FP32]
if isinstance(self.tosa_spec, Tosa_1_00) and self.tosa_spec.support_extension(
"int16"
):
supported_dtypes.append(ts.DType.INT16)
validate_num_inputs(self.target, inputs, [1, 2])
input_tosa_args = [TosaArg(arg, output.tosa_spec) for arg in inputs[0].special]
validate_same_dtype(self.target, [*input_tosa_args, output], ts)
validate_valid_dtype(
self.target,
[*input_tosa_args, output],
supported_dtypes,
output.tosa_spec,
)

tensors = inputs[0].special
dim = 0 if len(inputs) < 2 else inputs[1].number
rank = len(output.shape)
dim = (dim + rank) % rank
Expand All @@ -50,7 +65,7 @@ def define_node(
node,
tosa_graph,
ts.Op.CONCAT,
[tensor.name for tensor in tensors],
[tensor.name for tensor in input_tosa_args],
[output.name],
attr,
)
8 changes: 7 additions & 1 deletion backends/arm/operators/op_clamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from executorch.backends.arm.tosa import TosaSpecification

from executorch.backends.arm.tosa.mapping import TosaArg
from executorch.backends.arm.tosa.specification import Tosa_1_00
from torch.fx import Node


Expand Down Expand Up @@ -88,10 +89,15 @@ def define_node(
) -> None:
validate_num_inputs(self.target, inputs, [2, 3])
validate_same_dtype(self.target, [inputs[0], output], ts)
supported_dtypes = [ts.DType.INT8, ts.DType.FP16, ts.DType.FP32]
if isinstance(self.tosa_spec, Tosa_1_00) and self.tosa_spec.support_extension(
"int16"
):
supported_dtypes.append(ts.DType.INT16)
validate_valid_dtype(
self.target,
[inputs[0], output],
[ts.DType.INT8, ts.DType.INT16, ts.DType.FP16, ts.DType.FP32],
supported_dtypes,
output.tosa_spec,
)

Expand Down
2 changes: 1 addition & 1 deletion backends/arm/operators/op_eq.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def define_node(
validate_valid_dtype(
self.target,
inputs,
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32],
[ts.DType.INT32, ts.DType.FP32],
output.tosa_spec,
)
validate_valid_dtype(self.target, output, ts.DType.BOOL, output.tosa_spec)
Expand Down
2 changes: 1 addition & 1 deletion backends/arm/operators/op_ge.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def define_node(
validate_valid_dtype(
self.target,
inputs,
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32],
[ts.DType.INT32, ts.DType.FP32],
output.tosa_spec,
)
validate_valid_dtype(self.target, output, ts.DType.BOOL, output.tosa_spec)
Expand Down
2 changes: 1 addition & 1 deletion backends/arm/operators/op_gt.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def define_node(
validate_valid_dtype(
self.target,
inputs,
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32],
[ts.DType.INT32, ts.DType.FP32],
output.tosa_spec,
)
validate_valid_dtype(self.target, output, ts.DType.BOOL, output.tosa_spec)
Expand Down
17 changes: 14 additions & 3 deletions backends/arm/operators/op_index_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.operators.operator_validation_utils import (
validate_num_inputs,
validate_same_dtype,
validate_valid_dtype,
)
from executorch.backends.arm.tosa.mapping import TosaArg

from executorch.backends.arm.tosa.utils import build_reshape_tosa_1_0
Expand Down Expand Up @@ -45,10 +50,16 @@ def define_node(
inputs: List[TosaArg],
output: TosaArg,
) -> None:
if len(inputs) != 3:
raise ValueError(f"Number of inputs are not 3: {len(inputs)}")
validate_num_inputs(self.target, inputs, 3)
validate_same_dtype(self.target, [inputs[0], output], ts)
validate_valid_dtype(
self.target,
[inputs[0], output],
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32],
output.tosa_spec,
)

weights, index, indices = inputs
weights, _, indices = inputs

if len(weights.shape) == 2:
weights_new_shape = [1, weights.shape[0], weights.shape[1]]
Expand Down
2 changes: 1 addition & 1 deletion backends/arm/operators/op_le.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def define_node(
validate_valid_dtype(
self.target,
inputs,
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32],
[ts.DType.INT32, ts.DType.FP32],
output.tosa_spec,
)
validate_valid_dtype(self.target, output, ts.DType.BOOL, output.tosa_spec)
Expand Down
2 changes: 1 addition & 1 deletion backends/arm/operators/op_lt.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def define_node(
validate_valid_dtype(
self.target,
inputs,
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32],
[ts.DType.INT32, ts.DType.FP32],
output.tosa_spec,
)
validate_valid_dtype(self.target, output, ts.DType.BOOL, output.tosa_spec)
Expand Down
8 changes: 7 additions & 1 deletion backends/arm/operators/op_max_pool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
)
from executorch.backends.arm.tosa import TosaSpecification
from executorch.backends.arm.tosa.mapping import TosaArg
from executorch.backends.arm.tosa.specification import Tosa_1_00


@register_node_visitor
Expand All @@ -44,10 +45,15 @@ def define_node(
) -> None:
validate_num_inputs(self.target, inputs, [3, 4, 5, 6])
validate_same_dtype(self.target, [inputs[0], output], ts)
supported_dtypes = [ts.DType.INT8, ts.DType.FP32]
if isinstance(
output.tosa_spec, Tosa_1_00
) and output.tosa_spec.support_extension("int16"):
supported_dtypes.append(ts.DType.INT16)
validate_valid_dtype(
self.target,
[inputs[0], output],
[ts.DType.INT8, ts.DType.INT16, ts.DType.FP32],
supported_dtypes,
output.tosa_spec,
)

Expand Down
2 changes: 1 addition & 1 deletion backends/arm/operators/op_mul.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def define_node(
validate_valid_dtype(
self.target,
[*inputs, output],
[ts.DType.INT32, ts.DType.FP32],
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32],
output.tosa_spec,
)

Expand Down
8 changes: 7 additions & 1 deletion backends/arm/operators/op_permute.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,13 @@ def define_node(
validate_valid_dtype(
self.target,
[inputs[0], output],
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32],
[
ts.DType.BOOL,
ts.DType.INT8,
ts.DType.INT16,
ts.DType.INT32,
ts.DType.FP32,
],
output.tosa_spec,
)

Expand Down
8 changes: 7 additions & 1 deletion backends/arm/operators/op_repeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,13 @@ def define_node(
validate_valid_dtype(
self.target,
[inputs[0], output],
[ts.DType.INT8, ts.DType.INT32, ts.DType.INT16, ts.DType.FP32],
[
ts.DType.BOOL,
ts.DType.INT8,
ts.DType.INT16,
ts.DType.INT32,
ts.DType.FP32,
],
output.tosa_spec,
)

Expand Down
8 changes: 7 additions & 1 deletion backends/arm/operators/op_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,13 @@ def define_node(
validate_valid_dtype(
self.target,
[inputs[0], output],
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.FP32],
[
ts.DType.BOOL,
ts.DType.INT8,
ts.DType.INT16,
ts.DType.INT32,
ts.DType.FP32,
],
output.tosa_spec,
)

Expand Down
7 changes: 7 additions & 0 deletions backends/arm/operators/op_sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from executorch.backends.arm.operators.operator_validation_utils import (
validate_num_inputs,
validate_same_dtype,
validate_valid_dtype,
)
from executorch.backends.arm.tosa import TosaSpecification
from executorch.backends.arm.tosa.mapping import TosaArg
Expand All @@ -39,6 +40,12 @@ def define_node(
) -> None:
validate_num_inputs(self.target, inputs, 3)
validate_same_dtype(self.target, [inputs[0], output], ts)
validate_valid_dtype(
self.target,
[inputs[0], output],
[ts.DType.INT32, ts.DType.FP32],
output.tosa_spec,
)

tensor = inputs[0]
input_shape = list(tensor.shape)
Expand Down
15 changes: 13 additions & 2 deletions backends/arm/operators/op_tosa_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
)
from executorch.backends.arm.tosa import TosaSpecification
from executorch.backends.arm.tosa.mapping import TosaArg
from executorch.backends.arm.tosa.specification import Tosa_1_00


@register_node_visitor
Expand All @@ -51,16 +52,26 @@ def define_node(
"""Define the TOSA ``MATMUL`` operator."""
validate_num_inputs(self.target, inputs, 2)
validate_same_dtype(self.target, [*inputs], ts)
supported_input_dtypes = [ts.DType.INT8, ts.DType.INT32, ts.DType.FP32]
if isinstance(self.tosa_spec, Tosa_1_00) and self.tosa_spec.support_extension(
"int16"
):
supported_input_dtypes.append(ts.DType.INT16)
validate_valid_dtype(
self.target,
[*inputs],
[ts.DType.INT8, ts.DType.INT16, ts.DType.FP32],
supported_input_dtypes,
output.tosa_spec,
)
supported_output_dtypes = [ts.DType.INT32, ts.DType.FP32]
if isinstance(self.tosa_spec, Tosa_1_00) and self.tosa_spec.support_extension(
"int16"
):
supported_output_dtypes.append(ts.DType.INT48)
validate_valid_dtype(
self.target,
[output],
[ts.DType.INT32, ts.DType.INT48, ts.DType.FP32],
supported_output_dtypes,
output.tosa_spec,
)

Expand Down
43 changes: 24 additions & 19 deletions backends/arm/operators/op_tosa_resize.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
validate_valid_dtype,
)
from executorch.backends.arm.tosa.mapping import TosaArg
from executorch.backends.arm.tosa.specification import Tosa_1_00
from executorch.backends.arm.tosa.utils import get_resize_parameters


Expand All @@ -39,33 +40,37 @@ def define_node(
output: TosaArg,
) -> None:
validate_num_inputs(self.target, inputs, [3, 4])
supported_input_dtypes = [ts.DType.INT8, ts.DType.FP32]
if isinstance(self.tosa_spec, Tosa_1_00) and self.tosa_spec.support_extension(
"int16"
):
supported_input_dtypes.append(ts.DType.INT16)
validate_valid_dtype(
self.target,
[inputs[0]],
supported_input_dtypes,
output.tosa_spec,
)
supported_output_dtypes = [ts.DType.FP32]
if node.kwargs.get("resize_mode") == "bilinear":
resize_mode = ts.ResizeMode.BILINEAR
align_corners = bool(node.args[2])
supported_output_dtypes.append(ts.DType.INT32)
if isinstance(
self.tosa_spec, Tosa_1_00
) and self.tosa_spec.support_extension("int16"):
supported_output_dtypes.append(ts.DType.INT48)
else:
resize_mode = ts.ResizeMode.NEAREST
align_corners = False
validate_same_dtype(self.target, [inputs[0], output], ts)

valid_dtypes = []
if self.tosa_spec.support_integer():
valid_dtypes.extend(
[ts.DType.INT8, ts.DType.INT16, ts.DType.INT32, ts.DType.INT48]
)

if self.tosa_spec.support_float():
valid_dtypes.extend(
[
ts.DType.FP16,
ts.DType.FP32,
]
)

supported_output_dtypes.append(ts.DType.INT8)
if isinstance(
self.tosa_spec, Tosa_1_00
) and self.tosa_spec.support_extension("int16"):
supported_output_dtypes.append(ts.DType.INT16)
validate_valid_dtype(
self.target,
[inputs[0], output],
valid_dtypes,
output.tosa_spec,
self.target, [output], supported_output_dtypes, output.tosa_spec
)
# tosa_shape output is NHWC, take HW
input_size_yx = tuple([inputs[0].shape[dim] for dim in inputs[0].dim_order])[
Expand Down
Loading
Loading