Skip to content

Commit afd98fe

Browse files
authored
NXP backend: Add conversion and quantization support for dim_order_ops._clone_dim_order.default (#14535)
### Summary - Adds support for conversion and quantization of `dim_order_ops._clone_dim_order.default` operator and fixes problems with some variations of `nn.Dropout`. - Adds more robust test cases for clone operators. ### Test plan All changes should be covered by unit tests. cc @robert-kalmar @JakeStevens @digantdesai
1 parent f32cdc3 commit afd98fe

File tree

5 files changed

+123
-57
lines changed

5 files changed

+123
-57
lines changed

backends/nxp/backend/edge_program_converter.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
exir_ops.edge.aten.avg_pool2d.default: AvgPool2dConverter, # noqa F405
3535
exir_ops.edge.aten.cat.default: CatConverter, # noqa F405
3636
exir_ops.edge.aten.clone.default: CloneConverter, # noqa F405
37+
exir_ops.edge.dim_order_ops._clone_dim_order.default: CloneConverter, # noqa F405
3738
exir_ops.edge.aten.constant_pad_nd.default: ConstantPadNDConverter, # noqa F405
3839
exir_ops.edge.aten.convolution.default: ConvolutionConverter, # noqa F405
3940
exir_ops.edge.aten.hardtanh.default: HardTanhConverter, # noqa F405

backends/nxp/backend/ir/converter/node_converters/ops_converters/clone_converter.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,11 @@ def _has_supported_memory_format(node: Node) -> bool:
2020

2121

2222
class CloneConverter(NodeConverter):
23+
"""
24+
This converter is responsible for converting both edge operators:
25+
- aten.clone.default
26+
- dim_order_ops._clone_dim_order.default
27+
"""
2328

2429
@staticmethod
2530
def _is_supported_in_IR(

backends/nxp/neutron_partitioner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ def tag_qdq_clusters(self, nodes: list[torch.fx.Node]):
201201
exir_ops.edge.aten.avg_pool2d.default: AvgPool2dConverter, # noqa F405
202202
exir_ops.edge.aten.cat.default: CatConverter, # noqa F405
203203
exir_ops.edge.aten.clone.default: CloneConverter, # noqa F405
204+
exir_ops.edge.dim_order_ops._clone_dim_order.default: CloneConverter, # noqa F405
204205
exir_ops.edge.aten.constant_pad_nd.default: ConstantPadNDConverter, # noqa F405
205206
exir_ops.edge.aten.convolution.default: ConvolutionConverter, # noqa F405
206207
exir_ops.edge.aten.hardtanh.default: HardTanhConverter, # noqa F405

backends/nxp/tests/executors.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,13 @@ def convert_run_compare(
368368

369369

370370
def graph_contains_any_of_ops(graph: Graph, ops: list) -> bool:
371-
return any(node.target in ops for node in graph.nodes)
371+
return graph_contains_any(
372+
graph, condition=lambda n: hasattr(n, "target") and n.target in ops
373+
)
374+
375+
376+
def graph_contains_any(graph: Graph, condition: Callable[[Node], bool]) -> bool:
377+
return any(map(condition, graph.nodes))
372378

373379

374380
target_support_check_function = Callable[[Node, NeutronTargetSpec], bool]

backends/nxp/tests/ir/converter/node_converter/test_clone_converter.py

Lines changed: 109 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -4,31 +4,33 @@
44
# LICENSE file in the root directory of this source tree.
55

66

7+
import itertools
8+
import unittest
9+
10+
import kgb
711
import numpy as np
8-
import pytest
912
import torch
1013

1114
from executorch.backends.nxp.backend.edge_program_converter import (
1215
EdgeProgramToIRConverter,
1316
)
14-
from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program
17+
from executorch.backends.nxp.tests.executorch_pipeline import (
18+
to_edge_program,
19+
to_quantized_edge_program,
20+
)
1521
from executorch.backends.nxp.tests.executors import (
1622
convert_run_compare,
23+
graph_contains_any,
1724
graph_contains_any_of_ops,
18-
ToNCHWPreprocess,
19-
ToNHWCPreprocess,
25+
ToChannelFirstPreprocess,
26+
ToChannelLastPreprocess,
2027
)
2128
from executorch.exir.dialects._ops import ops as exir_ops
29+
from parameterized import parameterized
2230
from torch import nn
2331
from torch.export import ExportedProgram
2432

2533

26-
@pytest.fixture(autouse=True)
27-
def reseed_model_per_test_run():
28-
torch.manual_seed(23)
29-
np.random.seed(23)
30-
31-
3234
class SingleConvBlockWithDropout(torch.nn.Module):
3335
def __init__(
3436
self, conv_in_channels: int = 3, perform_inplace_dropout: bool = False
@@ -74,57 +76,108 @@ def forward(self, x):
7476
return self.block(x)
7577

7678

77-
@pytest.mark.parametrize("inplace_dropout", [False, True])
78-
@pytest.mark.parametrize("input_shape", [(1, 3, 128, 128), (1, 3, 256, 256)])
79-
def test_conv_dropout_quant(mocker, inplace_dropout: bool, input_shape: tuple[int]):
80-
model = SingleConvBlockWithDropout(
81-
conv_in_channels=input_shape[1], perform_inplace_dropout=inplace_dropout
82-
).eval()
79+
class TestCloneConverter(unittest.TestCase):
80+
__test__ = False # Prevent interfering with PyTest tests
8381

84-
converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program")
82+
@classmethod
83+
def setUpClass(cls):
84+
torch.manual_seed(23)
85+
np.random.seed(23)
8586

86-
quantized_program = to_quantized_edge_program(model, input_shape).exported_program()
87+
@staticmethod
88+
def _node_is_clone(node) -> bool:
89+
clone_ops = [
90+
exir_ops.edge.aten.clone.default,
91+
exir_ops.edge.dim_order_ops._clone_dim_order.default,
92+
]
8793

88-
tflite_flatbuffers_model, io_formats = converter_spy.spy_return
89-
exported_program: ExportedProgram = converter_spy.call_args.args[1]
90-
91-
assert not graph_contains_any_of_ops(
92-
graph=quantized_program.graph, ops=[exir_ops.edge.aten.clone.default]
93-
)
94-
95-
input_data = (np.random.random(input_shape) * 50).astype(np.int8)
96-
convert_run_compare(
97-
exported_program,
98-
tfl_model=tflite_flatbuffers_model,
99-
tflite_input_preprocess=ToNHWCPreprocess(),
100-
tflite_output_preprocess=ToNCHWPreprocess(),
101-
input_data=input_data,
102-
atol=1.0,
103-
)
94+
def target_can_be_clone(node):
95+
if hasattr(node, "op") and node.op == "call_function":
96+
return "clone" in node.target.__name__
10497

98+
return False
10599

106-
@pytest.mark.parametrize("inplace_dropout", [False, True])
107-
def test_clone_pool_view_copy_quant(
108-
mocker, inplace_dropout: bool, input_shape: tuple[int] = (1, 64, 25, 5)
109-
):
110-
model = KWSFinalBlock(input_shape).eval()
100+
return node in clone_ops or target_can_be_clone(node)
111101

112-
converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program")
113-
114-
quantized_program = to_quantized_edge_program(model, input_shape).exported_program()
115-
116-
tflite_flatbuffers_model, io_formats = converter_spy.spy_return
117-
exported_program: ExportedProgram = converter_spy.call_args.args[1]
118-
119-
assert not graph_contains_any_of_ops(
120-
graph=quantized_program.graph, ops=[exir_ops.edge.aten.clone.default]
102+
@parameterized.expand(
103+
list(itertools.product([True, False], [(1, 3, 128, 128), (1, 3, 256, 256)]))
121104
)
122-
123-
input_data = (np.random.random(input_shape) * 50).astype(np.int8)
124-
convert_run_compare(
125-
exported_program,
126-
tfl_model=tflite_flatbuffers_model,
127-
tflite_input_preprocess=ToNHWCPreprocess(),
128-
input_data=input_data,
129-
atol=1.0,
105+
def test_conv_dropout_quant(self, inplace_dropout: bool, input_shape: tuple[int]):
106+
model = SingleConvBlockWithDropout(
107+
conv_in_channels=input_shape[1], perform_inplace_dropout=inplace_dropout
108+
).eval()
109+
110+
with kgb.spy_on(
111+
EdgeProgramToIRConverter.convert_program, call_original=True
112+
) as converter_spy:
113+
quantized_program = to_quantized_edge_program(
114+
model, input_shape
115+
).exported_program()
116+
117+
tflite_flatbuffers_model, _ = converter_spy.calls[-1].return_value
118+
exported_program: ExportedProgram = converter_spy.calls[-1].args[0]
119+
120+
assert not graph_contains_any(
121+
graph=quantized_program.graph,
122+
condition=TestCloneConverter._node_is_clone,
123+
)
124+
125+
input_data = (np.random.random(input_shape) * 50).astype(np.int8)
126+
convert_run_compare(
127+
exported_program,
128+
tfl_model=tflite_flatbuffers_model,
129+
tflite_input_preprocess=ToChannelLastPreprocess(),
130+
tflite_output_preprocess=ToChannelFirstPreprocess(),
131+
input_data=input_data,
132+
atol=1.0,
133+
)
134+
135+
@parameterized.expand(
136+
list(itertools.product([True, False], [(1, 3, 128, 128), (1, 3, 256, 256)]))
130137
)
138+
def test_conv_dropout_no_quant(
139+
self, inplace_dropout: bool, input_shape: tuple[int]
140+
):
141+
model = SingleConvBlockWithDropout(
142+
conv_in_channels=input_shape[1], perform_inplace_dropout=inplace_dropout
143+
).eval()
144+
145+
edge_program = to_edge_program(model, input_shape).exported_program()
146+
147+
has_clone = graph_contains_any_of_ops(
148+
graph=edge_program.graph,
149+
ops=[
150+
exir_ops.edge.aten.clone.default,
151+
exir_ops.edge.dim_order_ops._clone_dim_order.default,
152+
],
153+
)
154+
155+
# Clone with inplace=True should not produce clone edge op and vice versa
156+
assert inplace_dropout ^ has_clone
157+
158+
def test_clone_pool_view_copy_quant(self, input_shape: tuple[int] = (1, 64, 25, 5)):
159+
model = KWSFinalBlock(input_shape).eval()
160+
161+
with kgb.spy_on(
162+
EdgeProgramToIRConverter.convert_program, call_original=True
163+
) as converter_spy:
164+
quantized_program = to_quantized_edge_program(
165+
model, input_shape
166+
).exported_program()
167+
168+
tflite_flatbuffers_model, _ = converter_spy.calls[-1].return_value
169+
exported_program: ExportedProgram = converter_spy.calls[-1].args[0]
170+
171+
assert not graph_contains_any(
172+
graph=quantized_program.graph,
173+
condition=TestCloneConverter._node_is_clone,
174+
)
175+
176+
input_data = (np.random.random(input_shape) * 50).astype(np.int8)
177+
convert_run_compare(
178+
exported_program,
179+
tfl_model=tflite_flatbuffers_model,
180+
tflite_input_preprocess=ToChannelLastPreprocess(),
181+
input_data=input_data,
182+
atol=1.0,
183+
)

0 commit comments

Comments
 (0)