|
42 | 42 | )
|
43 | 43 | from torch.export import export_for_training
|
44 | 44 | from torch.fx import Node
|
45 |
| -from torch.testing._internal.common_device_type import instantiate_device_type_tests |
46 | 45 | from torch.testing._internal.common_quantization import (
|
47 | 46 | NodeSpec as ns,
|
48 | 47 | PT2EQuantizationTestCase,
|
@@ -1865,6 +1864,10 @@ def _get_bn_train_eval_ops(self):
|
1865 | 1864 | torch.ops.aten.batch_norm.default,
|
1866 | 1865 | )
|
1867 | 1866 |
|
| 1867 | + @parametrize( |
| 1868 | + "device", |
| 1869 | + ["cpu"] + (["cuda"] if TEST_CUDA else []) + (["hpu"] if TEST_HPU else []), |
| 1870 | + ) |
1868 | 1871 | def test_move_exported_model_bn(self, device):
|
1869 | 1872 | """
|
1870 | 1873 | Test switching batch_norm behavior between train and eval modes using
|
@@ -2477,9 +2480,90 @@ def check_nn_module(node):
|
2477 | 2480 | check_nn_module(node)
|
2478 | 2481 |
|
2479 | 2482 |
|
2480 |
| -instantiate_parametrized_tests(TestQuantizePT2E) |
| 2483 | +@skipIfNoQNNPACK |
| 2484 | +class TestQuantizePT2EAffineQuantization(PT2EQuantizationTestCase): |
| 2485 | + def test_channel_group_quantization(self): |
| 2486 | + from torch.ao.quantization.observer import MappingType, PerGroup, PerToken |
| 2487 | + from torch.ao.quantization.pt2e._affine_quantization import ( |
| 2488 | + AffineQuantizedMinMaxObserver, |
| 2489 | + ) |
| 2490 | + |
| 2491 | + class BackendAQuantizer(Quantizer): |
| 2492 | + def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule: |
| 2493 | + for node in model.graph.nodes: |
| 2494 | + if ( |
| 2495 | + node.op == "call_function" |
| 2496 | + and node.target == torch.ops.aten.linear.default |
| 2497 | + ): |
| 2498 | + input_act = node.args[0] |
| 2499 | + assert isinstance(input_act, Node) |
| 2500 | + weight = node.args[1] |
| 2501 | + assert isinstance(weight, Node) |
| 2502 | + |
| 2503 | + act_qspec = QuantizationSpec( |
| 2504 | + dtype=torch.uint8, |
| 2505 | + quant_min=0, |
| 2506 | + quant_max=255, |
| 2507 | + qscheme=None, |
| 2508 | + is_dynamic=False, |
| 2509 | + observer_or_fake_quant_ctr=AffineQuantizedMinMaxObserver.with_args( |
| 2510 | + # TODO: maybe align the arg name here |
| 2511 | + target_dtype=torch.uint8, |
| 2512 | + mapping_type=MappingType.SYMMETRIC, |
| 2513 | + granularity=PerToken(), |
| 2514 | + ), |
| 2515 | + ) |
| 2516 | + |
| 2517 | + weight_qspec = QuantizationSpec( |
| 2518 | + dtype=torch.uint8, |
| 2519 | + quant_min=0, |
| 2520 | + quant_max=255, |
| 2521 | + qscheme=None, |
| 2522 | + is_dynamic=False, |
| 2523 | + observer_or_fake_quant_ctr=AffineQuantizedMinMaxObserver.with_args( |
| 2524 | + target_dtype=torch.uint8, |
| 2525 | + mapping_type=MappingType.SYMMETRIC, |
| 2526 | + granularity=PerGroup(group_size=128), |
| 2527 | + ), |
| 2528 | + ) |
| 2529 | + node.meta["quantization_annotation"] = QuantizationAnnotation( |
| 2530 | + input_qspec_map={ |
| 2531 | + input_act: act_qspec, |
| 2532 | + weight: weight_qspec, |
| 2533 | + }, |
| 2534 | + _annotated=True, |
| 2535 | + ) |
| 2536 | + |
| 2537 | + def validate(self, model: torch.fx.GraphModule) -> None: |
| 2538 | + pass |
| 2539 | + |
| 2540 | + class M(torch.nn.Module): |
| 2541 | + def __init__(self): |
| 2542 | + super().__init__() |
| 2543 | + self.linear = torch.nn.Linear(128, 20) |
| 2544 | + |
| 2545 | + def forward(self, x): |
| 2546 | + return self.linear(x) |
2481 | 2547 |
|
2482 |
| -devices = ["cpu", "cuda"] |
2483 |
| -if TEST_HPU: |
2484 |
| - devices.append("hpu") |
2485 |
| -instantiate_device_type_tests(TestQuantizePT2E, globals(), only_for=devices) |
| 2548 | + node_occurrence = { |
| 2549 | + torch.ops.quant.quantize_affine: 2, |
| 2550 | + torch.ops.quant.dequantize_affine: 2, |
| 2551 | + } |
| 2552 | + node_list = [ |
| 2553 | + torch.ops.quant.quantize_affine, |
| 2554 | + torch.ops.quant.dequantize_affine, |
| 2555 | + torch.ops.quant.quantize_affine, |
| 2556 | + torch.ops.quant.dequantize_affine, |
| 2557 | + ] |
| 2558 | + example_inputs = (torch.randn(5, 128),) |
| 2559 | + self._test_quantizer( |
| 2560 | + M().eval(), |
| 2561 | + example_inputs, |
| 2562 | + BackendAQuantizer(), |
| 2563 | + node_occurrence, |
| 2564 | + node_list, |
| 2565 | + is_debug_mode=True, |
| 2566 | + ) |
| 2567 | + |
| 2568 | + |
| 2569 | +instantiate_parametrized_tests(TestQuantizePT2E) |
0 commit comments