11# mypy: allow-untyped-defs
22import itertools
3- import typing
4- from dataclasses import dataclass
5- from typing import Callable , NamedTuple , Optional
3+ from typing import Callable , Optional
64
75import torch
86import torch .nn .functional as F
97from executorch .backends .xnnpack .utils .utils import is_depthwise_conv
108from torch ._subclasses import FakeTensor
11- from torch .ao .quantization .fx .utils import get_new_attr_name_with_prefix
12- from torch .ao .quantization .pt2e .export_utils import _WrapperModule
13- from torch .ao .quantization .pt2e .utils import (
14- _get_aten_graph_module_for_pattern ,
15- _is_conv_node ,
16- _is_conv_transpose_node ,
9+ from torch .fx import Node
10+ from torch .fx .passes .utils .matcher_with_name_node_map_utils import (
11+ SubgraphMatcherWithNameNodeMap ,
1712)
18- from torch .ao .quantization .quantizer import (
13+ from torchao .quantization .pt2e import WrapperModule
14+ from torchao .quantization .pt2e .graph_utils import get_source_partitions
15+ from torchao .quantization .pt2e .quantizer import (
16+ annotate_input_qspec_map ,
17+ annotate_output_qspec ,
18+ get_bias_qspec ,
19+ get_input_act_qspec ,
20+ get_output_act_qspec ,
21+ get_weight_qspec ,
22+ OperatorConfig ,
23+ OperatorPatternType ,
1924 QuantizationAnnotation ,
25+ QuantizationConfig ,
2026 QuantizationSpec ,
2127 SharedQuantizationSpec ,
2228)
23- from torch .ao .quantization .quantizer .utils import (
24- _annotate_input_qspec_map ,
25- _annotate_output_qspec ,
26- )
27- from torch .fx import Node
28- from torch .fx .passes .utils .matcher_with_name_node_map_utils import (
29- SubgraphMatcherWithNameNodeMap ,
29+ from torchao .quantization .pt2e .utils import (
30+ _get_aten_graph_module_for_pattern ,
31+ _is_conv_node ,
32+ _is_conv_transpose_node ,
33+ get_new_attr_name_with_prefix ,
3034)
31- from torch .fx .passes .utils .source_matcher_utils import get_source_partitions
3235
3336__all__ = [
3437 "OperatorConfig" ,
3538 "OperatorPatternType" ,
3639 "QuantizationConfig" ,
40+ "QuantizationSpec" ,
3741 "get_input_act_qspec" ,
3842 "get_output_act_qspec" ,
3943 "get_weight_qspec" ,
4347]
4448
4549
46- # In the absence of better name, just winging it with QuantizationConfig
47- @dataclass (eq = True , frozen = True )
48- class QuantizationConfig :
49- input_activation : Optional [QuantizationSpec ]
50- output_activation : Optional [QuantizationSpec ]
51- weight : Optional [QuantizationSpec ]
52- bias : Optional [QuantizationSpec ]
53- # TODO: remove, since we can use observer_or_fake_quant_ctr to express this
54- is_qat : bool = False
55-
56-
57- # Use Annotated because list[Callable].__module__ is read-only.
58- OperatorPatternType = typing .Annotated [list [Callable ], None ]
59- OperatorPatternType .__module__ = (
60- "executorch.backends.xnnpack.quantizer.xnnpack_quantizer_utils"
61- )
62-
6350AnnotatorType = Callable [
6451 [
6552 torch .fx .GraphModule ,
@@ -78,19 +65,6 @@ def decorator(annotator: AnnotatorType) -> None:
7865 return decorator
7966
8067
81- class OperatorConfig (NamedTuple ):
82- # fix List[str] with List[List[Union[nn.Module, FunctionType, BuiltinFunctionType]]]
83- # Basically we are mapping a quantization config to some list of patterns.
84- # a pattern is defined as a list of nn module, function or builtin function names
85- # e.g. [nn.Conv2d, torch.relu, torch.add]
86- # We have not resolved whether fusion can be considered internal details of the
87- # quantizer hence it does not need communication to user.
88- # Note this pattern is not really informative since it does not really
89- # tell us the graph structure resulting from the list of ops.
90- config : QuantizationConfig
91- operators : list [OperatorPatternType ]
92-
93-
9468def is_relu_node (node : Node ) -> bool :
9569 """
9670 Check if a given node is a relu node
@@ -124,63 +98,6 @@ def _mark_nodes_as_annotated(nodes: list[Node]):
12498 node .meta ["quantization_annotation" ]._annotated = True
12599
126100
127- def get_input_act_qspec (quantization_config : Optional [QuantizationConfig ]):
128- if quantization_config is None :
129- return None
130- if quantization_config .input_activation is None :
131- return None
132- quantization_spec : QuantizationSpec = quantization_config .input_activation
133- assert quantization_spec .qscheme in [
134- torch .per_tensor_affine ,
135- torch .per_tensor_symmetric ,
136- ]
137- return quantization_spec
138-
139-
140- def get_output_act_qspec (quantization_config : Optional [QuantizationConfig ]):
141- if quantization_config is None :
142- return None
143- if quantization_config .output_activation is None :
144- return None
145- quantization_spec : QuantizationSpec = quantization_config .output_activation
146- assert quantization_spec .qscheme in [
147- torch .per_tensor_affine ,
148- torch .per_tensor_symmetric ,
149- ]
150- return quantization_spec
151-
152-
153- def get_weight_qspec (quantization_config : Optional [QuantizationConfig ]):
154- if quantization_config is None :
155- return None
156- assert quantization_config is not None
157- if quantization_config .weight is None :
158- return None
159- quantization_spec : QuantizationSpec = quantization_config .weight
160- if quantization_spec .qscheme not in [
161- torch .per_tensor_symmetric ,
162- torch .per_channel_symmetric ,
163- None ,
164- ]:
165- raise ValueError (
166- f"Unsupported quantization_spec { quantization_spec } for weight"
167- )
168- return quantization_spec
169-
170-
171- def get_bias_qspec (quantization_config : Optional [QuantizationConfig ]):
172- if quantization_config is None :
173- return None
174- assert quantization_config is not None
175- if quantization_config .bias is None :
176- return None
177- quantization_spec : QuantizationSpec = quantization_config .bias
178- assert (
179- quantization_spec .dtype == torch .float
180- ), "Only float dtype for bias is supported for bias right now"
181- return quantization_spec
182-
183-
184101@register_annotator ("linear" )
185102def _annotate_linear (
186103 gm : torch .fx .GraphModule ,
@@ -204,25 +121,25 @@ def _annotate_linear(
204121 bias_node = node .args [2 ]
205122
206123 if _is_annotated ([node ]) is False : # type: ignore[list-item]
207- _annotate_input_qspec_map (
124+ annotate_input_qspec_map (
208125 node ,
209126 act_node ,
210127 input_act_qspec ,
211128 )
212- _annotate_input_qspec_map (
129+ annotate_input_qspec_map (
213130 node ,
214131 weight_node ,
215132 weight_qspec ,
216133 )
217134 nodes_to_mark_annotated = [node , weight_node ]
218135 if bias_node :
219- _annotate_input_qspec_map (
136+ annotate_input_qspec_map (
220137 node ,
221138 bias_node ,
222139 bias_qspec ,
223140 )
224141 nodes_to_mark_annotated .append (bias_node )
225- _annotate_output_qspec (node , output_act_qspec )
142+ annotate_output_qspec (node , output_act_qspec )
226143 _mark_nodes_as_annotated (nodes_to_mark_annotated )
227144 annotated_partitions .append (nodes_to_mark_annotated )
228145
@@ -572,7 +489,7 @@ def _conv_bn(x, conv_weight, conv_bias, bn_weight, bn_bias, bn_rm, bn_rv):
572489 "output" : output ,
573490 }
574491
575- return _WrapperModule (_conv_bn )
492+ return WrapperModule (_conv_bn )
576493
577494 # Needed for matching, otherwise the matches gets filtered out due to unused
578495 # nodes returned by batch norm
0 commit comments