Skip to content

Commit c5106f0

Browse files
Andrew Grebenisanfacebook-github-bot
authored andcommitted
Cadence: Support quantized_w8a32_conv (pytorch#15137)
Summary: As titled Reviewed By: skrtskrtfb Differential Revision: D84658444
1 parent 41baaa9 commit c5106f0

File tree

3 files changed

+140
-1
lines changed

3 files changed

+140
-1
lines changed

backends/cadence/aot/ops_registrations.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@ def _validate_ref_impl_exists() -> None:
6767
"cadence::dequantize_per_tensor_asym16u",
6868
"cadence::linalg_vector_norm",
6969
"cadence::quantized_conv2d_nchw", # We should only support per_tensor variant, should remove
70-
"cadence::quantized_w8a32_conv",
7170
"cadence::quantize_per_tensor_asym32s",
7271
"cadence::quantized_relu", # We should only support per_tensor variant, should remove
7372
"cadence::linalg_svd",

backends/cadence/aot/ref_implementations.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -703,6 +703,36 @@ def quantized_conv2d_nchw_per_tensor(
703703
)
704704

705705

706+
@impl_tracked(m, "quantized_w8a32_conv")
707+
def quantized_w8a32_conv(
708+
src: torch.Tensor,
709+
weight: torch.Tensor,
710+
w_scale: float,
711+
bias: torch.Tensor,
712+
b_scale: float,
713+
) -> torch.Tensor:
714+
# src comes in shape [batch, in_channel, in_length]
715+
# weight comes in shape [out_ch, in_ch, kernel_dim]
716+
# output comes in empty with shape [batch, out_ch, in_length - kernel_dim + 1]
717+
# Dequantize weight using scale
718+
dequant_weight = weight.float() * w_scale # Assuming zero_point is 0 for weights
719+
720+
# Dequantize bias using scale
721+
dequant_bias = bias.float() * b_scale
722+
723+
# Perform 1D convolution
724+
# src: [batch, in_channel, in_length]
725+
# weight: [out_ch, in_ch, kernel_dim]
726+
# bias: [out_ch]
727+
output = torch.nn.functional.conv1d(
728+
src.float(),
729+
dequant_weight,
730+
dequant_bias,
731+
)
732+
733+
return output
734+
735+
706736
@impl_tracked(m, "quantized_conv2d_nhwc.per_tensor")
707737
def quantized_conv2d_nhwc_per_tensor(
708738
input_tensor: torch.Tensor,

backends/cadence/aot/tests/test_ref_implementations.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1040,6 +1040,116 @@ def test_quantized_conv_per_tensor(
10401040
f"Output values don't match expected. Got {output}, expected {expected_output}",
10411041
)
10421042

1043+
@expand(
1044+
[
1045+
# Test case 1: Basic 1D convolution with int8 weights
1046+
(
1047+
"basic_int8_weights",
1048+
torch.tensor(
1049+
[[[1.0, 2.0, 3.0, 4.0, 5.0]]], dtype=torch.float32
1050+
), # src: 1x1x5
1051+
torch.tensor([[[1, -1, 2]]], dtype=torch.int8), # weight: 1x1x3
1052+
0.1, # w_scale
1053+
torch.tensor([1], dtype=torch.int8), # bias: 1
1054+
0.2, # b_scale
1055+
torch.tensor(
1056+
[[[0.7, 0.9, 1.1]]], dtype=torch.float32
1057+
), # expected: conv1d result
1058+
),
1059+
# Test case 2: Multiple input channels
1060+
(
1061+
"multi_input_channels",
1062+
torch.tensor(
1063+
[[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]], dtype=torch.float32
1064+
), # src: 1x2x3
1065+
torch.tensor([[[2, 1], [1, 2]]], dtype=torch.int8), # weight: 1x2x2
1066+
0.5, # w_scale
1067+
torch.tensor([1], dtype=torch.int8), # bias: 1
1068+
1.0, # b_scale
1069+
torch.tensor([[[10.0, 13.0]]], dtype=torch.float32), # expected
1070+
),
1071+
# Test case 3: Multiple output channels
1072+
(
1073+
"multi_output_channels",
1074+
torch.tensor(
1075+
[[[1.0, 2.0, 3.0, 4.0]]], dtype=torch.float32
1076+
), # src: 1x1x4
1077+
torch.tensor([[[1, -1]], [[2, 0]]], dtype=torch.int8), # weight: 2x1x2
1078+
0.25, # w_scale
1079+
torch.tensor([0, 1], dtype=torch.int8), # bias: 2
1080+
0.5, # b_scale
1081+
torch.tensor(
1082+
[[[-0.25, -0.25, -0.25], [1.0, 1.5, 2.0]]], dtype=torch.float32
1083+
), # expected
1084+
),
1085+
# Test case 4: Batch size > 1
1086+
(
1087+
"batch_size_2",
1088+
torch.tensor(
1089+
[[[1.0, 2.0, 3.0]], [[4.0, 5.0, 6.0]]], dtype=torch.float32
1090+
), # src: 2x1x3
1091+
torch.tensor([[[1, 1]]], dtype=torch.int8), # weight: 1x1x2
1092+
1.0, # w_scale
1093+
torch.tensor([0], dtype=torch.int8), # bias: 1
1094+
1.0, # b_scale
1095+
torch.tensor(
1096+
[[[3.0, 5.0]], [[9.0, 11.0]]], dtype=torch.float32
1097+
), # expected
1098+
),
1099+
# Test case 5: Zero weights and bias
1100+
(
1101+
"zero_weights_bias",
1102+
torch.tensor([[[1.0, 2.0, 3.0]]], dtype=torch.float32), # src: 1x1x3
1103+
torch.tensor([[[0, 0]]], dtype=torch.int8), # weight: 1x1x2
1104+
0.1, # w_scale
1105+
torch.tensor([0], dtype=torch.int8), # bias: 1
1106+
1.0, # b_scale
1107+
torch.tensor([[[0.0, 0.0]]], dtype=torch.float32), # expected
1108+
),
1109+
# Test case 6: Negative weights
1110+
(
1111+
"negative_weights",
1112+
torch.tensor([[[2.0, 4.0, 6.0]]], dtype=torch.float32), # src: 1x1x3
1113+
torch.tensor([[[-2, -1]]], dtype=torch.int8), # weight: 1x1x2
1114+
0.5, # w_scale
1115+
torch.tensor([2], dtype=torch.float32), # bias: 1
1116+
1.0, # b_scale
1117+
torch.tensor([[[-2.0, -5.0]]], dtype=torch.float32), # expected
1118+
),
1119+
]
1120+
)
1121+
def test_quantized_w8a32_conv(
1122+
self,
1123+
name: str,
1124+
src: torch.Tensor,
1125+
weight: torch.Tensor,
1126+
w_scale: float,
1127+
bias: torch.Tensor,
1128+
b_scale: float,
1129+
expected_output: torch.Tensor,
1130+
) -> None:
1131+
output = torch.ops.cadence.quantized_w8a32_conv(
1132+
src, weight, w_scale, bias, b_scale
1133+
)
1134+
1135+
# Verify output properties
1136+
self.assertEqual(
1137+
output.dtype,
1138+
torch.float32,
1139+
f"Output dtype should be float32 in {name}",
1140+
)
1141+
self.assertEqual(
1142+
output.shape,
1143+
expected_output.shape,
1144+
f"Output shape should match expected shape in {name}",
1145+
)
1146+
1147+
# Verify output matches expected values
1148+
self.assertTrue(
1149+
torch.allclose(output, expected_output, rtol=1e-4, atol=1e-4),
1150+
f"Output values don't match expected in {name}. Got {output}, expected {expected_output}",
1151+
)
1152+
10431153
@expand(
10441154
[
10451155
# Test case 1: Basic int8 case with negative scale

0 commit comments

Comments
 (0)