@@ -1040,6 +1040,116 @@ def test_quantized_conv_per_tensor(
10401040                f"Output values don't match expected. Got { output }  , expected { expected_output }  " ,
10411041            )
10421042
1043+     @expand ( 
1044+         [ 
1045+             # Test case 1: Basic 1D convolution with int8 weights  
1046+             ( 
1047+                 "basic_int8_weights" , 
1048+                 torch .tensor ( 
1049+                     [[[1.0 , 2.0 , 3.0 , 4.0 , 5.0 ]]], dtype = torch .float32  
1050+                 ),  # src: 1x1x5  
1051+                 torch .tensor ([[[1 , - 1 , 2 ]]], dtype = torch .int8 ),  # weight: 1x1x3  
1052+                 0.1 ,  # w_scale  
1053+                 torch .tensor ([1 ], dtype = torch .int8 ),  # bias: 1  
1054+                 0.2 ,  # b_scale  
1055+                 torch .tensor ( 
1056+                     [[[0.7 , 0.9 , 1.1 ]]], dtype = torch .float32  
1057+                 ),  # expected: conv1d result  
1058+             ), 
1059+             # Test case 2: Multiple input channels  
1060+             ( 
1061+                 "multi_input_channels" , 
1062+                 torch .tensor ( 
1063+                     [[[1.0 , 2.0 , 3.0 ], [4.0 , 5.0 , 6.0 ]]], dtype = torch .float32  
1064+                 ),  # src: 1x2x3  
1065+                 torch .tensor ([[[2 , 1 ], [1 , 2 ]]], dtype = torch .int8 ),  # weight: 1x2x2  
1066+                 0.5 ,  # w_scale  
1067+                 torch .tensor ([1 ], dtype = torch .int8 ),  # bias: 1  
1068+                 1.0 ,  # b_scale  
1069+                 torch .tensor ([[[10.0 , 13.0 ]]], dtype = torch .float32 ),  # expected  
1070+             ), 
1071+             # Test case 3: Multiple output channels  
1072+             ( 
1073+                 "multi_output_channels" , 
1074+                 torch .tensor ( 
1075+                     [[[1.0 , 2.0 , 3.0 , 4.0 ]]], dtype = torch .float32  
1076+                 ),  # src: 1x1x4  
1077+                 torch .tensor ([[[1 , - 1 ]], [[2 , 0 ]]], dtype = torch .int8 ),  # weight: 2x1x2  
1078+                 0.25 ,  # w_scale  
1079+                 torch .tensor ([0 , 1 ], dtype = torch .int8 ),  # bias: 2  
1080+                 0.5 ,  # b_scale  
1081+                 torch .tensor ( 
1082+                     [[[- 0.25 , - 0.25 , - 0.25 ], [1.0 , 1.5 , 2.0 ]]], dtype = torch .float32  
1083+                 ),  # expected  
1084+             ), 
1085+             # Test case 4: Batch size > 1  
1086+             ( 
1087+                 "batch_size_2" , 
1088+                 torch .tensor ( 
1089+                     [[[1.0 , 2.0 , 3.0 ]], [[4.0 , 5.0 , 6.0 ]]], dtype = torch .float32  
1090+                 ),  # src: 2x1x3  
1091+                 torch .tensor ([[[1 , 1 ]]], dtype = torch .int8 ),  # weight: 1x1x2  
1092+                 1.0 ,  # w_scale  
1093+                 torch .tensor ([0 ], dtype = torch .int8 ),  # bias: 1  
1094+                 1.0 ,  # b_scale  
1095+                 torch .tensor ( 
1096+                     [[[3.0 , 5.0 ]], [[9.0 , 11.0 ]]], dtype = torch .float32  
1097+                 ),  # expected  
1098+             ), 
1099+             # Test case 5: Zero weights and bias  
1100+             ( 
1101+                 "zero_weights_bias" , 
1102+                 torch .tensor ([[[1.0 , 2.0 , 3.0 ]]], dtype = torch .float32 ),  # src: 1x1x3  
1103+                 torch .tensor ([[[0 , 0 ]]], dtype = torch .int8 ),  # weight: 1x1x2  
1104+                 0.1 ,  # w_scale  
1105+                 torch .tensor ([0 ], dtype = torch .int8 ),  # bias: 1  
1106+                 1.0 ,  # b_scale  
1107+                 torch .tensor ([[[0.0 , 0.0 ]]], dtype = torch .float32 ),  # expected  
1108+             ), 
1109+             # Test case 6: Negative weights  
1110+             ( 
1111+                 "negative_weights" , 
1112+                 torch .tensor ([[[2.0 , 4.0 , 6.0 ]]], dtype = torch .float32 ),  # src: 1x1x3  
1113+                 torch .tensor ([[[- 2 , - 1 ]]], dtype = torch .int8 ),  # weight: 1x1x2  
1114+                 0.5 ,  # w_scale  
1115+                 torch .tensor ([2 ], dtype = torch .float32 ),  # bias: 1  
1116+                 1.0 ,  # b_scale  
1117+                 torch .tensor ([[[- 2.0 , - 5.0 ]]], dtype = torch .float32 ),  # expected  
1118+             ), 
1119+         ] 
1120+     ) 
1121+     def  test_quantized_w8a32_conv (
1122+         self ,
1123+         name : str ,
1124+         src : torch .Tensor ,
1125+         weight : torch .Tensor ,
1126+         w_scale : float ,
1127+         bias : torch .Tensor ,
1128+         b_scale : float ,
1129+         expected_output : torch .Tensor ,
1130+     ) ->  None :
1131+         output  =  torch .ops .cadence .quantized_w8a32_conv (
1132+             src , weight , w_scale , bias , b_scale 
1133+         )
1134+ 
1135+         # Verify output properties 
1136+         self .assertEqual (
1137+             output .dtype ,
1138+             torch .float32 ,
1139+             f"Output dtype should be float32 in { name }  " ,
1140+         )
1141+         self .assertEqual (
1142+             output .shape ,
1143+             expected_output .shape ,
1144+             f"Output shape should match expected shape in { name }  " ,
1145+         )
1146+ 
1147+         # Verify output matches expected values 
1148+         self .assertTrue (
1149+             torch .allclose (output , expected_output , rtol = 1e-4 , atol = 1e-4 ),
1150+             f"Output values don't match expected in { name }  . Got { output }  , expected { expected_output }  " ,
1151+         )
1152+ 
10431153    @expand ( 
10441154        [ 
10451155            # Test case 1: Basic int8 case with negative scale  
0 commit comments