Skip to content

Commit 05a4d72

Browse files
Andrew Grebenisanmeta-codesync[bot]
authored andcommitted
Cadence ops: Support quantized_w8a32_linear (pytorch#15171)
Summary: Pull Request resolved: pytorch#15171 As titled. Differential Revision: D84745967
1 parent b1cf245 commit 05a4d72

File tree

3 files changed

+111
-3
lines changed

3 files changed

+111
-3
lines changed

backends/cadence/aot/ops_registrations.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@ def _validate_ref_impl_exists() -> None:
5353
# 1. be removed
5454
# 2. have a reference implementation added to ref_implementations.py
5555
_WARN_ONLY = {
56-
"cadence::quantized_w8a32_linear",
5756
"cadence::quantized_add", # We should only support per_tensor variant, should remove
5857
"cadence::_softmax_f32_f32",
5958
"cadence::requantize", # We should only support per_tensor variant, should remove
@@ -2702,10 +2701,13 @@ def quantized_w8a32_linear_meta(
27022701
b_scale: float,
27032702
) -> torch.Tensor:
27042703
# src comes in shape [leading_dims, in_dim]
2705-
# weight comes in shape [in_dim, out_dim]
2704+
# weight comes in shape [out_dim, in_dim]
27062705
# output comes in empty with shape [leading_dims, out_dim]
27072706
src_shape = list(src.shape)
27082707
weight_shape = weight.shape
2708+
assert (src_shape[-1] % 4) == 0
2709+
if len(src_shape) >= 2:
2710+
assert src_shape[-2] == 1
27092711
assert len(weight_shape) == 2
27102712
assert src_shape[-1] == weight_shape[-1]
27112713
src_shape[-1] = weight_shape[0]

backends/cadence/aot/ref_implementations.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -866,6 +866,9 @@ def quantized_w8a32_conv(
866866
# weight comes in shape [out_ch, in_ch, kernel_dim]
867867
# output comes in empty with shape [batch, out_ch, in_length - kernel_dim + 1]
868868
# Dequantize weight using scale
869+
assert weight.dtype == torch.int8
870+
assert bias.dtype == torch.int8
871+
869872
dequant_weight = weight.float() * w_scale
870873

871874
# Dequantize bias using scale
@@ -884,6 +887,33 @@ def quantized_w8a32_conv(
884887
return output
885888

886889

890+
@impl_tracked(m, "quantized_w8a32_linear")
891+
def quantized_w8a32_linear(
892+
src: torch.Tensor,
893+
weight: torch.Tensor,
894+
w_scale: float,
895+
bias: torch.Tensor,
896+
b_scale: float,
897+
) -> torch.Tensor:
898+
# src comes in shape [leading_dims, in_dim]
899+
# weight comes in shape [out_dim, in_dim]
900+
# output comes in empty with shape [leading_dims, out_dim]
901+
assert weight.dtype == torch.int8
902+
assert bias.dtype == torch.int8
903+
if len(src.shape) >= 2:
904+
assert src.shape[-2] == 1
905+
dequant_weight = weight.float() * w_scale
906+
dequant_bias = bias.float() * b_scale
907+
908+
output = torch.nn.functional.linear(
909+
src.float(),
910+
dequant_weight,
911+
dequant_bias,
912+
)
913+
914+
return output
915+
916+
887917
@impl_tracked(m, "quantized_conv2d_nhwc.per_tensor")
888918
def quantized_conv2d_nhwc_per_tensor(
889919
input_tensor: torch.Tensor,

backends/cadence/aot/tests/test_ref_implementations.py

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1188,7 +1188,7 @@ def test_quantized_conv_per_tensor(
11881188
dtype=torch.int8,
11891189
), # weight: 4x4x3
11901190
0.5, # w_scale
1191-
torch.tensor([2, 2, 2, 2], dtype=torch.float32), # bias: 4
1191+
torch.tensor([2, 2, 2, 2], dtype=torch.int8), # bias: 4
11921192
1.0, # b_scale
11931193
torch.tensor(
11941194
[
@@ -1236,6 +1236,82 @@ def test_quantized_w8a32_conv(
12361236
f"Output values don't match expected in {name}. Got {output}, expected {expected_output}",
12371237
)
12381238

1239+
@expand(
1240+
[
1241+
(
1242+
"multi_input_features",
1243+
torch.tensor([[1.0, 2.0, 3.0]], dtype=torch.float32), # src: 1x3
1244+
torch.tensor([[2, 1, 1], [1, 2, 1]], dtype=torch.int8), # weight: 2x3
1245+
0.5, # w_scale
1246+
torch.tensor([0, 1], dtype=torch.int8), # bias: 2
1247+
1.0, # b_scale
1248+
torch.tensor([[3.5, 5.0]], dtype=torch.float32), # expected
1249+
),
1250+
(
1251+
"batch_size_2",
1252+
torch.tensor([[1.0, 2.0], [3.0, 4.0]], dtype=torch.float32), # src: 2x2
1253+
torch.tensor([[1, 1], [2, -1]], dtype=torch.int8), # weight: 2x2
1254+
1.0, # w_scale
1255+
torch.tensor([0, 0], dtype=torch.int8), # bias: 2
1256+
1.0, # b_scale
1257+
torch.tensor([[3.0, 0.0], [7.0, 2.0]], dtype=torch.float32), # expected
1258+
),
1259+
(
1260+
"3d_input",
1261+
torch.tensor(
1262+
[[[1.0, 2.0], [3.0, 4.0]]], dtype=torch.float32
1263+
), # src: 1x2x2
1264+
torch.tensor([[1, 1], [2, -1]], dtype=torch.int8), # weight: 2x2
1265+
1.0, # w_scale
1266+
torch.tensor([0, 1], dtype=torch.int8), # bias: 2
1267+
1.0, # b_scale
1268+
torch.tensor(
1269+
[[[3.0, 1.0], [7.0, 3.0]]], dtype=torch.float32
1270+
), # expected
1271+
),
1272+
(
1273+
"negative_weights",
1274+
torch.tensor([[2.0, 4.0]], dtype=torch.float32), # src: 1x2
1275+
torch.tensor([[-2, -1], [-3, -2]], dtype=torch.int8), # weight: 2x2
1276+
0.5, # w_scale
1277+
torch.tensor([2, 1], dtype=torch.int8), # bias: 2
1278+
1.0, # b_scale
1279+
torch.tensor([[-2.0, -6.0]], dtype=torch.float32), # expected
1280+
),
1281+
]
1282+
)
1283+
def test_quantized_w8a32_linear(
1284+
self,
1285+
name: str,
1286+
src: torch.Tensor,
1287+
weight: torch.Tensor,
1288+
w_scale: float,
1289+
bias: torch.Tensor,
1290+
b_scale: float,
1291+
expected_output: torch.Tensor,
1292+
) -> None:
1293+
output = torch.ops.cadence.quantized_w8a32_linear(
1294+
src, weight, w_scale, bias, b_scale
1295+
)
1296+
1297+
# Verify output properties
1298+
self.assertEqual(
1299+
output.dtype,
1300+
torch.float32,
1301+
f"Output dtype should be float32 in {name}",
1302+
)
1303+
self.assertEqual(
1304+
output.shape,
1305+
expected_output.shape,
1306+
f"Output shape should match expected shape in {name}",
1307+
)
1308+
1309+
# Verify output matches expected values
1310+
self.assertTrue(
1311+
torch.allclose(output, expected_output, rtol=1e-4, atol=1e-4),
1312+
f"Output values don't match expected in {name}. Got {output}, expected {expected_output}",
1313+
)
1314+
12391315
@expand(
12401316
[
12411317
# Test case 1: Basic int8 case with negative scale

0 commit comments

Comments
 (0)