Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/nxp/backend/edge_program_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
exir_ops.edge.aten.permute_copy.default: PermuteCopyConverter, # noqa F405
exir_ops.edge.aten.relu.default: ReLUConverter, # noqa F405
exir_ops.edge.aten._softmax.default: SoftmaxConverter, # noqa F405
exir_ops.edge.aten.slice_copy.Tensor: SliceTensorConverter, # noqa F405
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please maintain alphabetical order.

exir_ops.edge.aten.sub.Tensor: SubTensorConverter, # noqa F405
exir_ops.edge.aten.tanh.default: TanhConverter, # noqa F405
exir_ops.edge.aten.view_copy.default: ViewCopyConverter, # noqa F405
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.sigmoid_converter import (
SigmoidConverter,
)
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.slice_tensor_converter import (
SliceTensorConverter,
)
from executorch.backends.nxp.backend.ir.converter.node_converters.ops_converters.softmax_converter import (
SoftmaxConverter,
)
Expand Down Expand Up @@ -90,4 +93,5 @@
"HardTanhConverter",
"SigmoidConverter",
"TanhConverter",
"SliceTensorConverter",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Copyright 2025 NXP
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import numpy as np
from backends.nxp.backend.edge_helper import input_tensor
from backends.nxp.backend.ir.converter.conversion.common import OpsList
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please don't use relative imports.

The tests pass because Python finds the backends directory in our development environment, but it won't work in deployment as there is no backends top-level module in the installed Python package.

So, replace all from backends... with from executorch.backends....

from executorch.backends.nxp.backend.ir.converter.conversion import translator
from executorch.backends.nxp.backend.ir.converter.conversion.common import (
node_uses_shape_broadcasting,
)
from executorch.backends.nxp.backend.ir.converter.node_converter import (
CustomDelegationOptions,
NodeConverter,
)
from executorch.backends.nxp.backend.ir.tflite_generator.builtin_options import (
slice_options,
)
from executorch.backends.nxp.backend.neutron_target_spec import NeutronTargetSpec
from torch.fx import Node
from torch.nn import Parameter


class SliceTensorConverter(NodeConverter):
@staticmethod
def _is_supported_on_target(
node: Node,
neutron_target_spec: NeutronTargetSpec,
parameters_mapping: dict[str, Parameter],
custom_delegation_options: CustomDelegationOptions,
) -> bool:
if node_uses_shape_broadcasting(node):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shape broadcasting is something that allows us to perform element-wise operations on 2 tensors, that don't necessarily have the same shape. For example if you have 2 tensors with shapes (2, 3) and (2, 3), you can trivially add (or multiply, subtract, divide, ...) their elements together. E.g.:
[[1, 2, 3],
[4, 5, 6]]
+
[[0, -1, 1],
[2, -2, 3]]
=
[[1, 1, 4],
[6, 3, 9]]

If your tensors have different shapes, for example (2, 3) and (5, 4), the element-wise operations are not defined (how would you add them together?). But for some combinations of shapes, they are defined. Going from the right, the dimensions must either match, or one of them must be 1, and the ranks don't even need to be the same. For example with shapes (2, 3) and (1):
[[1, 2, 3],
[4, 5, 6]]
+
[-1]
=
[[0, 1, 2],
[3, 4, 5]]

or for shapes (2, 3) and (3):
[[1, 2, 3],
[4, 5, 6]]
+
[-1, 0, 1]
=
[[0, 2, 4],
[3, 5, 7]]

Other combinations that work are for example:
(2, 4, 6, 8) and (4, 1, 8)
(42, 1) and (2, 4, 1, 8)
...

The fact that the tensors can have a different rank (number of dimensions) causes some issues when we are handling the conversion from ExecuTorch's channels_first format to NeutronIR's channels_last format. These issues may require the insertion of Transpose operators to solve, hence the check on line 33.

But this is only the case for operators that support shape broadcasting (Add, Sub, Mul, Div, Pow, ...). Slice is not one of these element-wise operations, therefore this check does not make sense here, and it should be removed.

Sorry for the long comment, but I believe it is important to understand these issues going forward.

# Shape broadcasting may require the addition of `Transpose` ops during conversion.
return False

# Provisional solution - slice conversion works for neutron software 2.2.1+
neutron_flavor = neutron_target_spec.neutron_target.__module__.split(".")[0]
if neutron_flavor != "neutron_converter_SDK_25_12":
return False

input_shape = input_tensor(node, 0).shape
dim = node.args[1]

# The rank of the dimension that we want to slice must be divisible by num_macs
Copy link
Collaborator

@MartinPavella MartinPavella Nov 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does The rank of the dimension mean?
Generally, by rank we mean the number of dimensions of a tensor.

num_macs = neutron_target_spec.get_num_macs()
return input_shape[dim] % num_macs == 0

@staticmethod
def _is_supported_in_IR(
node: Node,
parameters_mapping: dict[str, Parameter],
custom_delegation_options: CustomDelegationOptions,
) -> bool:
args = node.args
if len(args) != 4:
return False

_, start, end = SliceTensorConverter._get_clipped_slice_args(node)
if start >= end:
return False

return True

def _convert_to_slice(self, t_op, main_input, input_rank, dim, start, end) -> None:
# Prepare the TFLite parameters 'begin' and 'size' tensors
begin = [0] * input_rank # By default, start the slice at 0
size = (
main_input.shape.vector.copy()
) # By default, end the slice at the end of the dimension

size[dim] = max(end - start, 0)
begin[dim] = start

# We can slice only the channels dimension
# So we swap the sliced dimension with the channels dimension
begin[-1], begin[dim] = begin[dim], begin[-1]
size[-1], size[dim] = size[dim], size[-1]

# Create permutation for swapping
perm = list(range(0, input_rank))
perm[dim], perm[-1] = perm[-1], perm[dim]

begin_tensor = self.builder.create_tensor_for_data(
np.asarray(begin, np.int32), "begin"
)
size_tensor = self.builder.create_tensor_for_data(
np.asarray(size, np.int32), "size"
)

t_op.tmp_inputs = [main_input, begin_tensor, size_tensor]
t_op.builtin_options = slice_options.Slice()

ops = OpsList(middle_op=t_op)
# Insert forward and backward transpose
ops.add_pre(self.builder.create_transpose_operator_before(t_op, 0, perm))
ops.add_post(self.builder.create_transpose_operator_after(t_op, 0, perm))

self.builder.append_operators(ops.flatten())

Dim = Start = End = int
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice 👍🏻


@staticmethod
def _get_clipped_slice_args(node: Node) -> (Dim, Start, End):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This type hint is not valid in Python. Use the following instead:

def _get_clipped_slice_args(node: Node) -> tuple[Dim, Start, End]:

input_shape = input_tensor(node, 0).shape
_, dim, start, end = node.args
sliced_tensor_rank = input_shape[dim]

end = int(np.clip(end, 0, sliced_tensor_rank))
start = int(np.clip(start, 0, sliced_tensor_rank))

return dim, start, end

def convert(self, node: Node):
"""Convert 'slice_tensor' operator to NeutronIR 'Slice'."""
self.assert_convertible(node)
t_op = self._create_tflite_op_with_io_tensors(node)
inputs = t_op.tmp_inputs[0]
rank = inputs.rank

dim, start, end = self._get_clipped_slice_args(node)

if t_op.tmp_inputs[0].tensor_format.is_channels_last():
dim = translator.create_channels_last_to_channels_first_permutation(
t_op.tmp_inputs[0].rank
)[dim]

self._convert_to_slice(t_op, inputs, rank, dim, start, end)
9 changes: 7 additions & 2 deletions backends/nxp/backend/neutron_converter_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@ class NeutronConverterManager:
contains NeutronGraph nodes.
"""

def __init__(self, neutron_converter_flavor: str = "SDK_25_09"):
def __init__(
self,
neutron_converter_flavor: str = "SDK_25_09",
):

neutron_converter_modules = [
module.name
Expand Down Expand Up @@ -76,7 +79,9 @@ def convert(self, tflite_model: bytes, target: str) -> bytes:
cctx = self.neutron_converter.CompilationContext()
cctx.targetOpts = self.neutron_converter.getNeutronTarget(target)
cctx.compilationOpts.minNumOpsPerGraph = 1
cctx.compilationOpts.excludeGraphPasses = "MergeTranspose"
cctx.compilationOpts.excludeGraphPasses = (
"HoistSliceAboveTranspose,MergeTranspose"
)

logger = multiprocessing.log_to_stderr()
logger.setLevel(logging.WARNING)
Expand Down
1 change: 1 addition & 0 deletions backends/nxp/neutron_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ def tag_qdq_clusters(self, nodes: list[torch.fx.Node]):
exir_ops.edge.aten.mm.default: MMConverter, # noqa F405
exir_ops.edge.aten.permute_copy.default: PermuteCopyConverter, # noqa F405
exir_ops.edge.aten.relu.default: ReLUConverter, # noqa F405
exir_ops.edge.aten.slice_copy.Tensor: SliceTensorConverter, # noqa F405
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor:
Pleas keep alphabetical order. (I know it was already broken by Sigmoid, but you can fix both ;) )

exir_ops.edge.aten._softmax.default: SoftmaxConverter, # noqa F405
exir_ops.edge.aten.sigmoid.default: SigmoidConverter, # noqa F405
exir_ops.edge.aten.sub.Tensor: SubTensorConverter, # noqa F405
Expand Down
2 changes: 2 additions & 0 deletions backends/nxp/quantizer/neutron_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
ReshapePattern,
SharedSpecPattern,
SigmoidPattern,
SliceTensorPattern,
SoftMaxPattern,
SubTensorPattern,
TanhInPlacePattern,
Expand Down Expand Up @@ -214,6 +215,7 @@ def __init__(self, neutron_target_spec: NeutronTargetSpec):
NeutronAtenQuantizer(ReluInPlacePattern(), static_qconfig),
NeutronAtenQuantizer(ReshapePattern(), static_qconfig),
NeutronAtenQuantizer(SigmoidPattern(), static_qconfig),
NeutronAtenQuantizer(SliceTensorPattern(), static_qconfig),
NeutronAtenQuantizer(SoftMaxPattern(), static_qconfig),
NeutronAtenQuantizer(SubTensorPattern(), static_qconfig),
NeutronAtenQuantizer(TanhPattern(), static_qconfig),
Expand Down
9 changes: 9 additions & 0 deletions backends/nxp/quantizer/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,6 +736,15 @@ def partition_types(self):
return [torch.ops.aten.view.default]


class SliceTensorPattern(SharedSpecPattern):
"""
Quantizer for Slice operator.
"""

def partition_types(self):
return [torch.ops.aten.slice.Tensor]


class SoftMaxPattern(QuantizationPattern):
"""
Quantizer for Softmax operator.
Expand Down
Loading
Loading