Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 13 additions & 19 deletions test/quantization/pt2e/test_quantize_pt2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,29 +157,23 @@ def test_chunked_bn_fusion(self):
n_chunks = 3
in_channels = 1
out_channels = 32
m = ConvWithSharedWeightInExportedModel(n_chunks, in_channels, out_channels)
m.bn.running_var = torch.nn.Parameter(
torch.rand(out_channels) * 1e-2, requires_grad=False
)
for bias in [True, False]:
m = ConvWithSharedWeightInExportedModel(n_chunks, in_channels, out_channels, bias=bias)
m.bn.running_var = torch.nn.Parameter(
torch.rand(out_channels) * 1e-2, requires_grad=False
)

m.eval()
example_inputs = (torch.rand(batch_size, n_chunks, 32, 32),)
ref_outputs = m(*example_inputs)
traced_model = torch.export.export(m, example_inputs, strict=True).module()
traced_outputs = traced_model(*example_inputs)
prepared_model = prepare_pt2e(traced_model, XNNPACKQuantizer())
prepared_outputs = prepared_model(*example_inputs)

if isinstance(ref_outputs, (tuple, list)):
for ref, prepared, traced in zip(
ref_outputs, prepared_outputs, traced_outputs
):
torch.testing.assert_close(ref, traced)
torch.testing.assert_close(traced, prepared)
else:
m.eval()
example_inputs = (torch.rand(batch_size, n_chunks, 32, 32),)
ref_outputs = m(*example_inputs)
traced_model = torch.export.export(m, example_inputs, strict=True).module()
traced_outputs = traced_model(*example_inputs)
prepared_model = prepare_pt2e(traced_model, XNNPACKQuantizer())
prepared_outputs = prepared_model(*example_inputs)
torch.testing.assert_close(ref_outputs, traced_outputs)
torch.testing.assert_close(traced_outputs, prepared_outputs)


def test_wo_annotate_conv_output_quantizer(self):
# TODO: use OP_TO_ANNOTATOR
class BackendAQuantizer(Quantizer):
Expand Down
9 changes: 5 additions & 4 deletions torchao/quantization/pt2e/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,10 +710,11 @@ def fold_bn_weights_into_conv_node(
conv_args.append(None)

if fake_fuse:
fused_weight, fused_bias = (
torch.nn.Parameter(conv_w, conv_w.requires_grad),
torch.nn.Parameter(conv_b, conv_b.requires_grad),
)
fused_weight = torch.nn.Parameter(conv_w, conv_w.requires_grad)
if conv_b is not None:
fused_bias = torch.nn.Parameter(conv_b, conv_b.requires_grad)
else:
fused_bias = torch.nn.Parameter(torch.zeros_like(bn_rm), requires_grad=conv_w.requires_grad)
else:
fused_weight, fused_bias = fuse_conv_bn_weights(
conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b, transpose=transpose
Expand Down
4 changes: 2 additions & 2 deletions torchao/testing/model_architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,11 @@ def forward(self, x):

class ConvWithSharedWeightInExportedModel(nn.Module):
def __init__(
self, n_chunks, in_channels, out_channels, kernel_size=3, stride=1, padding=1
self, n_chunks, in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True
) -> None:
super().__init__()
self.n_chunks = n_chunks
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)

Expand Down
Loading