@@ -121,6 +121,24 @@ def get_anchors(
121121 )
122122
123123
124+ class SingleInputBasicPattern (QuantizationPattern ):
125+ @abstractmethod
126+ def partition_types (self ) -> list [OpOverload ]:
127+ pass
128+
129+ def get_anchors (
130+ self , gm : fx .GraphModule , fused_partition : list [fx .GraphModule ]
131+ ) -> PartitionAnchors | None :
132+ node = fused_partition [0 ].nodes [- 1 ]
133+
134+ return PartitionAnchors (
135+ inputs = [(node , NodeArgsIdx (0 ))],
136+ weights = [],
137+ biases = [],
138+ output = [(node ,)],
139+ )
140+
141+
124142def get_anchors_for_fixed_quant_specs (
125143 fused_partition : list [fx .GraphModule ],
126144 scale : float ,
@@ -376,7 +394,7 @@ def partition_types(self):
376394 return [torch .ops .aten .flatten .using_ints ]
377395
378396
379- class HardTanhPattern (QuantizationPattern ):
397+ class HardTanhPattern (SingleInputBasicPattern ):
380398 """
381399 Quantizer for HardTanh operator. Shared quantization spec is selected, as activation functions usually follows
382400 computation layer.
@@ -385,23 +403,12 @@ class HardTanhPattern(QuantizationPattern):
385403 def partition_types (self ):
386404 return [torch .ops .aten .hardtanh .default ]
387405
388- def get_anchors (
389- self , gm : fx .GraphModule , fused_partition : list [fx .GraphModule ]
390- ) -> PartitionAnchors | None :
391- node = fused_partition [0 ].nodes [- 1 ]
392-
393- return PartitionAnchors (
394- inputs = [(node , NodeArgsIdx (0 ))],
395- weights = [],
396- biases = [],
397- output = [(node ,)],
398- )
399406
400407 def replacement_op (self ):
401408 raise AssertionError ()
402409
403410
404- class HardTanhInPlacePattern (QuantizationPattern ):
411+ class HardTanhInPlacePattern (SingleInputBasicPattern ):
405412 """
406413 Quantizer for HardTanh operator with param inplace=True. Shared quantization spec is selected, as activation
407414 functions usually follows computation layer.
@@ -410,18 +417,6 @@ class HardTanhInPlacePattern(QuantizationPattern):
410417 def partition_types (self ):
411418 return [torch .ops .aten .hardtanh_ .default ]
412419
413- def get_anchors (
414- self , gm : fx .GraphModule , fused_partition : list [fx .GraphModule ]
415- ) -> PartitionAnchors | None :
416- node = fused_partition [0 ].nodes [- 1 ]
417-
418- return PartitionAnchors (
419- inputs = [(node , NodeArgsIdx (0 ))],
420- weights = [],
421- biases = [],
422- output = [(node ,)],
423- )
424-
425420 def replacement_op (self ):
426421 raise AssertionError ()
427422
@@ -513,19 +508,18 @@ def partition_types(self):
513508 return [torch .ops .aten .permute .default ]
514509
515510
516- class ReluPattern (SharedSpecPattern ):
511+ class ReluPattern (SingleInputBasicPattern ):
517512 """
518- Quantizer for Relu operator. Shared quantization spec is selected, as ReLU usually follows computation layer.
513+ Quantizer for Relu operator.
519514 """
520515
521516 def partition_types (self ):
522517 return [torch .ops .aten .relu .default ]
523518
524519
525- class ReluInPlacePattern (SharedSpecPattern ):
520+ class ReluInPlacePattern (SingleInputBasicPattern ):
526521 """
527- Quantizer for Relu operator with param inplace=True. Shared quantization spec is selected, as ReLU usually
528- follows computation layer.
522+ Quantizer for Relu operator with param inplace=True.
529523 """
530524
531525 def partition_types (self ):
0 commit comments