Skip to content

Commit edb2d3f

Browse files
committed
NXP backend: Add conversion and quantization support for dim_order_ops._clone_dim_order.default
1 parent df5bfd5 commit edb2d3f

File tree

5 files changed

+129
-54
lines changed

5 files changed

+129
-54
lines changed

backends/nxp/backend/edge_program_converter.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
exir_ops.edge.aten.avg_pool2d.default: AvgPool2dConverter, # noqa F405
3434
exir_ops.edge.aten.cat.default: CatConverter, # noqa F405
3535
exir_ops.edge.aten.clone.default: CloneConverter, # noqa F405
36+
exir_ops.edge.dim_order_ops._clone_dim_order.default: CloneConverter, # noqa F405
3637
exir_ops.edge.aten.constant_pad_nd.default: ConstantPadNDConverter, # noqa F405
3738
exir_ops.edge.aten.convolution.default: ConvolutionConverter, # noqa F405
3839
exir_ops.edge.aten.hardtanh.default: HardTanhConverter, # noqa F405

backends/nxp/backend/ir/converter/node_converters/ops_converters/clone_converter.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,11 @@ def _has_supported_memory_format(node: Node) -> bool:
2020

2121

2222
class CloneConverter(NodeConverter):
23+
"""
24+
This converter is responsible for converting both edge operators:
25+
- aten.clone.default
26+
- dim_order_ops._clone_dim_order.default
27+
"""
2328

2429
@staticmethod
2530
def _is_supported_in_IR(

backends/nxp/neutron_partitioner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ def tag_qdq_clusters(self, nodes: List[torch.fx.Node]):
197197
exir_ops.edge.aten.avg_pool2d.default: AvgPool2dConverter, # noqa F405
198198
exir_ops.edge.aten.cat.default: CatConverter, # noqa F405
199199
exir_ops.edge.aten.clone.default: CloneConverter, # noqa F405
200+
exir_ops.edge.dim_order_ops._clone_dim_order.default: CloneConverter, # noqa F405
200201
exir_ops.edge.aten.constant_pad_nd.default: ConstantPadNDConverter, # noqa F405
201202
exir_ops.edge.aten.convolution.default: ConvolutionConverter, # noqa F405
202203
exir_ops.edge.aten.hardtanh.default: HardTanhConverter, # noqa F405

backends/nxp/tests/executors.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,13 @@ def convert_run_compare(
370370

371371

372372
def graph_contains_any_of_ops(graph: Graph, ops: list) -> bool:
373-
return any(node.target in ops for node in graph.nodes)
373+
return graph_contains_any(
374+
graph, condition=lambda n: hasattr(n, "target") and n.target in ops
375+
)
376+
377+
378+
def graph_contains_any(graph: Graph, condition: Callable[[Node], bool]) -> bool:
379+
return any(map(condition, graph.nodes))
374380

375381

376382
target_support_check_function = Callable[[Node, Target], bool]

backends/nxp/tests/ir/converter/node_converter/test_clone_converter.py

Lines changed: 115 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -4,31 +4,33 @@
44
# LICENSE file in the root directory of this source tree.
55

66

7+
import itertools
8+
import unittest
9+
10+
import kgb
711
import numpy as np
8-
import pytest
912
import torch
1013

1114
from executorch.backends.nxp.backend.edge_program_converter import (
1215
EdgeProgramToIRConverter,
1316
)
14-
from executorch.backends.nxp.tests.executorch_pipeline import to_quantized_edge_program
17+
from executorch.backends.nxp.tests.executorch_pipeline import (
18+
to_edge_program,
19+
to_quantized_edge_program,
20+
)
1521
from executorch.backends.nxp.tests.executors import (
1622
convert_run_compare,
23+
graph_contains_any,
1724
graph_contains_any_of_ops,
1825
ToNCHWPreprocess,
1926
ToNHWCPreprocess,
2027
)
2128
from executorch.exir.dialects._ops import ops as exir_ops
29+
from parameterized import parameterized
2230
from torch import nn
2331
from torch.export import ExportedProgram
2432

2533

26-
@pytest.fixture(autouse=True)
27-
def reseed_model_per_test_run():
28-
torch.manual_seed(23)
29-
np.random.seed(23)
30-
31-
3234
class SingleConvBlockWithDropout(torch.nn.Module):
3335
def __init__(
3436
self, conv_in_channels: int = 3, perform_inplace_dropout: bool = False
@@ -74,57 +76,117 @@ def forward(self, x):
7476
return self.block(x)
7577

7678

77-
@pytest.mark.parametrize("inplace_dropout", [False, True])
78-
@pytest.mark.parametrize("input_shape", [(1, 3, 128, 128), (1, 3, 256, 256)])
79-
def test_conv_dropout_quant(mocker, inplace_dropout: bool, input_shape: tuple[int]):
80-
model = SingleConvBlockWithDropout(
81-
conv_in_channels=input_shape[1], perform_inplace_dropout=inplace_dropout
82-
).eval()
79+
class TestCloneConverter(unittest.TestCase):
80+
__test__ = False # Prevent interfering with PyTest tests
8381

84-
converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program")
82+
@staticmethod
83+
def _node_is_clone(node) -> bool:
84+
clone_ops = [
85+
exir_ops.edge.aten.clone.default,
86+
exir_ops.edge.dim_order_ops._clone_dim_order.default,
87+
]
8588

86-
quantized_program = to_quantized_edge_program(model, input_shape).exported_program()
89+
def target_can_be_clone(node):
90+
if hasattr(node, "op") and node.op == "call_function":
91+
return "clone" in node.target.__name__
8792

88-
tflite_flatbuffers_model, io_formats = converter_spy.spy_return
89-
exported_program: ExportedProgram = converter_spy.call_args.args[1]
93+
return False
9094

91-
assert not graph_contains_any_of_ops(
92-
graph=quantized_program.graph, ops=[exir_ops.edge.aten.clone.default]
93-
)
95+
return node in clone_ops or target_can_be_clone(node)
9496

95-
input_data = (np.random.random(input_shape) * 50).astype(np.int8)
96-
convert_run_compare(
97-
exported_program,
98-
tfl_model=tflite_flatbuffers_model,
99-
tflite_input_preprocess=ToNHWCPreprocess(),
100-
tflite_output_preprocess=ToNCHWPreprocess(),
101-
input_data=input_data,
102-
atol=1.0,
97+
@parameterized.expand(
98+
list(itertools.product([True, False], [(1, 3, 128, 128), (1, 3, 256, 256)]))
10399
)
100+
def test_conv_dropout_quant(self, inplace_dropout: bool, input_shape: tuple[int]):
101+
model = SingleConvBlockWithDropout(
102+
conv_in_channels=input_shape[1], perform_inplace_dropout=inplace_dropout
103+
).eval()
104+
105+
with kgb.spy_on(
106+
EdgeProgramToIRConverter.convert_program, call_original=True
107+
) as converter_spy:
108+
quantized_program = to_quantized_edge_program(
109+
model, input_shape
110+
).exported_program()
111+
112+
tflite_flatbuffers_model, _ = converter_spy.calls[-1].return_value
113+
exported_program: ExportedProgram = converter_spy.calls[-1].args[0]
114+
115+
assert not graph_contains_any(
116+
graph=quantized_program.graph,
117+
condition=TestCloneConverter._node_is_clone,
118+
)
119+
120+
input_data = (np.random.random(input_shape) * 50).astype(np.int8)
121+
convert_run_compare(
122+
exported_program,
123+
tfl_model=tflite_flatbuffers_model,
124+
tflite_input_preprocess=ToNHWCPreprocess(),
125+
tflite_output_preprocess=ToNCHWPreprocess(),
126+
input_data=input_data,
127+
atol=1.0,
128+
)
129+
130+
@parameterized.expand(
131+
list(itertools.product([True, False], [(1, 3, 128, 128), (1, 3, 256, 256)]))
132+
)
133+
def test_conv_dropout_no_quant(
134+
self, inplace_dropout: bool, input_shape: tuple[int]
135+
):
136+
model = SingleConvBlockWithDropout(
137+
conv_in_channels=input_shape[1], perform_inplace_dropout=inplace_dropout
138+
).eval()
139+
140+
edge_program = to_edge_program(model, input_shape).exported_program()
141+
142+
has_clone = graph_contains_any_of_ops(
143+
graph=edge_program.graph,
144+
ops=[
145+
exir_ops.edge.aten.clone.default,
146+
exir_ops.edge.dim_order_ops._clone_dim_order.default,
147+
],
148+
)
104149

150+
# Clone with inplace=True should not produce clone edge op and vice versa
151+
assert inplace_dropout ^ has_clone
105152

106-
@pytest.mark.parametrize("inplace_dropout", [False, True])
107-
def test_clone_pool_view_copy_quant(
108-
mocker, inplace_dropout: bool, input_shape: tuple[int] = (1, 64, 25, 5)
109-
):
110-
model = KWSFinalBlock(input_shape).eval()
111-
112-
converter_spy = mocker.spy(EdgeProgramToIRConverter, "convert_program")
113-
114-
quantized_program = to_quantized_edge_program(model, input_shape).exported_program()
115-
116-
tflite_flatbuffers_model, io_formats = converter_spy.spy_return
117-
exported_program: ExportedProgram = converter_spy.call_args.args[1]
118-
119-
assert not graph_contains_any_of_ops(
120-
graph=quantized_program.graph, ops=[exir_ops.edge.aten.clone.default]
121-
)
153+
input_data = np.random.random(input_shape).astype(np.float32)
154+
convert_run_compare(
155+
edge_program,
156+
input_data,
157+
tflite_input_preprocess=ToNHWCPreprocess(),
158+
tflite_output_preprocess=ToNCHWPreprocess(),
159+
atol=1.0e-7,
160+
)
122161

123-
input_data = (np.random.random(input_shape) * 50).astype(np.int8)
124-
convert_run_compare(
125-
exported_program,
126-
tfl_model=tflite_flatbuffers_model,
127-
tflite_input_preprocess=ToNHWCPreprocess(),
128-
input_data=input_data,
129-
atol=1.0,
130-
)
162+
def test_clone_pool_view_copy_quant(self, input_shape: tuple[int] = (1, 64, 25, 5)):
163+
model = KWSFinalBlock(input_shape).eval()
164+
165+
with kgb.spy_on(
166+
EdgeProgramToIRConverter.convert_program, call_original=True
167+
) as converter_spy:
168+
quantized_program = to_quantized_edge_program(
169+
model, input_shape
170+
).exported_program()
171+
172+
tflite_flatbuffers_model, _ = converter_spy.calls[-1].return_value
173+
exported_program: ExportedProgram = converter_spy.calls[-1].args[0]
174+
175+
assert not graph_contains_any(
176+
graph=quantized_program.graph,
177+
condition=TestCloneConverter._node_is_clone,
178+
)
179+
180+
input_data = (np.random.random(input_shape) * 50).astype(np.int8)
181+
convert_run_compare(
182+
exported_program,
183+
tfl_model=tflite_flatbuffers_model,
184+
tflite_input_preprocess=ToNHWCPreprocess(),
185+
input_data=input_data,
186+
atol=1.0,
187+
)
188+
189+
@classmethod
190+
def setUpClass(cls):
191+
torch.manual_seed(23)
192+
np.random.seed(23)

0 commit comments

Comments
 (0)