@@ -1188,7 +1188,7 @@ def test_quantized_conv_per_tensor(
11881188 dtype = torch .int8 ,
11891189 ), # weight: 4x4x3
11901190 0.5 , # w_scale
1191- torch .tensor ([2 , 2 , 2 , 2 ], dtype = torch .float32 ), # bias: 4
1191+ torch .tensor ([2 , 2 , 2 , 2 ], dtype = torch .int8 ), # bias: 4
11921192 1.0 , # b_scale
11931193 torch .tensor (
11941194 [
@@ -1236,6 +1236,82 @@ def test_quantized_w8a32_conv(
12361236 f"Output values don't match expected in { name } . Got { output } , expected { expected_output } " ,
12371237 )
12381238
1239+ @expand (
1240+ [
1241+ (
1242+ "multi_input_features" ,
1243+ torch .tensor ([[1.0 , 2.0 , 3.0 ]], dtype = torch .float32 ), # src: 1x3
1244+ torch .tensor ([[2 , 1 , 1 ], [1 , 2 , 1 ]], dtype = torch .int8 ), # weight: 2x3
1245+ 0.5 , # w_scale
1246+ torch .tensor ([0 , 1 ], dtype = torch .int8 ), # bias: 2
1247+ 1.0 , # b_scale
1248+ torch .tensor ([[3.5 , 5.0 ]], dtype = torch .float32 ), # expected
1249+ ),
1250+ (
1251+ "batch_size_2" ,
1252+ torch .tensor ([[1.0 , 2.0 ], [3.0 , 4.0 ]], dtype = torch .float32 ), # src: 2x2
1253+ torch .tensor ([[1 , 1 ], [2 , - 1 ]], dtype = torch .int8 ), # weight: 2x2
1254+ 1.0 , # w_scale
1255+ torch .tensor ([0 , 0 ], dtype = torch .int8 ), # bias: 2
1256+ 1.0 , # b_scale
1257+ torch .tensor ([[3.0 , 0.0 ], [7.0 , 2.0 ]], dtype = torch .float32 ), # expected
1258+ ),
1259+ (
1260+ "3d_input" ,
1261+ torch .tensor (
1262+ [[[1.0 , 2.0 ], [3.0 , 4.0 ]]], dtype = torch .float32
1263+ ), # src: 1x2x2
1264+ torch .tensor ([[1 , 1 ], [2 , - 1 ]], dtype = torch .int8 ), # weight: 2x2
1265+ 1.0 , # w_scale
1266+ torch .tensor ([0 , 1 ], dtype = torch .int8 ), # bias: 2
1267+ 1.0 , # b_scale
1268+ torch .tensor (
1269+ [[[3.0 , 1.0 ], [7.0 , 3.0 ]]], dtype = torch .float32
1270+ ), # expected
1271+ ),
1272+ (
1273+ "negative_weights" ,
1274+ torch .tensor ([[2.0 , 4.0 ]], dtype = torch .float32 ), # src: 1x2
1275+ torch .tensor ([[- 2 , - 1 ], [- 3 , - 2 ]], dtype = torch .int8 ), # weight: 2x2
1276+ 0.5 , # w_scale
1277+ torch .tensor ([2 , 1 ], dtype = torch .int8 ), # bias: 2
1278+ 1.0 , # b_scale
1279+ torch .tensor ([[- 2.0 , - 6.0 ]], dtype = torch .float32 ), # expected
1280+ ),
1281+ ]
1282+ )
1283+ def test_quantized_w8a32_linear (
1284+ self ,
1285+ name : str ,
1286+ src : torch .Tensor ,
1287+ weight : torch .Tensor ,
1288+ w_scale : float ,
1289+ bias : torch .Tensor ,
1290+ b_scale : float ,
1291+ expected_output : torch .Tensor ,
1292+ ) -> None :
1293+ output = torch .ops .cadence .quantized_w8a32_linear (
1294+ src , weight , w_scale , bias , b_scale
1295+ )
1296+
1297+ # Verify output properties
1298+ self .assertEqual (
1299+ output .dtype ,
1300+ torch .float32 ,
1301+ f"Output dtype should be float32 in { name } " ,
1302+ )
1303+ self .assertEqual (
1304+ output .shape ,
1305+ expected_output .shape ,
1306+ f"Output shape should match expected shape in { name } " ,
1307+ )
1308+
1309+ # Verify output matches expected values
1310+ self .assertTrue (
1311+ torch .allclose (output , expected_output , rtol = 1e-4 , atol = 1e-4 ),
1312+ f"Output values don't match expected in { name } . Got { output } , expected { expected_output } " ,
1313+ )
1314+
12391315 @expand (
12401316 [
12411317 # Test case 1: Basic int8 case with negative scale
0 commit comments