Skip to content

Commit ace645a

Browse files
jerryzh168pytorchmergebot
authored andcommitted
Add support for prototype affine quantization in pt2e flow (pytorch#141421)
Summary: duplicated affine quantization functionality including observer (https://github.com/pytorch/ao/blob/main/torchao/quantization/observer.py) and some quant_primitive ops (https://github.com/pytorch/ao/blob/7c3c51fd0de33307e43a1769883a348861d6f7c9/torchao/quantization/quant_primitives.py#L26-L30) to allow for per group quantization min max observer in pt2e flow Next: We can follow up to add moving average min max observer Test Plan: python test/test_quantization.py -k test_channel_group_quantization Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: pytorch#141421 Approved by: https://github.com/cccclai
1 parent 60a0d53 commit ace645a

File tree

8 files changed

+1172
-7
lines changed

8 files changed

+1172
-7
lines changed

docs/source/quantization-support.rst

+12
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,18 @@ the values observed during calibration (PTQ) or training (QAT).
250250
default_per_channel_weight_observer
251251
default_dynamic_quant_observer
252252
default_float_qparams_observer
253+
AffineQuantizedObserverBase
254+
Granularity
255+
MappingType
256+
PerAxis
257+
PerBlock
258+
PerGroup
259+
PerRow
260+
PerTensor
261+
PerToken
262+
TorchAODType
263+
ZeroPointDomain
264+
get_block_size
253265

254266
torch.ao.quantization.fake_quantize
255267
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

mypy.ini

+3
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,9 @@ ignore_missing_imports = True
7979
[mypy-torch.ao.quantization.experimental.fake_quantize]
8080
ignore_missing_imports = True
8181

82+
[mypy-torch.ao.quantization.pt2e._affine_quantization]
83+
ignore_errors = True
84+
8285
#
8386
# Files with various errors. Mostly real errors, possibly some false
8487
# positives as well.

test/quantization/pt2e/test_quantize_pt2e.py

+90-6
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242
)
4343
from torch.export import export_for_training
4444
from torch.fx import Node
45-
from torch.testing._internal.common_device_type import instantiate_device_type_tests
4645
from torch.testing._internal.common_quantization import (
4746
NodeSpec as ns,
4847
PT2EQuantizationTestCase,
@@ -1865,6 +1864,10 @@ def _get_bn_train_eval_ops(self):
18651864
torch.ops.aten.batch_norm.default,
18661865
)
18671866

1867+
@parametrize(
1868+
"device",
1869+
["cpu"] + (["cuda"] if TEST_CUDA else []) + (["hpu"] if TEST_HPU else []),
1870+
)
18681871
def test_move_exported_model_bn(self, device):
18691872
"""
18701873
Test switching batch_norm behavior between train and eval modes using
@@ -2477,9 +2480,90 @@ def check_nn_module(node):
24772480
check_nn_module(node)
24782481

24792482

2480-
instantiate_parametrized_tests(TestQuantizePT2E)
2483+
@skipIfNoQNNPACK
2484+
class TestQuantizePT2EAffineQuantization(PT2EQuantizationTestCase):
2485+
def test_channel_group_quantization(self):
2486+
from torch.ao.quantization.observer import MappingType, PerGroup, PerToken
2487+
from torch.ao.quantization.pt2e._affine_quantization import (
2488+
AffineQuantizedMinMaxObserver,
2489+
)
2490+
2491+
class BackendAQuantizer(Quantizer):
2492+
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
2493+
for node in model.graph.nodes:
2494+
if (
2495+
node.op == "call_function"
2496+
and node.target == torch.ops.aten.linear.default
2497+
):
2498+
input_act = node.args[0]
2499+
assert isinstance(input_act, Node)
2500+
weight = node.args[1]
2501+
assert isinstance(weight, Node)
2502+
2503+
act_qspec = QuantizationSpec(
2504+
dtype=torch.uint8,
2505+
quant_min=0,
2506+
quant_max=255,
2507+
qscheme=None,
2508+
is_dynamic=False,
2509+
observer_or_fake_quant_ctr=AffineQuantizedMinMaxObserver.with_args(
2510+
# TODO: maybe align the arg name here
2511+
target_dtype=torch.uint8,
2512+
mapping_type=MappingType.SYMMETRIC,
2513+
granularity=PerToken(),
2514+
),
2515+
)
2516+
2517+
weight_qspec = QuantizationSpec(
2518+
dtype=torch.uint8,
2519+
quant_min=0,
2520+
quant_max=255,
2521+
qscheme=None,
2522+
is_dynamic=False,
2523+
observer_or_fake_quant_ctr=AffineQuantizedMinMaxObserver.with_args(
2524+
target_dtype=torch.uint8,
2525+
mapping_type=MappingType.SYMMETRIC,
2526+
granularity=PerGroup(group_size=128),
2527+
),
2528+
)
2529+
node.meta["quantization_annotation"] = QuantizationAnnotation(
2530+
input_qspec_map={
2531+
input_act: act_qspec,
2532+
weight: weight_qspec,
2533+
},
2534+
_annotated=True,
2535+
)
2536+
2537+
def validate(self, model: torch.fx.GraphModule) -> None:
2538+
pass
2539+
2540+
class M(torch.nn.Module):
2541+
def __init__(self):
2542+
super().__init__()
2543+
self.linear = torch.nn.Linear(128, 20)
2544+
2545+
def forward(self, x):
2546+
return self.linear(x)
24812547

2482-
devices = ["cpu", "cuda"]
2483-
if TEST_HPU:
2484-
devices.append("hpu")
2485-
instantiate_device_type_tests(TestQuantizePT2E, globals(), only_for=devices)
2548+
node_occurrence = {
2549+
torch.ops.quant.quantize_affine: 2,
2550+
torch.ops.quant.dequantize_affine: 2,
2551+
}
2552+
node_list = [
2553+
torch.ops.quant.quantize_affine,
2554+
torch.ops.quant.dequantize_affine,
2555+
torch.ops.quant.quantize_affine,
2556+
torch.ops.quant.dequantize_affine,
2557+
]
2558+
example_inputs = (torch.randn(5, 128),)
2559+
self._test_quantizer(
2560+
M().eval(),
2561+
example_inputs,
2562+
BackendAQuantizer(),
2563+
node_occurrence,
2564+
node_list,
2565+
is_debug_mode=True,
2566+
)
2567+
2568+
2569+
instantiate_parametrized_tests(TestQuantizePT2E)

test/test_quantization.py

+1
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@
8787
from quantization.pt2e.test_metadata_porting import TestMetaDataPorting # noqa: F401
8888
from quantization.pt2e.test_numeric_debugger import TestNumericDebugger # noqa: F401
8989
from quantization.pt2e.test_quantize_pt2e import TestQuantizePT2E # noqa: F401
90+
from quantization.pt2e.test_quantize_pt2e import TestQuantizePT2EAffineQuantization # noqa: F401
9091
from quantization.pt2e.test_representation import TestPT2ERepresentation # noqa: F401
9192
from quantization.pt2e.test_xnnpack_quantizer import TestXNNPACKQuantizer # noqa: F401
9293
from quantization.pt2e.test_xnnpack_quantizer import TestXNNPACKQuantizerModels # noqa: F401

torch/ao/quantization/__init__.py

+14
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,20 @@
168168
"prepare_for_propagation_comparison",
169169
"extract_results_from_loggers",
170170
"compare_results",
171+
# from torchao, should be merged with torchao
172+
# in the future
173+
"AffineQuantizedObserverBase",
174+
"Granularity",
175+
"MappingType",
176+
"PerAxis",
177+
"PerBlock",
178+
"PerGroup",
179+
"PerRow",
180+
"PerTensor",
181+
"PerToken",
182+
"TorchAODType",
183+
"ZeroPointDomain",
184+
"get_block_size",
171185
]
172186

173187

0 commit comments

Comments
 (0)