@@ -1188,7 +1188,7 @@ def test_quantized_conv_per_tensor(
11881188 dtype = torch .int8 ,
11891189 ), # weight: 4x4x3
11901190 0.5 , # w_scale
1191- torch .tensor ([2 , 2 , 2 , 2 ], dtype = torch .float32 ), # bias: 4
1191+ torch .tensor ([2 , 2 , 2 , 2 ], dtype = torch .int8 ), # bias: 4
11921192 1.0 , # b_scale
11931193 torch .tensor (
11941194 [
@@ -1236,6 +1236,95 @@ def test_quantized_w8a32_conv(
12361236 f"Output values don't match expected in { name } . Got { output } , expected { expected_output } " ,
12371237 )
12381238
1239+ @expand (
1240+ [
1241+ (
1242+ "multi_input_features" ,
1243+ torch .tensor ([[1.0 , 2.0 , 3.0 ]], dtype = torch .float32 ), # src: 1x3
1244+ torch .tensor ([[2 , 1 , 1 ], [1 , 2 , 1 ]], dtype = torch .int8 ), # weight: 2x3
1245+ 0.5 , # w_scale
1246+ torch .tensor ([0 , 1 ], dtype = torch .int8 ), # bias: 2
1247+ 1.0 , # b_scale
1248+ torch .tensor ([[3.5 , 5.0 ]], dtype = torch .float32 ), # expected
1249+ ),
1250+ (
1251+ "batch_size_2" ,
1252+ torch .tensor (
1253+ [[[1.0 , 2.0 ]], [[3.0 , 4.0 ]]], dtype = torch .float32
1254+ ), # src: 2x2
1255+ torch .tensor ([[1 , 1 ], [2 , - 1 ]], dtype = torch .int8 ), # weight: 2x2
1256+ 1.0 , # w_scale
1257+ torch .tensor ([0 , 0 ], dtype = torch .int8 ), # bias: 2
1258+ 1.0 , # b_scale
1259+ torch .tensor (
1260+ [[[3.0 , 0.0 ]], [[7.0 , 2.0 ]]], dtype = torch .float32
1261+ ), # expected
1262+ ),
1263+ (
1264+ "shape_assertion_error" ,
1265+ torch .tensor (
1266+ [[[1.0 , 2.0 ], [3.0 , 4.0 ]]], dtype = torch .float32
1267+ ), # src: 1x2x2
1268+ torch .tensor ([[1 , 1 ], [2 , - 1 ]], dtype = torch .int8 ), # weight: 2x2
1269+ 1.0 , # w_scale
1270+ torch .tensor ([0 , 1 ], dtype = torch .int8 ), # bias: 2
1271+ 1.0 , # b_scale
1272+ torch .tensor (
1273+ [[[3.0 , 1.0 ], [7.0 , 3.0 ]]], dtype = torch .float32
1274+ ), # expected
1275+ ),
1276+ (
1277+ "negative_weights" ,
1278+ torch .tensor ([[2.0 , 4.0 ]], dtype = torch .float32 ), # src: 1x2
1279+ torch .tensor ([[- 2 , - 1 ], [- 3 , - 2 ]], dtype = torch .int8 ), # weight: 2x2
1280+ 0.5 , # w_scale
1281+ torch .tensor ([2 , 1 ], dtype = torch .int8 ), # bias: 2
1282+ 1.0 , # b_scale
1283+ torch .tensor ([[- 2.0 , - 6.0 ]], dtype = torch .float32 ), # expected
1284+ ),
1285+ ]
1286+ )
1287+ def test_quantized_w8a32_linear (
1288+ self ,
1289+ name : str ,
1290+ src : torch .Tensor ,
1291+ weight : torch .Tensor ,
1292+ w_scale : float ,
1293+ bias : torch .Tensor ,
1294+ b_scale : float ,
1295+ expected_output : torch .Tensor ,
1296+ ) -> None :
1297+ if name == "shape_assertion_error" :
1298+ with self .assertRaisesRegex (
1299+ AssertionError , "Only supporting vector-matrix multiplication"
1300+ ):
1301+ torch .ops .cadence .quantized_w8a32_linear (
1302+ src , weight , w_scale , bias , b_scale
1303+ )
1304+ return
1305+
1306+ output = torch .ops .cadence .quantized_w8a32_linear (
1307+ src , weight , w_scale , bias , b_scale
1308+ )
1309+
1310+ # Verify output properties
1311+ self .assertEqual (
1312+ output .dtype ,
1313+ torch .float32 ,
1314+ f"Output dtype should be float32 in { name } " ,
1315+ )
1316+ self .assertEqual (
1317+ output .shape ,
1318+ expected_output .shape ,
1319+ f"Output shape should match expected shape in { name } " ,
1320+ )
1321+
1322+ # Verify output matches expected values
1323+ self .assertTrue (
1324+ torch .allclose (output , expected_output , rtol = 1e-4 , atol = 1e-4 ),
1325+ f"Output values don't match expected in { name } . Got { output } , expected { expected_output } " ,
1326+ )
1327+
12391328 @expand (
12401329 [
12411330 # Test case 1: Basic int8 case with negative scale
0 commit comments