@@ -173,14 +173,64 @@ def test_qnn_backend_arange(self):
173173 self .lower_module_and_test_output (module , sample_input )
174174
175175 def test_qnn_backend_argmax (self ):
176- module = Argmax () # noqa: F405
177- sample_input = (torch .randn (16 , 3 , 4 , 4 ),)
178- self .lower_module_and_test_output (module , sample_input )
176+ test_cases = [
177+ {
178+ QCOM_MODULE : Argmax (), # noqa: F405
179+ QCOM_SAMPLE_INPUTS : (torch .randn (16 , 3 , 4 , 4 ),),
180+ },
181+ {
182+ QCOM_MODULE : Argmax (dim = 0 , keepdim = True ), # noqa: F405
183+ QCOM_SAMPLE_INPUTS : (torch .randn (16 , 3 , 4 , 4 ),),
184+ },
185+ {
186+ QCOM_MODULE : Argmax (dim = 1 , keepdim = False ), # noqa: F405
187+ QCOM_SAMPLE_INPUTS : (torch .randn (8 , 5 ),),
188+ },
189+ {
190+ QCOM_MODULE : Argmax (dim = None , keepdim = False ), # noqa: F405
191+ QCOM_SAMPLE_INPUTS : (torch .tensor ([5.0 ]),),
192+ },
193+ {
194+ QCOM_MODULE : Argmax (dim = 2 , keepdim = True ), # noqa: F405
195+ QCOM_SAMPLE_INPUTS : (torch .randn (2 , 3 , 4 ),),
196+ },
197+ ]
198+
199+ for i , case in enumerate (test_cases ):
200+ with self .subTest (i = i ):
201+ self .lower_module_and_test_output (
202+ case [QCOM_MODULE ], case [QCOM_SAMPLE_INPUTS ]
203+ )
179204
180205 def test_qnn_backend_argmin (self ):
181- module = Argmin () # noqa: F405
182- sample_input = (torch .rand (3 , 4 ),)
183- self .lower_module_and_test_output (module , sample_input )
206+ test_cases = [
207+ {
208+ QCOM_MODULE : Argmin (), # noqa: F405
209+ QCOM_SAMPLE_INPUTS : (torch .randn (16 , 3 , 4 , 4 ),),
210+ },
211+ {
212+ QCOM_MODULE : Argmin (dim = 0 , keepdim = True ), # noqa: F405
213+ QCOM_SAMPLE_INPUTS : (torch .randn (16 , 3 , 4 , 4 ),),
214+ },
215+ {
216+ QCOM_MODULE : Argmin (dim = 1 , keepdim = False ), # noqa: F405
217+ QCOM_SAMPLE_INPUTS : (torch .randn (8 , 5 ),),
218+ },
219+ {
220+ QCOM_MODULE : Argmin (dim = None , keepdim = False ), # noqa: F405
221+ QCOM_SAMPLE_INPUTS : (torch .tensor ([5.0 ]),),
222+ },
223+ {
224+ QCOM_MODULE : Argmin (dim = 2 , keepdim = True ), # noqa: F405
225+ QCOM_SAMPLE_INPUTS : (torch .randn (2 , 3 , 4 ),),
226+ },
227+ ]
228+
229+ for i , case in enumerate (test_cases ):
230+ with self .subTest (i = i ):
231+ self .lower_module_and_test_output (
232+ case [QCOM_MODULE ], case [QCOM_SAMPLE_INPUTS ]
233+ )
184234
185235 @unittest .expectedFailure
186236 def test_qnn_backend_asin (self ):
@@ -1740,16 +1790,66 @@ def test_qnn_backend_arange(self):
17401790 self .lower_module_and_test_output (module , sample_input )
17411791
17421792 def test_qnn_backend_argmax (self ):
1743- module = Argmax () # noqa: F405
1744- sample_input = (torch .randn (16 , 3 , 4 , 4 ),)
1745- module = self .get_qdq_module (module , sample_input )
1746- self .lower_module_and_test_output (module , sample_input )
1793+ test_cases = [
1794+ {
1795+ QCOM_MODULE : Argmax (), # noqa: F405
1796+ QCOM_SAMPLE_INPUTS : (torch .randn (16 , 3 , 4 , 4 ),),
1797+ },
1798+ {
1799+ QCOM_MODULE : Argmax (dim = 0 , keepdim = True ), # noqa: F405
1800+ QCOM_SAMPLE_INPUTS : (torch .randn (16 , 3 , 4 , 4 ),),
1801+ },
1802+ {
1803+ QCOM_MODULE : Argmax (dim = 1 , keepdim = False ), # noqa: F405
1804+ QCOM_SAMPLE_INPUTS : (torch .randn (8 , 5 ),),
1805+ },
1806+ {
1807+ QCOM_MODULE : Argmax (dim = None , keepdim = False ), # noqa: F405
1808+ QCOM_SAMPLE_INPUTS : (torch .tensor ([5.0 ]),),
1809+ },
1810+ {
1811+ QCOM_MODULE : Argmax (dim = 2 , keepdim = True ), # noqa: F405
1812+ QCOM_SAMPLE_INPUTS : (torch .randn (2 , 3 , 4 ),),
1813+ },
1814+ ]
1815+
1816+ for i , case in enumerate (test_cases ):
1817+ with self .subTest (i = i ):
1818+ module = self .get_qdq_module (
1819+ case [QCOM_MODULE ], case [QCOM_SAMPLE_INPUTS ]
1820+ )
1821+ self .lower_module_and_test_output (module , case [QCOM_SAMPLE_INPUTS ])
17471822
17481823 def test_qnn_backend_argmin (self ):
1749- module = Argmin () # noqa: F405
1750- sample_input = (torch .randn (16 , 3 , 4 , 4 ),)
1751- module = self .get_qdq_module (module , sample_input )
1752- self .lower_module_and_test_output (module , sample_input )
1824+ test_cases = [
1825+ {
1826+ QCOM_MODULE : Argmin (), # noqa: F405
1827+ QCOM_SAMPLE_INPUTS : (torch .randn (16 , 3 , 4 , 4 ),),
1828+ },
1829+ {
1830+ QCOM_MODULE : Argmin (dim = 0 , keepdim = True ), # noqa: F405
1831+ QCOM_SAMPLE_INPUTS : (torch .randn (16 , 3 , 4 , 4 ),),
1832+ },
1833+ {
1834+ QCOM_MODULE : Argmin (dim = 1 , keepdim = False ), # noqa: F405
1835+ QCOM_SAMPLE_INPUTS : (torch .randn (8 , 5 ),),
1836+ },
1837+ {
1838+ QCOM_MODULE : Argmin (dim = None , keepdim = False ), # noqa: F405
1839+ QCOM_SAMPLE_INPUTS : (torch .tensor ([5.0 ]),),
1840+ },
1841+ {
1842+ QCOM_MODULE : Argmin (dim = 2 , keepdim = True ), # noqa: F405
1843+ QCOM_SAMPLE_INPUTS : (torch .randn (2 , 3 , 4 ),),
1844+ },
1845+ ]
1846+
1847+ for i , case in enumerate (test_cases ):
1848+ with self .subTest (i = i ):
1849+ module = self .get_qdq_module (
1850+ case [QCOM_MODULE ], case [QCOM_SAMPLE_INPUTS ]
1851+ )
1852+ self .lower_module_and_test_output (module , case [QCOM_SAMPLE_INPUTS ])
17531853
17541854 def test_qnn_backend_asin (self ):
17551855 module = Asin () # noqa: F405
0 commit comments