Skip to content

Commit 8b4c381

Browse files
pytorchbotcccclai
andauthored
support argmax/argmin without dim kwargs and fix adaptive_max_pool3d (#14868)
Summary: As title, in PyTorch, when dim is not set, it will flatten the input and get argmax as dim=0. Add a pass to reshape the input when dim is not set and consolidate test case edit: 1. Apply to argmin too 2. Add `exir_ops.edge.aten.adaptive_max_pool3d.default` to the to be implemented op list to pass the error Differential Revision: D83606497 Co-authored-by: cccclai <[email protected]>
1 parent 3bb1399 commit 8b4c381

File tree

8 files changed

+255
-21
lines changed

8 files changed

+255
-21
lines changed

backends/qualcomm/_passes/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from .i64_to_i32 import I64toI32
3232
from .insert_io_qdq import InsertIOQDQ
3333
from .insert_requantize import InsertRequantize
34+
from .insert_reshape_for_reduce_ops import InsertReshapeForReduceOps
3435
from .layout_transform import LayoutTransform
3536
from .lift_constant_scalar_operands import LiftConstantScalarOperands
3637
from .recompose_pixel_unshuffle import RecomposePixelUnshuffle
@@ -43,7 +44,6 @@
4344
from .seq_mse import SeqMSE
4445
from .tag_quant_io import TagQuantIO
4546

46-
4747
__all__ = [
4848
AnnotateAdaptiveAvgPool1D,
4949
AnnotateQuantAttrs,
@@ -71,6 +71,7 @@
7171
FuseConsecutiveTranspose,
7272
I64toI32,
7373
InsertIOQDQ,
74+
InsertReshapeForReduceOps,
7475
InsertRequantize,
7576
LayoutTransform,
7677
LiftConstantScalarOperands,
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.exir.pass_base import ExportPass, PassResult
9+
from executorch.exir.passes import dead_code_elimination_pass
10+
11+
12+
class InsertReshapeForReduceOps(ExportPass):
13+
"""
14+
Rewrite `aten.argmax.default` with `dim=None` into
15+
a reshape-to-1D followed by argmax(dim=0).
16+
17+
PyTorch semantics:
18+
torch.argmax(x, dim=None) -> flatten(x) then argmax along axis=0
19+
20+
QNN requires an explicit axis, so we insert the reshape.
21+
"""
22+
23+
def __init__(self):
24+
super().__init__()
25+
self.op_map = {torch.ops.aten.argmax.default, torch.ops.aten.argmin.default}
26+
27+
def call(self, graph_module: torch.fx.GraphModule):
28+
graph = graph_module.graph
29+
modified = False
30+
31+
for n in graph.nodes:
32+
if n.target in self.op_map:
33+
dim_arg = None if len(n.args) == 1 else n.args[1]
34+
35+
if dim_arg is None:
36+
inp = n.args[0]
37+
38+
# Insert reshape before argmax
39+
with graph.inserting_before(n):
40+
reshape_node = graph.create_node(
41+
"call_function",
42+
torch.ops.aten.reshape.default,
43+
(inp, [-1]),
44+
{},
45+
)
46+
reshape_node.meta = dict(inp.meta)
47+
if "val" in inp.meta:
48+
reshape_node.meta["val"] = inp.meta["val"].reshape(-1)
49+
50+
# Rewrite argmax: take reshape_node as input, set dim=0
51+
n.args = (reshape_node, 0, *n.args[2:])
52+
53+
modified = True
54+
55+
if modified:
56+
graph_module.recompile()
57+
dead_code_elimination_pass(graph_module)
58+
59+
return PassResult(graph_module, modified)

backends/qualcomm/_passes/qnn_pass_manager.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
I64toI32,
3737
InsertIOQDQ,
3838
InsertRequantize,
39+
InsertReshapeForReduceOps,
3940
LayoutTransform,
4041
LiftConstantScalarOperands,
4142
RecomposePixelUnshuffle,
@@ -205,6 +206,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
205206
self.add_pass(DecomposeLinalgVectorNorm(quantization_capture=True))
206207
self.add_pass(ReplaceInfValues())
207208
self.add_pass(LiftConstantScalarOperands())
209+
self.add_pass(InsertReshapeForReduceOps())
208210
return self._transform(graph_module)
209211

210212
def transform_for_export_pipeline(
@@ -224,6 +226,7 @@ def transform_for_export_pipeline(
224226
self.add_pass(ConvertLinearToConv2d(exported_program))
225227
self.add_pass(ConvertSquareToPow())
226228
self.add_pass(LiftConstantScalarOperands())
229+
self.add_pass(InsertReshapeForReduceOps())
227230
self._transform(exported_program.graph_module)
228231
ep = lift_constant_tensor_pass(exported_program)
229232
return ep

backends/qualcomm/partition/common_defs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
to_be_implemented_operator = [
1818
exir_ops.edge.aten._adaptive_avg_pool3d.default,
1919
exir_ops.edge.aten.adaptive_max_pool2d.default,
20+
exir_ops.edge.aten.adaptive_max_pool3d.default,
2021
exir_ops.edge.aten.avg_pool3d.default,
2122
exir_ops.edge.aten.div.Tensor_mode,
2223
exir_ops.edge.aten.log10.default,

backends/qualcomm/tests/TARGETS

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,17 @@ runtime.python_library(
4747
":test_qnn_delegate"
4848
]
4949
)
50+
51+
runtime.python_test(
52+
name = "test_passes",
53+
srcs = [
54+
"test_passes.py",
55+
],
56+
deps = [
57+
"fbsource//third-party/pypi/expecttest:expecttest", # @manual
58+
"//caffe2:torch",
59+
"//executorch/exir:lib",
60+
"//executorch/backends/qualcomm/_passes:passes",
61+
"//executorch/backends/qualcomm/builders:builders",
62+
],
63+
)

backends/qualcomm/tests/models.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -171,21 +171,23 @@ def forward(self, y):
171171

172172

173173
class Argmax(torch.nn.Module):
174-
def __init__(self):
174+
def __init__(self, dim: Optional[int] = None, keepdim: bool = False):
175175
super().__init__()
176+
self.dim = dim
177+
self.keepdim = keepdim
176178

177179
def forward(self, x):
178-
x = torch.argmax(x, dim=0, keepdim=True)
179-
return x
180+
return torch.argmax(x, dim=self.dim, keepdim=self.keepdim)
180181

181182

182183
class Argmin(torch.nn.Module):
183-
def __init__(self):
184+
def __init__(self, dim: Optional[int] = None, keepdim: bool = False):
184185
super().__init__()
186+
self.dim = dim
187+
self.keepdim = keepdim
185188

186189
def forward(self, x):
187-
x = torch.argmin(x, dim=0, keepdim=True)
188-
return x
190+
return torch.argmin(x, dim=self.dim, keepdim=self.keepdim)
189191

190192

191193
class ArgminViewSqueezeConv2D(torch.nn.Module):
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import unittest
2+
3+
import torch
4+
from executorch.backends.qualcomm._passes import InsertReshapeForReduceOps
5+
6+
7+
class TestPasses(unittest.TestCase):
8+
def test_insert_reshape_for_argmax(self):
9+
class ArgmaxModule(torch.nn.Module):
10+
def forward(self, x):
11+
return torch.argmax(x, dim=None)
12+
13+
mod = ArgmaxModule()
14+
15+
x = torch.tensor([[1.0, 5.0], [3.0, 2.0]])
16+
ep = torch.export.export(mod, (x,))
17+
# Run original module for reference
18+
ref = mod(x)
19+
20+
reshape_nodes = [
21+
n for n in ep.graph.nodes if n.target == torch.ops.aten.reshape.default
22+
]
23+
argmax_nodes = [
24+
n for n in ep.graph.nodes if n.target == torch.ops.aten.argmax.default
25+
]
26+
self.assertTrue(len(reshape_nodes) == 0, "Reshape node not inserted")
27+
self.assertTrue(len(argmax_nodes) == 1, "Argmax node missing")
28+
29+
InsertReshapeForReduceOps()(ep.graph_module)
30+
31+
out = ep.graph_module(x)
32+
33+
# Check graph structure: argmax should take a reshape as input
34+
reshape_nodes = [
35+
n for n in ep.graph.nodes if n.target == torch.ops.aten.reshape.default
36+
]
37+
argmax_nodes = [
38+
n for n in ep.graph.nodes if n.target == torch.ops.aten.argmax.default
39+
]
40+
self.assertTrue(len(reshape_nodes) == 1, "Reshape node should be inserted")
41+
self.assertTrue(len(argmax_nodes) == 1, "Argmax node missing")
42+
43+
argmax_node = argmax_nodes[0]
44+
self.assertEqual(argmax_node.args[1], 0, "Argmax dim not set to 0")
45+
46+
# Execute new graph and compare with reference
47+
out = ep.graph_module(x)
48+
self.assertTrue(
49+
torch.equal(*out, ref), f"Output mismatch: got {out}, expected {ref}"
50+
)
51+
52+
53+
if __name__ == "__main__":
54+
unittest.main()

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 114 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -173,14 +173,64 @@ def test_qnn_backend_arange(self):
173173
self.lower_module_and_test_output(module, sample_input)
174174

175175
def test_qnn_backend_argmax(self):
176-
module = Argmax() # noqa: F405
177-
sample_input = (torch.randn(16, 3, 4, 4),)
178-
self.lower_module_and_test_output(module, sample_input)
176+
test_cases = [
177+
{
178+
QCOM_MODULE: Argmax(), # noqa: F405
179+
QCOM_SAMPLE_INPUTS: (torch.randn(16, 3, 4, 4),),
180+
},
181+
{
182+
QCOM_MODULE: Argmax(dim=0, keepdim=True), # noqa: F405
183+
QCOM_SAMPLE_INPUTS: (torch.randn(16, 3, 4, 4),),
184+
},
185+
{
186+
QCOM_MODULE: Argmax(dim=1, keepdim=False), # noqa: F405
187+
QCOM_SAMPLE_INPUTS: (torch.randn(8, 5),),
188+
},
189+
{
190+
QCOM_MODULE: Argmax(dim=None, keepdim=False), # noqa: F405
191+
QCOM_SAMPLE_INPUTS: (torch.tensor([5.0]),),
192+
},
193+
{
194+
QCOM_MODULE: Argmax(dim=2, keepdim=True), # noqa: F405
195+
QCOM_SAMPLE_INPUTS: (torch.randn(2, 3, 4),),
196+
},
197+
]
198+
199+
for i, case in enumerate(test_cases):
200+
with self.subTest(i=i):
201+
self.lower_module_and_test_output(
202+
case[QCOM_MODULE], case[QCOM_SAMPLE_INPUTS]
203+
)
179204

180205
def test_qnn_backend_argmin(self):
181-
module = Argmin() # noqa: F405
182-
sample_input = (torch.rand(3, 4),)
183-
self.lower_module_and_test_output(module, sample_input)
206+
test_cases = [
207+
{
208+
QCOM_MODULE: Argmin(), # noqa: F405
209+
QCOM_SAMPLE_INPUTS: (torch.randn(16, 3, 4, 4),),
210+
},
211+
{
212+
QCOM_MODULE: Argmin(dim=0, keepdim=True), # noqa: F405
213+
QCOM_SAMPLE_INPUTS: (torch.randn(16, 3, 4, 4),),
214+
},
215+
{
216+
QCOM_MODULE: Argmin(dim=1, keepdim=False), # noqa: F405
217+
QCOM_SAMPLE_INPUTS: (torch.randn(8, 5),),
218+
},
219+
{
220+
QCOM_MODULE: Argmin(dim=None, keepdim=False), # noqa: F405
221+
QCOM_SAMPLE_INPUTS: (torch.tensor([5.0]),),
222+
},
223+
{
224+
QCOM_MODULE: Argmin(dim=2, keepdim=True), # noqa: F405
225+
QCOM_SAMPLE_INPUTS: (torch.randn(2, 3, 4),),
226+
},
227+
]
228+
229+
for i, case in enumerate(test_cases):
230+
with self.subTest(i=i):
231+
self.lower_module_and_test_output(
232+
case[QCOM_MODULE], case[QCOM_SAMPLE_INPUTS]
233+
)
184234

185235
@unittest.expectedFailure
186236
def test_qnn_backend_asin(self):
@@ -1740,16 +1790,66 @@ def test_qnn_backend_arange(self):
17401790
self.lower_module_and_test_output(module, sample_input)
17411791

17421792
def test_qnn_backend_argmax(self):
1743-
module = Argmax() # noqa: F405
1744-
sample_input = (torch.randn(16, 3, 4, 4),)
1745-
module = self.get_qdq_module(module, sample_input)
1746-
self.lower_module_and_test_output(module, sample_input)
1793+
test_cases = [
1794+
{
1795+
QCOM_MODULE: Argmax(), # noqa: F405
1796+
QCOM_SAMPLE_INPUTS: (torch.randn(16, 3, 4, 4),),
1797+
},
1798+
{
1799+
QCOM_MODULE: Argmax(dim=0, keepdim=True), # noqa: F405
1800+
QCOM_SAMPLE_INPUTS: (torch.randn(16, 3, 4, 4),),
1801+
},
1802+
{
1803+
QCOM_MODULE: Argmax(dim=1, keepdim=False), # noqa: F405
1804+
QCOM_SAMPLE_INPUTS: (torch.randn(8, 5),),
1805+
},
1806+
{
1807+
QCOM_MODULE: Argmax(dim=None, keepdim=False), # noqa: F405
1808+
QCOM_SAMPLE_INPUTS: (torch.tensor([5.0]),),
1809+
},
1810+
{
1811+
QCOM_MODULE: Argmax(dim=2, keepdim=True), # noqa: F405
1812+
QCOM_SAMPLE_INPUTS: (torch.randn(2, 3, 4),),
1813+
},
1814+
]
1815+
1816+
for i, case in enumerate(test_cases):
1817+
with self.subTest(i=i):
1818+
module = self.get_qdq_module(
1819+
case[QCOM_MODULE], case[QCOM_SAMPLE_INPUTS]
1820+
)
1821+
self.lower_module_and_test_output(module, case[QCOM_SAMPLE_INPUTS])
17471822

17481823
def test_qnn_backend_argmin(self):
1749-
module = Argmin() # noqa: F405
1750-
sample_input = (torch.randn(16, 3, 4, 4),)
1751-
module = self.get_qdq_module(module, sample_input)
1752-
self.lower_module_and_test_output(module, sample_input)
1824+
test_cases = [
1825+
{
1826+
QCOM_MODULE: Argmin(), # noqa: F405
1827+
QCOM_SAMPLE_INPUTS: (torch.randn(16, 3, 4, 4),),
1828+
},
1829+
{
1830+
QCOM_MODULE: Argmin(dim=0, keepdim=True), # noqa: F405
1831+
QCOM_SAMPLE_INPUTS: (torch.randn(16, 3, 4, 4),),
1832+
},
1833+
{
1834+
QCOM_MODULE: Argmin(dim=1, keepdim=False), # noqa: F405
1835+
QCOM_SAMPLE_INPUTS: (torch.randn(8, 5),),
1836+
},
1837+
{
1838+
QCOM_MODULE: Argmin(dim=None, keepdim=False), # noqa: F405
1839+
QCOM_SAMPLE_INPUTS: (torch.tensor([5.0]),),
1840+
},
1841+
{
1842+
QCOM_MODULE: Argmin(dim=2, keepdim=True), # noqa: F405
1843+
QCOM_SAMPLE_INPUTS: (torch.randn(2, 3, 4),),
1844+
},
1845+
]
1846+
1847+
for i, case in enumerate(test_cases):
1848+
with self.subTest(i=i):
1849+
module = self.get_qdq_module(
1850+
case[QCOM_MODULE], case[QCOM_SAMPLE_INPUTS]
1851+
)
1852+
self.lower_module_and_test_output(module, case[QCOM_SAMPLE_INPUTS])
17531853

17541854
def test_qnn_backend_asin(self):
17551855
module = Asin() # noqa: F405

0 commit comments

Comments
 (0)