Skip to content

Commit 0e6cd2b

Browse files
Andrew Grebenisanfacebook-github-bot
authored andcommitted
Cadence ops: Support quantized_w8a32_linear (pytorch#15171)
Summary: As titled. Reviewed By: hsharma35 Differential Revision: D84745967
1 parent bd34e74 commit 0e6cd2b

File tree

3 files changed

+124
-3
lines changed

3 files changed

+124
-3
lines changed

backends/cadence/aot/ops_registrations.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@ def _validate_ref_impl_exists() -> None:
5353
# 1. be removed
5454
# 2. have a reference implementation added to ref_implementations.py
5555
_WARN_ONLY = {
56-
"cadence::quantized_w8a32_linear",
5756
"cadence::quantized_add", # We should only support per_tensor variant, should remove
5857
"cadence::_softmax_f32_f32",
5958
"cadence::requantize", # We should only support per_tensor variant, should remove
@@ -2702,10 +2701,13 @@ def quantized_w8a32_linear_meta(
27022701
b_scale: float,
27032702
) -> torch.Tensor:
27042703
# src comes in shape [leading_dims, in_dim]
2705-
# weight comes in shape [in_dim, out_dim]
2704+
# weight comes in shape [out_dim, in_dim]
27062705
# output comes in empty with shape [leading_dims, out_dim]
27072706
src_shape = list(src.shape)
27082707
weight_shape = weight.shape
2708+
assert (src_shape[-1] % 4) == 0
2709+
if len(src_shape) >= 2:
2710+
assert src_shape[-2] == 1
27092711
assert len(weight_shape) == 2
27102712
assert src_shape[-1] == weight_shape[-1]
27112713
src_shape[-1] = weight_shape[0]

backends/cadence/aot/ref_implementations.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -866,6 +866,9 @@ def quantized_w8a32_conv(
866866
# weight comes in shape [out_ch, in_ch, kernel_dim]
867867
# output comes in empty with shape [batch, out_ch, in_length - kernel_dim + 1]
868868
# Dequantize weight using scale
869+
assert weight.dtype == torch.int8
870+
assert bias.dtype == torch.int8
871+
869872
dequant_weight = weight.float() * w_scale
870873

871874
# Dequantize bias using scale
@@ -884,6 +887,33 @@ def quantized_w8a32_conv(
884887
return output
885888

886889

890+
@impl_tracked(m, "quantized_w8a32_linear")
891+
def quantized_w8a32_linear(
892+
src: torch.Tensor,
893+
weight: torch.Tensor,
894+
w_scale: float,
895+
bias: torch.Tensor,
896+
b_scale: float,
897+
) -> torch.Tensor:
898+
# src comes in shape [leading_dims, in_dim]
899+
# weight comes in shape [out_dim, in_dim]
900+
# output comes in empty with shape [leading_dims, out_dim]
901+
assert weight.dtype == torch.int8
902+
assert bias.dtype == torch.int8
903+
if len(src.shape) >= 2:
904+
assert src.shape[-2] == 1, "Only supporting vector-matrix multiplication"
905+
dequant_weight = weight.float() * w_scale
906+
dequant_bias = bias.float() * b_scale
907+
908+
output = torch.nn.functional.linear(
909+
src.float(),
910+
dequant_weight,
911+
dequant_bias,
912+
)
913+
914+
return output
915+
916+
887917
@impl_tracked(m, "quantized_conv2d_nhwc.per_tensor")
888918
def quantized_conv2d_nhwc_per_tensor(
889919
input_tensor: torch.Tensor,

backends/cadence/aot/tests/test_ref_implementations.py

Lines changed: 90 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1188,7 +1188,7 @@ def test_quantized_conv_per_tensor(
11881188
dtype=torch.int8,
11891189
), # weight: 4x4x3
11901190
0.5, # w_scale
1191-
torch.tensor([2, 2, 2, 2], dtype=torch.float32), # bias: 4
1191+
torch.tensor([2, 2, 2, 2], dtype=torch.int8), # bias: 4
11921192
1.0, # b_scale
11931193
torch.tensor(
11941194
[
@@ -1236,6 +1236,95 @@ def test_quantized_w8a32_conv(
12361236
f"Output values don't match expected in {name}. Got {output}, expected {expected_output}",
12371237
)
12381238

1239+
@expand(
1240+
[
1241+
(
1242+
"multi_input_features",
1243+
torch.tensor([[1.0, 2.0, 3.0]], dtype=torch.float32), # src: 1x3
1244+
torch.tensor([[2, 1, 1], [1, 2, 1]], dtype=torch.int8), # weight: 2x3
1245+
0.5, # w_scale
1246+
torch.tensor([0, 1], dtype=torch.int8), # bias: 2
1247+
1.0, # b_scale
1248+
torch.tensor([[3.5, 5.0]], dtype=torch.float32), # expected
1249+
),
1250+
(
1251+
"batch_size_2",
1252+
torch.tensor(
1253+
[[[1.0, 2.0]], [[3.0, 4.0]]], dtype=torch.float32
1254+
), # src: 2x2
1255+
torch.tensor([[1, 1], [2, -1]], dtype=torch.int8), # weight: 2x2
1256+
1.0, # w_scale
1257+
torch.tensor([0, 0], dtype=torch.int8), # bias: 2
1258+
1.0, # b_scale
1259+
torch.tensor(
1260+
[[[3.0, 0.0]], [[7.0, 2.0]]], dtype=torch.float32
1261+
), # expected
1262+
),
1263+
(
1264+
"shape_assertion_error",
1265+
torch.tensor(
1266+
[[[1.0, 2.0], [3.0, 4.0]]], dtype=torch.float32
1267+
), # src: 1x2x2
1268+
torch.tensor([[1, 1], [2, -1]], dtype=torch.int8), # weight: 2x2
1269+
1.0, # w_scale
1270+
torch.tensor([0, 1], dtype=torch.int8), # bias: 2
1271+
1.0, # b_scale
1272+
torch.tensor(
1273+
[[[3.0, 1.0], [7.0, 3.0]]], dtype=torch.float32
1274+
), # expected
1275+
),
1276+
(
1277+
"negative_weights",
1278+
torch.tensor([[2.0, 4.0]], dtype=torch.float32), # src: 1x2
1279+
torch.tensor([[-2, -1], [-3, -2]], dtype=torch.int8), # weight: 2x2
1280+
0.5, # w_scale
1281+
torch.tensor([2, 1], dtype=torch.int8), # bias: 2
1282+
1.0, # b_scale
1283+
torch.tensor([[-2.0, -6.0]], dtype=torch.float32), # expected
1284+
),
1285+
]
1286+
)
1287+
def test_quantized_w8a32_linear(
1288+
self,
1289+
name: str,
1290+
src: torch.Tensor,
1291+
weight: torch.Tensor,
1292+
w_scale: float,
1293+
bias: torch.Tensor,
1294+
b_scale: float,
1295+
expected_output: torch.Tensor,
1296+
) -> None:
1297+
if name == "shape_assertion_error":
1298+
with self.assertRaisesRegex(
1299+
AssertionError, "Only supporting vector-matrix multiplication"
1300+
):
1301+
torch.ops.cadence.quantized_w8a32_linear(
1302+
src, weight, w_scale, bias, b_scale
1303+
)
1304+
return
1305+
1306+
output = torch.ops.cadence.quantized_w8a32_linear(
1307+
src, weight, w_scale, bias, b_scale
1308+
)
1309+
1310+
# Verify output properties
1311+
self.assertEqual(
1312+
output.dtype,
1313+
torch.float32,
1314+
f"Output dtype should be float32 in {name}",
1315+
)
1316+
self.assertEqual(
1317+
output.shape,
1318+
expected_output.shape,
1319+
f"Output shape should match expected shape in {name}",
1320+
)
1321+
1322+
# Verify output matches expected values
1323+
self.assertTrue(
1324+
torch.allclose(output, expected_output, rtol=1e-4, atol=1e-4),
1325+
f"Output values don't match expected in {name}. Got {output}, expected {expected_output}",
1326+
)
1327+
12391328
@expand(
12401329
[
12411330
# Test case 1: Basic int8 case with negative scale

0 commit comments

Comments
 (0)