@@ -44,16 +44,6 @@ def forward(self, x, y):
4444 return torch .bmm (x , y )
4545
4646
47- class MatMul (torch .nn .Module ):
48- test_data_generators = {
49- "rand_3d" : lambda : (torch .rand (2 , 3 , 5 ), torch .rand (2 , 5 , 2 )),
50- "rand_4d" : lambda : (torch .rand (1 , 2 , 3 , 5 ), torch .rand (1 , 2 , 5 , 2 )),
51- }
52-
53- def forward (self , x , y ):
54- return torch .matmul (x , y )
55-
56-
5747class BMMSingleInput (torch .nn .Module ):
5848 test_data_generators = {
5949 "rand_3d_1" : lambda : (torch .rand (20 , 3 , 3 ),),
@@ -81,26 +71,14 @@ def test_bmm_tosa_MI_single_input(test_data: input_t1):
8171 pipeline .run ()
8272
8373
84- @common .parametrize ("test_data" , MatMul .test_data_generators )
85- def test_mm_tosa_MI (test_data : input_t1 ):
86- pipeline = TosaPipelineMI [input_t1 ](MatMul (), test_data (), aten_op_mm , exir_op_mm )
87- pipeline .run ()
88-
89-
90- @common .parametrize ("test_data" , MatMul .test_data_generators )
91- def test_mm_tosa_BI (test_data : input_t1 ):
92- pipeline = TosaPipelineBI [input_t1 ](MatMul (), test_data (), aten_op_mm , exir_op_mm )
93- pipeline .run ()
94-
95-
96- @pytest .mark .flaky (reruns = 5 ) # TODO: Investigate flakyness (MLETORCH-534)
9774@common .parametrize ("test_data" , BMM .test_data_generators )
9875def test_bmm_tosa_BI (test_data : input_t1 ):
99- pipeline = TosaPipelineBI [input_t1 ](BMM (), test_data (), aten_op_bmm , exir_op_bmm )
76+ pipeline = TosaPipelineBI [input_t1 ](
77+ BMM (), test_data (), aten_op_bmm , exir_op_bmm , qtol = 1
78+ )
10079 pipeline .run ()
10180
10281
103- @pytest .mark .flaky (reruns = 5 ) # TODO: Investigate flakyness (MLETORCH-534)
10482@common .parametrize ("test_data" , BMMSingleInput .test_data_generators )
10583def test_bmm_tosa_BI_single_input (test_data : input_t1 ):
10684 pipeline = TosaPipelineBI [input_t1 ](
0 commit comments